mirror of
https://github.com/halejohn/Cloudreve.git
synced 2026-01-26 09:34:57 +08:00
Add: upload controller in slave mode
This commit is contained in:
@@ -3,6 +3,7 @@ package filesystem
|
||||
import (
|
||||
"context"
|
||||
"github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem/local"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -90,12 +91,15 @@ func NewAnonymousFileSystem() (*FileSystem, error) {
|
||||
User: &model.User{},
|
||||
}
|
||||
|
||||
anonymousGroup, err := model.GetGroupByID(3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// 如果是主机模式下,则为匿名文件系统分配游客用户组
|
||||
if conf.SystemConfig.Mode == "master" {
|
||||
anonymousGroup, err := model.GetGroupByID(3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fs.User.Group = anonymousGroup
|
||||
}
|
||||
|
||||
fs.User.Group = anonymousGroup
|
||||
return fs, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,4 +15,6 @@ const (
|
||||
FileModelCtx
|
||||
// HTTPCtx HTTP请求的上下文
|
||||
HTTPCtx
|
||||
// UploadPolicyCtx 上传策略,一般为slave模式下使用
|
||||
UploadPolicyCtx
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
@@ -52,6 +53,30 @@ func HookIsFileExist(ctx context.Context, fs *FileSystem) error {
|
||||
return ErrObjectNotExist
|
||||
}
|
||||
|
||||
// HookSlaveUploadValidate Slave模式下对文件上传的一系列验证
|
||||
// TODO 测试
|
||||
func HookSlaveUploadValidate(ctx context.Context, fs *FileSystem) error {
|
||||
file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader)
|
||||
policy := ctx.Value(fsctx.UploadPolicyCtx).(serializer.UploadPolicy)
|
||||
|
||||
// 验证单文件尺寸
|
||||
if file.GetSize() > policy.MaxSize {
|
||||
return ErrFileSizeTooBig
|
||||
}
|
||||
|
||||
// 验证文件名
|
||||
if !fs.ValidateLegalName(ctx, file.GetFileName()) {
|
||||
return ErrIllegalObjectName
|
||||
}
|
||||
|
||||
// 验证扩展名
|
||||
if len(policy.AllowedExtension) > 0 && !IsInExtensionList(policy.AllowedExtension, file.GetFileName()) {
|
||||
return ErrFileExtensionNotAllowed
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HookValidateFile 一系列对文件检验的集合
|
||||
func HookValidateFile(ctx context.Context, fs *FileSystem) error {
|
||||
file := ctx.Value(fsctx.FileHeaderCtx).(FileHeader)
|
||||
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem/local"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
@@ -477,3 +479,64 @@ func TestGenericAfterUpdate(t *testing.T) {
|
||||
asserts.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookSlaveUploadValidate(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
conf.SystemConfig.Mode = "slave"
|
||||
fs, err := NewAnonymousFileSystem()
|
||||
conf.SystemConfig.Mode = "master"
|
||||
asserts.NoError(err)
|
||||
|
||||
// 正常
|
||||
{
|
||||
policy := serializer.UploadPolicy{
|
||||
SavePath: "",
|
||||
MaxSize: 10,
|
||||
AllowedExtension: nil,
|
||||
}
|
||||
file := local.FileStream{Name: "1.txt", Size: 10}
|
||||
ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy)
|
||||
ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file)
|
||||
asserts.NoError(HookSlaveUploadValidate(ctx, fs))
|
||||
}
|
||||
|
||||
// 尺寸太大
|
||||
{
|
||||
policy := serializer.UploadPolicy{
|
||||
SavePath: "",
|
||||
MaxSize: 10,
|
||||
AllowedExtension: nil,
|
||||
}
|
||||
file := local.FileStream{Name: "1.txt", Size: 11}
|
||||
ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy)
|
||||
ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file)
|
||||
asserts.Equal(ErrFileSizeTooBig, HookSlaveUploadValidate(ctx, fs))
|
||||
}
|
||||
|
||||
// 文件名非法
|
||||
{
|
||||
policy := serializer.UploadPolicy{
|
||||
SavePath: "",
|
||||
MaxSize: 10,
|
||||
AllowedExtension: nil,
|
||||
}
|
||||
file := local.FileStream{Name: "/1.txt", Size: 10}
|
||||
ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy)
|
||||
ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file)
|
||||
asserts.Equal(ErrIllegalObjectName, HookSlaveUploadValidate(ctx, fs))
|
||||
}
|
||||
|
||||
// 扩展名非法
|
||||
{
|
||||
policy := serializer.UploadPolicy{
|
||||
SavePath: "",
|
||||
MaxSize: 10,
|
||||
AllowedExtension: []string{"jpg"},
|
||||
}
|
||||
file := local.FileStream{Name: "1.txt", Size: 10}
|
||||
ctx := context.WithValue(context.Background(), fsctx.UploadPolicyCtx, policy)
|
||||
ctx = context.WithValue(ctx, fsctx.FileHeaderCtx, file)
|
||||
asserts.Equal(ErrFileExtensionNotAllowed, HookSlaveUploadValidate(ctx, fs))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -24,14 +24,15 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// 生成文件名和路径, 如果是更新操作就从原始文件获取
|
||||
// 生成文件名和路径,
|
||||
var savePath string
|
||||
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
|
||||
if ok {
|
||||
// 如果是更新操作就从上下文中获取
|
||||
if originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
savePath = originFile.SourceName
|
||||
} else {
|
||||
savePath = fs.GenerateSavePath(ctx, file)
|
||||
}
|
||||
ctx = context.WithValue(ctx, fsctx.SavePathCtx, savePath)
|
||||
|
||||
// 处理客户端未完成上传时,关闭连接
|
||||
go fs.CancelUpload(ctx, savePath, file)
|
||||
@@ -43,7 +44,6 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) {
|
||||
}
|
||||
|
||||
// 上传完成后的钩子
|
||||
ctx = context.WithValue(ctx, fsctx.SavePathCtx, savePath)
|
||||
err = fs.Trigger(ctx, fs.AfterUpload)
|
||||
|
||||
if err != nil {
|
||||
@@ -57,21 +57,42 @@ func (fs *FileSystem) Upload(ctx context.Context, file FileHeader) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
util.Log().Info("新文件PUT:%s , 大小:%d, 上传者:%s", file.GetFileName(), file.GetSize(), fs.User.Nick)
|
||||
util.Log().Info(
|
||||
"新文件PUT:%s , 大小:%d, 上传者:%s",
|
||||
file.GetFileName(),
|
||||
file.GetSize(),
|
||||
fs.User.Nick,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSavePath 生成要存放文件的路径
|
||||
// TODO 完善测试
|
||||
func (fs *FileSystem) GenerateSavePath(ctx context.Context, file FileHeader) string {
|
||||
if fs.User.Model.ID != 0 {
|
||||
return filepath.Join(
|
||||
fs.User.Policy.GeneratePath(
|
||||
fs.User.Model.ID,
|
||||
file.GetVirtualPath(),
|
||||
),
|
||||
fs.User.Policy.GenerateFileName(
|
||||
fs.User.Model.ID,
|
||||
file.GetFileName(),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// 匿名文件系统使用空上传策略生成路径
|
||||
nilPolicy := model.Policy{}
|
||||
return filepath.Join(
|
||||
fs.User.Policy.GeneratePath(
|
||||
fs.User.Model.ID,
|
||||
file.GetVirtualPath(),
|
||||
nilPolicy.GeneratePath(
|
||||
0,
|
||||
"",
|
||||
),
|
||||
fs.User.Policy.GenerateFileName(
|
||||
fs.User.Model.ID,
|
||||
file.GetFileName(),
|
||||
nilPolicy.GenerateFileName(
|
||||
0,
|
||||
"",
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,33 @@
|
||||
package serializer
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// UploadPolicy slave模式下传递的上传策略
|
||||
type UploadPolicy struct {
|
||||
SavePath string `json:"save_path"`
|
||||
MaxSize uint64 `json:"save_path"`
|
||||
MaxSize uint64 `json:"max_size"`
|
||||
AllowedExtension []string `json:"allowed_extension"`
|
||||
CallbackURL string `json:"callback_url"`
|
||||
CallbackKey string `json:"callback_key"`
|
||||
}
|
||||
|
||||
// DecodeUploadPolicy 反序列化Header中携带的上传策略
|
||||
// TODO 测试
|
||||
func DecodeUploadPolicy(raw string) (*UploadPolicy, error) {
|
||||
var res UploadPolicy
|
||||
|
||||
rawJSON, err := base64.StdEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(rawJSON, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &res, err
|
||||
}
|
||||
|
||||
55
pkg/serializer/file_test.go
Normal file
55
pkg/serializer/file_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package serializer
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDecodeUploadPolicy(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
testCases := []struct {
|
||||
input string
|
||||
expectError bool
|
||||
expectNil bool
|
||||
expectRes *UploadPolicy
|
||||
}{
|
||||
{
|
||||
"错误的base64字符",
|
||||
true,
|
||||
true,
|
||||
&UploadPolicy{},
|
||||
},
|
||||
{
|
||||
"6ZSZ6K+v55qESlNPTuWtl+espg==",
|
||||
true,
|
||||
true,
|
||||
&UploadPolicy{},
|
||||
},
|
||||
{
|
||||
"e30=",
|
||||
false,
|
||||
false,
|
||||
&UploadPolicy{},
|
||||
},
|
||||
{
|
||||
"eyJjYWxsYmFja19rZXkiOiJ0ZXN0In0=",
|
||||
false,
|
||||
false,
|
||||
&UploadPolicy{CallbackKey: "test"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
res, err := DecodeUploadPolicy(testCase.input)
|
||||
if testCase.expectError {
|
||||
asserts.Error(err)
|
||||
}
|
||||
if testCase.expectNil {
|
||||
asserts.Nil(res)
|
||||
}
|
||||
if !testCase.expectNil {
|
||||
asserts.Equal(testCase.expectRes, res)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user