Feat: validate / cancel task while downloading file in aria2

This commit is contained in:
HFO4
2020-02-05 12:58:26 +08:00
parent 8c7e3883ee
commit 3ed84ad5ec
7 changed files with 119 additions and 4 deletions

View File

@@ -1,9 +1,13 @@
package aria2
import (
"context"
"encoding/json"
"errors"
model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/filesystem"
"github.com/HFO4/cloudreve/pkg/filesystem/driver/local"
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
"github.com/HFO4/cloudreve/pkg/task"
"github.com/HFO4/cloudreve/pkg/util"
"github.com/zyxar/argo/rpc"
@@ -71,9 +75,18 @@ func (monitor *Monitor) Update() bool {
return true
}
// 磁力链下载需要跟随
if len(status.FollowedBy) > 0 {
util.Log().Debug("离线下载[%s]重定向至[%s]", monitor.Task.GID, status.FollowedBy[0])
monitor.Task.GID = status.FollowedBy[0]
monitor.Task.Save()
return false
}
// 更新任务信息
if err := monitor.UpdateTaskInfo(status); err != nil {
util.Log().Warning("无法更新下载任务[%s]的任务信息[%s]", monitor.Task.GID, err)
monitor.setErrorStatus(err)
return true
}
@@ -96,6 +109,9 @@ func (monitor *Monitor) Update() bool {
// UpdateTaskInfo 更新数据库中的任务信息
func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
originSize := monitor.Task.TotalSize
originPath := monitor.Task.Path
monitor.Task.GID = status.Gid
monitor.Task.Status = getStatus(status.Status)
@@ -126,7 +142,68 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
attrs, _ := json.Marshal(status)
monitor.Task.Attrs = string(attrs)
return monitor.Task.Save()
if err := monitor.Task.Save(); err != nil {
return nil
}
if originSize != monitor.Task.TotalSize || originPath != monitor.Task.Path {
// 大小、文件名更新后,对文件限制等进行校验
if err := monitor.ValidateFile(); err != nil {
// 验证失败时取消任务
monitor.Cancel()
return err
}
}
return nil
}
// Cancel 取消上传并尝试删除临时文件
func (monitor *Monitor) Cancel() {
if err := Instance.Cancel(monitor.Task); err != nil {
util.Log().Warning("无法取消离线下载任务[%s], %s", monitor.Task.GID, err)
}
util.Log().Debug("离线下载任务[%s]已取消1 分钟后删除临时文件", monitor.Task.GID)
go func(monitor *Monitor) {
select {
case <-time.After(time.Duration(60) * time.Second):
monitor.RemoveTempFolder()
}
}(monitor)
}
// ValidateFile 上传过程中校验文件大小、文件名
func (monitor *Monitor) ValidateFile() error {
// 找到任务创建者
user := monitor.Task.GetOwner()
if user == nil {
return ErrUserNotFound
}
// 创建文件系统
fs, err := filesystem.NewFileSystem(user)
if err != nil {
return err
}
defer fs.Recycle()
// 创建上下文环境
ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{
Size: monitor.Task.TotalSize,
Name: filepath.Base(monitor.Task.Path),
})
// 验证文件
if err := filesystem.HookValidateFile(ctx, fs); err != nil {
return err
}
// 验证用户容量
if err := filesystem.HookValidateCapacityWithoutIncrease(ctx, fs); err != nil {
return err
}
return nil
}
// Error 任务下载出错处理,返回是否中断监控

View File

@@ -20,6 +20,8 @@ type Aria2 interface {
CreateTask(task *model.Download) error
// 返回状态信息
Status(task *model.Download) (rpc.StatusInfo, error)
// 取消任务
Cancel(task *model.Download) error
}
const (
@@ -48,7 +50,8 @@ const (
var (
// ErrNotEnabled 功能未开启错误
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil)
)
// DummyAria2 未开启Aria2功能时使用的默认处理器
@@ -65,6 +68,11 @@ func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error)
return rpc.StatusInfo{}, ErrNotEnabled
}
// Cancel 返回未开启错误
func (instance *DummyAria2) Cancel(task *model.Download) error {
return ErrNotEnabled
}
// Init 初始化
func Init() {
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")

View File

@@ -40,6 +40,12 @@ func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) {
return client.caller.TellStatus(task.GID)
}
// Cancel 取消下载
func (client *RPCService) Cancel(task *model.Download) error {
_, err := client.caller.Remove(task.GID)
return err
}
// CreateTask 创建新任务
func (client *RPCService) CreateTask(task *model.Download) error {
// 生成存储路径