Refactor: move slave pkg inside of cluster

Test: middleware for node communication
This commit is contained in:
HFO4
2021-11-08 19:54:26 +08:00
parent eaa0f6be91
commit e41ec9defa
16 changed files with 135 additions and 43 deletions

View File

@@ -37,6 +37,7 @@ func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
c.Abort()
return
}
c.Next()
}
}

View File

@@ -90,15 +90,27 @@ func TestSignRequired(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
SignRequiredFunc := SignRequired(auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))})
authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
SignRequiredFunc := SignRequired(authInstance)
// 鉴权失败
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())
// Sign verify success
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
c.Request = auth.SignRequest(authInstance, c.Request, 0)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.False(c.IsAborted())
}
func TestWebDAVAuth(t *testing.T) {
@@ -780,8 +792,6 @@ func TestS3CallbackAuth(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[702]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackUpyun"},
@@ -789,5 +799,6 @@ func TestS3CallbackAuth(t *testing.T) {
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
AuthFunc(c)
asserts.False(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
}
}

View File

@@ -3,7 +3,6 @@ package middleware
import (
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/gin-gonic/gin"
"strconv"
)
@@ -19,11 +18,11 @@ func MasterMetadata() gin.HandlerFunc {
}
// UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例
func UseSlaveAria2Instance() gin.HandlerFunc {
func UseSlaveAria2Instance(clusterController cluster.Controller) gin.HandlerFunc {
return func(c *gin.Context) {
if siteID, exist := c.Get("MasterSiteID"); exist {
// 获取对应主机节点的从机Aria2实例
caller, err := slave.DefaultController.GetAria2Instance(siteID.(string))
caller, err := clusterController.GetAria2Instance(siteID.(string))
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err))
c.Abort()
@@ -40,7 +39,7 @@ func UseSlaveAria2Instance() gin.HandlerFunc {
}
}
func SlaveRPCSignRequired() gin.HandlerFunc {
func SlaveRPCSignRequired(nodePool cluster.Pool) gin.HandlerFunc {
return func(c *gin.Context) {
nodeID, err := strconv.ParseUint(c.GetHeader("X-Node-Id"), 10, 64)
if err != nil {
@@ -49,7 +48,7 @@ func SlaveRPCSignRequired() gin.HandlerFunc {
return
}
slaveNode := cluster.Default.GetNodeByID(uint(nodeID))
slaveNode := nodePool.GetNodeByID(uint(nodeID))
if slaveNode == nil {
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
c.Abort()

View File

@@ -0,0 +1,80 @@
package middleware
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)
func TestMasterMetadata(t *testing.T) {
a := assert.New(t)
masterMetaDataFunc := MasterMetadata()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header = map[string][]string{
"X-Site-Id": {"expectedSiteID"},
"X-Site-Url": {"expectedSiteURL"},
"X-Cloudreve-Version": {"expectedMasterVersion"},
}
masterMetaDataFunc(c)
siteID, _ := c.Get("MasterSiteID")
siteURL, _ := c.Get("MasterSiteURL")
siteVersion, _ := c.Get("MasterVersion")
a.Equal("expectedSiteID", siteID.(string))
a.Equal("expectedSiteURL", siteURL.(string))
a.Equal("expectedMasterVersion", siteVersion.(string))
}
func TestSlaveRPCSignRequired(t *testing.T) {
a := assert.New(t)
np := &cluster.NodePool{}
np.Init()
slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np)
rec := httptest.NewRecorder()
// id parse failed
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header.Set("X-Node-Id", "unknown")
slaveRPCSignRequiredFunc(c)
a.True(c.IsAborted())
}
// node id not exist
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header.Set("X-Node-Id", "38")
slaveRPCSignRequiredFunc(c)
a.True(c.IsAborted())
}
// success
{
authInstance := auth.HMACAuth{SecretKey: []byte("")}
np.Add(&model.Node{Model: gorm.Model{
ID: 38,
}})
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("POST", "/", nil)
c.Request.Header.Set("X-Node-Id", "38")
c.Request = auth.SignRequest(authInstance, c.Request, 0)
slaveRPCSignRequiredFunc(c)
a.False(c.IsAborted())
}
}
func TestUseSlaveAria2Instance(t *testing.T) {
a := assert.New(t)
}