mirror of
https://github.com/halejohn/Cloudreve.git
synced 2026-01-26 09:34:57 +08:00
Feat: RWMutex / reload for aira2
This commit is contained in:
@@ -7,18 +7,22 @@ import (
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"net/url"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Instance 默认使用的Aria2处理实例
|
||||
var Instance Aria2 = &DummyAria2{}
|
||||
|
||||
// Lock Instance的读写锁
|
||||
var Lock sync.RWMutex
|
||||
|
||||
// EventNotifier 任务状态更新通知处理器
|
||||
var EventNotifier = &Notifier{}
|
||||
|
||||
// Aria2 离线下载处理接口
|
||||
type Aria2 interface {
|
||||
// CreateTask 创建新的任务
|
||||
CreateTask(task *model.Download, options []interface{}) error
|
||||
CreateTask(task *model.Download, options map[string]interface{}) error
|
||||
// 返回状态信息
|
||||
Status(task *model.Download) (rpc.StatusInfo, error)
|
||||
// 取消任务
|
||||
@@ -63,7 +67,7 @@ type DummyAria2 struct {
|
||||
}
|
||||
|
||||
// CreateTask 创建新任务,此处直接返回未开启错误
|
||||
func (instance *DummyAria2) CreateTask(model *model.Download, options []interface{}) error {
|
||||
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) error {
|
||||
return ErrNotEnabled
|
||||
}
|
||||
|
||||
@@ -83,7 +87,16 @@ func (instance *DummyAria2) Select(task *model.Download, files []int) error {
|
||||
}
|
||||
|
||||
// Init 初始化
|
||||
func Init() {
|
||||
func Init(isReload bool) {
|
||||
Lock.Lock()
|
||||
defer Lock.Unlock()
|
||||
|
||||
// 关闭上个初始连接
|
||||
if previousClient, ok := Instance.(*RPCService); ok {
|
||||
util.Log().Debug("关闭上个 aria2 连接")
|
||||
previousClient.caller.Close()
|
||||
}
|
||||
|
||||
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")
|
||||
timeout := model.GetIntSetting("aria2_call_timeout", 5)
|
||||
if options["aria2_rpcurl"] == "" {
|
||||
@@ -93,9 +106,6 @@ func Init() {
|
||||
|
||||
util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"])
|
||||
client := &RPCService{}
|
||||
if previousClient, ok := Instance.(*RPCService); ok {
|
||||
client = previousClient
|
||||
}
|
||||
|
||||
// 解析RPC服务地址
|
||||
server, err := url.Parse(options["aria2_rpcurl"])
|
||||
@@ -107,7 +117,7 @@ func Init() {
|
||||
server.Path = "/jsonrpc"
|
||||
|
||||
// 加载自定义下载配置
|
||||
var globalOptions []interface{}
|
||||
var globalOptions map[string]interface{}
|
||||
err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法解析 aria2 全局配置,%s", err)
|
||||
@@ -123,13 +133,16 @@ func Init() {
|
||||
|
||||
Instance = client
|
||||
|
||||
// 从数据库中读取未完成任务,创建监控
|
||||
unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading)
|
||||
if !isReload {
|
||||
// 从数据库中读取未完成任务,创建监控
|
||||
unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading)
|
||||
|
||||
for i := 0; i < len(unfinished); i++ {
|
||||
// 创建任务监控
|
||||
NewMonitor(&unfinished[i])
|
||||
for i := 0; i < len(unfinished); i++ {
|
||||
// 创建任务监控
|
||||
NewMonitor(&unfinished[i])
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// getStatus 将给定的状态字符串转换为状态标识数字
|
||||
|
||||
Reference in New Issue
Block a user