mirror of
https://github.com/halejohn/Cloudreve.git
synced 2026-01-27 01:51:56 +08:00
Feat: support using SharePoint site to store files
This commit is contained in:
@@ -53,12 +53,23 @@ func (err RespError) Error() string {
|
||||
return err.APIError.Message
|
||||
}
|
||||
|
||||
func (client *Client) getRequestURL(api string) string {
|
||||
func (client *Client) getRequestURL(api string, opts ...Option) string {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
base, _ := url.Parse(client.Endpoints.EndpointURL)
|
||||
if base == nil {
|
||||
return ""
|
||||
}
|
||||
base.Path = path.Join(base.Path, api)
|
||||
|
||||
if options.useDriverResource {
|
||||
base.Path = path.Join(base.Path, client.Endpoints.DriverResource, api)
|
||||
} else {
|
||||
base.Path = path.Join(base.Path, api)
|
||||
}
|
||||
|
||||
return base.String()
|
||||
}
|
||||
|
||||
@@ -67,9 +78,9 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo
|
||||
var requestURL string
|
||||
dst := strings.TrimPrefix(path, "/")
|
||||
if dst == "" {
|
||||
requestURL = client.getRequestURL("me/drive/root/children")
|
||||
requestURL = client.getRequestURL("root/children")
|
||||
} else {
|
||||
requestURL = client.getRequestURL("me/drive/root:/" + dst + ":/children")
|
||||
requestURL = client.getRequestURL("root:/" + dst + ":/children")
|
||||
}
|
||||
|
||||
res, err := client.requestWithStr(ctx, "GET", requestURL+"?$top=999999999", "", 200)
|
||||
@@ -103,10 +114,10 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo
|
||||
func (client *Client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) {
|
||||
var requestURL string
|
||||
if id != "" {
|
||||
requestURL = client.getRequestURL("/me/drive/items/" + id)
|
||||
requestURL = client.getRequestURL("items/" + id)
|
||||
} else {
|
||||
dst := strings.TrimPrefix(path, "/")
|
||||
requestURL = client.getRequestURL("me/drive/root:/" + dst)
|
||||
requestURL = client.getRequestURL("root:/" + dst)
|
||||
}
|
||||
|
||||
res, err := client.requestWithStr(ctx, "GET", requestURL+"?expand=thumbnails", "", 200)
|
||||
@@ -129,14 +140,13 @@ func (client *Client) Meta(ctx context.Context, id string, path string) (*FileIn
|
||||
|
||||
// CreateUploadSession 创建分片上传会话
|
||||
func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) {
|
||||
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
dst = strings.TrimPrefix(dst, "/")
|
||||
requestURL := client.getRequestURL("me/drive/root:/" + dst + ":/createUploadSession")
|
||||
requestURL := client.getRequestURL("root:/" + dst + ":/createUploadSession")
|
||||
body := map[string]map[string]interface{}{
|
||||
"item": {
|
||||
"@microsoft.graph.conflictBehavior": options.conflictBehavior,
|
||||
@@ -161,6 +171,33 @@ func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts
|
||||
return uploadSession.UploadURL, nil
|
||||
}
|
||||
|
||||
// GetSiteIDByURL 通过 SharePoint 站点 URL 获取站点ID
|
||||
func (client *Client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) {
|
||||
siteUrlParsed, err := url.Parse(siteUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
hostName := siteUrlParsed.Hostname()
|
||||
relativePath := strings.Trim(siteUrlParsed.Path, "/")
|
||||
requestURL := client.getRequestURL(fmt.Sprintf("sites/%s:/%s", hostName, relativePath), WithDriverResource(false))
|
||||
res, reqErr := client.requestWithStr(ctx, "GET", requestURL, "", 200)
|
||||
if reqErr != nil {
|
||||
return "", reqErr
|
||||
}
|
||||
|
||||
var (
|
||||
decodeErr error
|
||||
siteInfo Site
|
||||
)
|
||||
decodeErr = json.Unmarshal([]byte(res), &siteInfo)
|
||||
if decodeErr != nil {
|
||||
return "", decodeErr
|
||||
}
|
||||
|
||||
return siteInfo.ID, nil
|
||||
}
|
||||
|
||||
// GetUploadSessionStatus 查询上传会话状态
|
||||
func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) {
|
||||
res, err := client.requestWithStr(ctx, "GET", uploadURL, "", 200)
|
||||
@@ -300,7 +337,7 @@ func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Read
|
||||
}
|
||||
|
||||
dst = strings.TrimPrefix(dst, "/")
|
||||
requestURL := client.getRequestURL("me/drive/root:/" + dst + ":/content")
|
||||
requestURL := client.getRequestURL("root:/" + dst + ":/content")
|
||||
requestURL += ("?@microsoft.graph.conflictBehavior=" + options.conflictBehavior)
|
||||
|
||||
res, err := client.request(ctx, "PUT", requestURL, body, request.WithContentLength(int64(size)),
|
||||
@@ -357,7 +394,8 @@ func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string,
|
||||
// 由于API限制,最多删除20个
|
||||
func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error) {
|
||||
body := client.makeBatchDeleteRequestsBody(dst)
|
||||
res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch"), body, 200)
|
||||
res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch",
|
||||
WithDriverResource(false)), body, 200)
|
||||
if err != nil {
|
||||
return dst, err
|
||||
}
|
||||
@@ -396,7 +434,7 @@ func (client *Client) makeBatchDeleteRequestsBody(files []string) string {
|
||||
}
|
||||
for i, v := range files {
|
||||
v = strings.TrimPrefix(v, "/")
|
||||
filePath, _ := url.Parse("/me/drive/root:/")
|
||||
filePath, _ := url.Parse("/" + client.Endpoints.DriverResource + "/root:/")
|
||||
filePath.Path = path.Join(filePath.Path, v)
|
||||
req.Requests[i] = BatchRequest{
|
||||
ID: v,
|
||||
@@ -418,10 +456,10 @@ func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (s
|
||||
)
|
||||
if client.Endpoints.isInChina {
|
||||
cropOption = "large"
|
||||
requestURL = client.getRequestURL("me/drive/root:/"+dst+":/thumbnails/0") + "/" + cropOption
|
||||
requestURL = client.getRequestURL("root:/"+dst+":/thumbnails/0") + "/" + cropOption
|
||||
} else {
|
||||
cropOption = fmt.Sprintf("c%dx%d_Crop", w, h)
|
||||
requestURL = client.getRequestURL("me/drive/root:/"+dst+":/thumbnails") + "?select=" + cropOption
|
||||
requestURL = client.getRequestURL("root:/"+dst+":/thumbnails") + "?select=" + cropOption
|
||||
}
|
||||
|
||||
res, err := client.requestWithStr(ctx, "GET", requestURL, "", 200)
|
||||
|
||||
@@ -167,6 +167,82 @@ func TestClient_GetRequestURL(t *testing.T) {
|
||||
client.Endpoints.EndpointURL = string([]byte{0x7f})
|
||||
asserts.Equal("", client.getRequestURL("123"))
|
||||
}
|
||||
|
||||
// 使用DriverResource
|
||||
{
|
||||
client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0"
|
||||
asserts.Equal("https://graph.microsoft.com/v1.0/me/drive/123", client.getRequestURL("123"))
|
||||
}
|
||||
|
||||
// 不使用DriverResource
|
||||
{
|
||||
client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0"
|
||||
asserts.Equal("https://graph.microsoft.com/v1.0/123", client.getRequestURL("123", WithDriverResource(false)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_GetSiteIDByURL(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
client, _ := NewClient(&model.Policy{})
|
||||
client.Credential.AccessToken = "AccessToken"
|
||||
|
||||
// 请求失败
|
||||
{
|
||||
client.Credential.ExpiresIn = 0
|
||||
res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
|
||||
asserts.Error(err)
|
||||
asserts.Empty(res)
|
||||
|
||||
}
|
||||
|
||||
// 返回未知响应
|
||||
{
|
||||
client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
clientMock := ClientMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"GET",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`???`)),
|
||||
},
|
||||
})
|
||||
client.Request = clientMock
|
||||
res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
|
||||
clientMock.AssertExpectations(t)
|
||||
asserts.Error(err)
|
||||
asserts.Empty(res)
|
||||
}
|
||||
|
||||
// 返回正常
|
||||
{
|
||||
client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
clientMock := ClientMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"GET",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"id":"123321"}`)),
|
||||
},
|
||||
})
|
||||
client.Request = clientMock
|
||||
res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
|
||||
clientMock.AssertExpectations(t)
|
||||
asserts.NoError(err)
|
||||
asserts.NotEmpty(res)
|
||||
asserts.Equal("123321", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Meta(t *testing.T) {
|
||||
|
||||
@@ -37,14 +37,16 @@ type Endpoints struct {
|
||||
OAuthEndpoints *oauthEndpoint
|
||||
EndpointURL string // 接口请求的基URL
|
||||
isInChina bool // 是否为世纪互联
|
||||
DriverResource string // 要使用的驱动器
|
||||
}
|
||||
|
||||
// NewClient 根据存储策略获取新的client
|
||||
func NewClient(policy *model.Policy) (*Client, error) {
|
||||
client := &Client{
|
||||
Endpoints: &Endpoints{
|
||||
OAuthURL: policy.BaseURL,
|
||||
EndpointURL: policy.Server,
|
||||
OAuthURL: policy.BaseURL,
|
||||
EndpointURL: policy.Server,
|
||||
DriverResource: policy.OptionsSerialized.OdDriver,
|
||||
},
|
||||
Credential: &Credential{
|
||||
RefreshToken: policy.AccessKey,
|
||||
@@ -56,6 +58,10 @@ func NewClient(policy *model.Policy) (*Client, error) {
|
||||
Request: request.HTTPClient{},
|
||||
}
|
||||
|
||||
if client.Endpoints.DriverResource == "" {
|
||||
client.Endpoints.DriverResource = "me/drive"
|
||||
}
|
||||
|
||||
oauthBase := client.getOAuthEndpoint()
|
||||
if oauthBase == nil {
|
||||
return nil, ErrAuthEndpoint
|
||||
|
||||
@@ -160,7 +160,8 @@ func (client *Client) UpdateCredential(ctx context.Context) error {
|
||||
client.Credential = credential
|
||||
|
||||
// 更新存储策略的 RefreshToken
|
||||
client.Policy.UpdateAccessKey(credential.RefreshToken)
|
||||
client.Policy.AccessKey = credential.RefreshToken
|
||||
client.Policy.SaveAndClearCache()
|
||||
|
||||
// 更新缓存
|
||||
cache.Set("onedrive_"+client.ClientID, *credential, int(expires))
|
||||
|
||||
@@ -8,11 +8,12 @@ type Option interface {
|
||||
}
|
||||
|
||||
type options struct {
|
||||
redirect string
|
||||
code string
|
||||
refreshToken string
|
||||
conflictBehavior string
|
||||
expires time.Time
|
||||
redirect string
|
||||
code string
|
||||
refreshToken string
|
||||
conflictBehavior string
|
||||
expires time.Time
|
||||
useDriverResource bool
|
||||
}
|
||||
|
||||
type optionFunc func(*options)
|
||||
@@ -38,13 +39,21 @@ func WithConflictBehavior(t string) Option {
|
||||
})
|
||||
}
|
||||
|
||||
// WithConflictBehavior 设置文件重名后的处理方式
|
||||
func WithDriverResource(t bool) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.useDriverResource = t
|
||||
})
|
||||
}
|
||||
|
||||
func (f optionFunc) apply(o *options) {
|
||||
f(o)
|
||||
}
|
||||
|
||||
func newDefaultOption() *options {
|
||||
return &options{
|
||||
conflictBehavior: "fail",
|
||||
expires: time.Now().UTC().Add(time.Duration(1) * time.Hour),
|
||||
conflictBehavior: "fail",
|
||||
useDriverResource: true,
|
||||
expires: time.Now().UTC().Add(time.Duration(1) * time.Hour),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,6 +131,15 @@ type OAuthError struct {
|
||||
CorrelationID string `json:"correlation_id"`
|
||||
}
|
||||
|
||||
// Site SharePoint 站点信息
|
||||
type Site struct {
|
||||
Description string `json:"description"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
WebUrl string `json:"webUrl"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
gob.Register(Credential{})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user