Feat: add local policy

This commit is contained in:
HFO4
2020-02-26 15:11:06 +08:00
parent c1d2b933aa
commit f1ef21e195
13 changed files with 270 additions and 35 deletions

View File

@@ -93,8 +93,10 @@ func Init(isReload bool) {
// 关闭上个初始连接
if previousClient, ok := Instance.(*RPCService); ok {
util.Log().Debug("关闭上个 aria2 连接")
previousClient.caller.Close()
if previousClient.Caller != nil {
util.Log().Debug("关闭上个 aria2 连接")
previousClient.Caller.Close()
}
}
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")

View File

@@ -14,7 +14,7 @@ import (
// RPCService 通过RPC服务的Aria2任务管理器
type RPCService struct {
options *clientOptions
caller rpc.Client
Caller rpc.Client
}
type clientOptions struct {
@@ -24,8 +24,8 @@ type clientOptions struct {
// Init 初始化
func (client *RPCService) Init(server, secret string, timeout int, options map[string]interface{}) error {
// 客户端已存在,则关闭先前连接
if client.caller != nil {
client.caller.Close()
if client.Caller != nil {
client.Caller.Close()
}
client.options = &clientOptions{
@@ -33,18 +33,18 @@ func (client *RPCService) Init(server, secret string, timeout int, options map[s
}
caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second,
EventNotifier)
client.caller = caller
client.Caller = caller
return err
}
// Status 查询下载状态
func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) {
res, err := client.caller.TellStatus(task.GID)
res, err := client.Caller.TellStatus(task.GID)
if err != nil {
// 失败后重试
util.Log().Debug("无法获取离线下载状态,%s10秒钟后重试", err)
time.Sleep(time.Duration(10) * time.Second)
res, err = client.caller.TellStatus(task.GID)
res, err = client.Caller.TellStatus(task.GID)
}
return res, err
@@ -53,7 +53,7 @@ func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) {
// Cancel 取消下载
func (client *RPCService) Cancel(task *model.Download) error {
// 取消下载任务
_, err := client.caller.Remove(task.GID)
_, err := client.Caller.Remove(task.GID)
if err != nil {
util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err)
}
@@ -79,7 +79,7 @@ func (client *RPCService) Select(task *model.Download, files []int) error {
for i := 0; i < len(files); i++ {
selected[i] = strconv.Itoa(files[i])
}
_, err := client.caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
_, err := client.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
return err
}
@@ -103,7 +103,7 @@ func (client *RPCService) CreateTask(task *model.Download, groupOptions map[stri
options[k] = v
}
gid, err := client.caller.AddURI(task.Source, options)
gid, err := client.Caller.AddURI(task.Source, options)
if err != nil || gid == "" {
return err
}

View File

@@ -20,7 +20,7 @@ type InstanceMock struct {
testMock.Mock
}
func (m InstanceMock) CreateTask(task *model.Download, options []interface{}) error {
func (m InstanceMock) CreateTask(task *model.Download, options map[string]interface{}) error {
args := m.Called(task, options)
return args.Error(0)
}
@@ -307,13 +307,16 @@ func TestMonitor_Complete(t *testing.T) {
}
cache.Set("setting_max_worker_num", "1", 0)
mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"}))
task.Init()
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
asserts.True(monitor.Complete(rpc.StatusInfo{}))
asserts.NoError(mock.ExpectationsWereMet())