mirror of
https://github.com/halejohn/Cloudreve.git
synced 2026-01-26 09:34:57 +08:00
feat(kv): persist cache and session into disk before shutdown
This commit is contained in:
18
pkg/cache/driver.go
vendored
18
pkg/cache/driver.go
vendored
@@ -1,11 +1,16 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register(map[string]itemWithTTL{})
|
||||
}
|
||||
|
||||
// Store 缓存存储器
|
||||
var Store Driver = NewMemoStore()
|
||||
|
||||
@@ -22,6 +27,13 @@ func Init() {
|
||||
}
|
||||
}
|
||||
|
||||
// Restore restores cache from given disk file
|
||||
func Restore(persistFile string) {
|
||||
if err := Store.Restore(persistFile); err != nil {
|
||||
util.Log().Warning("Failed to restore cache from disk: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func InitSlaveOverwrites() {
|
||||
err := Store.Sets(conf.OptionOverwrite, "setting_")
|
||||
if err != nil {
|
||||
@@ -45,6 +57,12 @@ type Driver interface {
|
||||
|
||||
// 删除值
|
||||
Delete(keys []string, prefix string) error
|
||||
|
||||
// Save in-memory cache to disk
|
||||
Persist(path string) error
|
||||
|
||||
// Restore cache from disk
|
||||
Restore(path string) error
|
||||
}
|
||||
|
||||
// Set 设置缓存值
|
||||
|
||||
79
pkg/cache/memo.go
vendored
79
pkg/cache/memo.go
vendored
@@ -1,6 +1,9 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -14,18 +17,20 @@ type MemoStore struct {
|
||||
|
||||
// item 存储的对象
|
||||
type itemWithTTL struct {
|
||||
expires int64
|
||||
value interface{}
|
||||
Expires int64
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
const DefaultCacheFile = "cache_persist.bin"
|
||||
|
||||
func newItem(value interface{}, expires int) itemWithTTL {
|
||||
expires64 := int64(expires)
|
||||
if expires > 0 {
|
||||
expires64 = time.Now().Unix() + expires64
|
||||
}
|
||||
return itemWithTTL{
|
||||
value: value,
|
||||
expires: expires64,
|
||||
Value: value,
|
||||
Expires: expires64,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,11 +45,11 @@ func getValue(item interface{}, ok bool) (interface{}, bool) {
|
||||
return item, true
|
||||
}
|
||||
|
||||
if itemObj.expires > 0 && itemObj.expires < time.Now().Unix() {
|
||||
if itemObj.Expires > 0 && itemObj.Expires < time.Now().Unix() {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return itemObj.value, ok
|
||||
return itemObj.Value, ok
|
||||
|
||||
}
|
||||
|
||||
@@ -52,7 +57,7 @@ func getValue(item interface{}, ok bool) (interface{}, bool) {
|
||||
func (store *MemoStore) GarbageCollect() {
|
||||
store.Store.Range(func(key, value interface{}) bool {
|
||||
if item, ok := value.(itemWithTTL); ok {
|
||||
if item.expires > 0 && item.expires < time.Now().Unix() {
|
||||
if item.Expires > 0 && item.Expires < time.Now().Unix() {
|
||||
util.Log().Debug("Cache %q is garbage collected.", key.(string))
|
||||
store.Store.Delete(key)
|
||||
}
|
||||
@@ -98,7 +103,7 @@ func (store *MemoStore) Gets(keys []string, prefix string) (map[string]interface
|
||||
// Sets 批量设置值
|
||||
func (store *MemoStore) Sets(values map[string]interface{}, prefix string) error {
|
||||
for key, value := range values {
|
||||
store.Store.Store(prefix+key, value)
|
||||
store.Store.Store(prefix+key, newItem(value, 0))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -110,3 +115,61 @@ func (store *MemoStore) Delete(keys []string, prefix string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Persist write memory store into cache
|
||||
func (store *MemoStore) Persist(path string) error {
|
||||
persisted := make(map[string]itemWithTTL)
|
||||
store.Store.Range(func(key, value interface{}) bool {
|
||||
v, ok := store.Store.Load(key)
|
||||
if _, ok := getValue(v, ok); ok {
|
||||
persisted[key.(string)] = v.(itemWithTTL)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
res, err := serializer(persisted)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize cache: %s", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(path, res, 0644)
|
||||
return err
|
||||
}
|
||||
|
||||
// Restore memory cache from disk file
|
||||
func (store *MemoStore) Restore(path string) error {
|
||||
if !util.Exists(path) {
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read cache file: %s", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
f.Close()
|
||||
os.Remove(path)
|
||||
}()
|
||||
|
||||
persisted := &item{}
|
||||
dec := gob.NewDecoder(f)
|
||||
if err := dec.Decode(&persisted); err != nil {
|
||||
return fmt.Errorf("unknown cache file format: %s", err)
|
||||
}
|
||||
|
||||
items := persisted.Value.(map[string]itemWithTTL)
|
||||
loaded := 0
|
||||
for k, v := range items {
|
||||
if _, ok := getValue(v, true); ok {
|
||||
loaded++
|
||||
store.Store.Store(k, v)
|
||||
} else {
|
||||
util.Log().Debug("Persisted cache %q is expired.", k)
|
||||
}
|
||||
}
|
||||
|
||||
util.Log().Info("Restored %d items from %q into memory cache.", loaded, path)
|
||||
return nil
|
||||
}
|
||||
|
||||
46
pkg/cache/memo_test.go
vendored
46
pkg/cache/memo_test.go
vendored
@@ -2,6 +2,7 @@ package cache
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -23,7 +24,7 @@ func TestMemoStore_Set(t *testing.T) {
|
||||
|
||||
val, ok := store.Store.Load("KEY")
|
||||
asserts.True(ok)
|
||||
asserts.Equal("vAL", val.(itemWithTTL).value)
|
||||
asserts.Equal("vAL", val.(itemWithTTL).Value)
|
||||
}
|
||||
|
||||
func TestMemoStore_Get(t *testing.T) {
|
||||
@@ -145,3 +146,46 @@ func TestMemoStore_GarbageCollect(t *testing.T) {
|
||||
_, ok := store.Get("test")
|
||||
asserts.False(ok)
|
||||
}
|
||||
|
||||
func TestMemoStore_PersistFailed(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
store := NewMemoStore()
|
||||
type testStruct struct{ v string }
|
||||
store.Set("test", 1, 0)
|
||||
store.Set("test2", testStruct{v: "test"}, 0)
|
||||
err := store.Persist(filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed"))
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
func TestMemoStore_PersistAndRestore(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
store := NewMemoStore()
|
||||
store.Set("test", 1, 0)
|
||||
// already expired
|
||||
store.Store.Store("test2", itemWithTTL{Value: "test", Expires: 1})
|
||||
// expired after persist
|
||||
store.Set("test3", 1, 1)
|
||||
temp := filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed")
|
||||
|
||||
// Persist
|
||||
err := store.Persist(temp)
|
||||
a.NoError(err)
|
||||
a.FileExists(temp)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
// Restore
|
||||
store2 := NewMemoStore()
|
||||
err = store2.Restore(temp)
|
||||
a.NoError(err)
|
||||
test, testOk := store2.Get("test")
|
||||
a.EqualValues(1, test)
|
||||
a.True(testOk)
|
||||
test2, test2Ok := store2.Get("test2")
|
||||
a.Nil(test2)
|
||||
a.False(test2Ok)
|
||||
test3, test3Ok := store2.Get("test3")
|
||||
a.Nil(test3)
|
||||
a.False(test3Ok)
|
||||
|
||||
a.NoFileExists(temp)
|
||||
}
|
||||
|
||||
10
pkg/cache/redis.go
vendored
10
pkg/cache/redis.go
vendored
@@ -215,3 +215,13 @@ func (store *RedisStore) DeleteAll() error {
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Persist Dummy implementation
|
||||
func (store *RedisStore) Persist(path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restore dummy implementation
|
||||
func (store *RedisStore) Restore(path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
136
pkg/sessionstore/kv.go
Normal file
136
pkg/sessionstore/kv.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package sessionstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base32"
|
||||
"encoding/gob"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/gorilla/sessions"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type kvStore struct {
|
||||
Codecs []securecookie.Codec
|
||||
Options *sessions.Options
|
||||
DefaultMaxAge int
|
||||
|
||||
prefix string
|
||||
serializer SessionSerializer
|
||||
store cache.Driver
|
||||
}
|
||||
|
||||
func newKvStore(prefix string, store cache.Driver, keyPairs ...[]byte) *kvStore {
|
||||
return &kvStore{
|
||||
prefix: prefix,
|
||||
store: store,
|
||||
DefaultMaxAge: 60 * 20,
|
||||
serializer: GobSerializer{},
|
||||
Codecs: securecookie.CodecsFromPairs(keyPairs...),
|
||||
Options: &sessions.Options{
|
||||
Path: "/",
|
||||
MaxAge: 86400 * 30,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a session for the given name after adding it to the registry.
|
||||
//
|
||||
// It returns a new session if the sessions doesn't exist. Access IsNew on
|
||||
// the session to check if it is an existing session or a new one.
|
||||
//
|
||||
// It returns a new session and an error if the session exists but could
|
||||
// not be decoded.
|
||||
func (s *kvStore) Get(r *http.Request, name string) (*sessions.Session, error) {
|
||||
return sessions.GetRegistry(r).Get(s, name)
|
||||
}
|
||||
|
||||
// New returns a session for the given name without adding it to the registry.
|
||||
//
|
||||
// The difference between New() and Get() is that calling New() twice will
|
||||
// decode the session data twice, while Get() registers and reuses the same
|
||||
// decoded session after the first call.
|
||||
func (s *kvStore) New(r *http.Request, name string) (*sessions.Session, error) {
|
||||
var (
|
||||
err error
|
||||
)
|
||||
session := sessions.NewSession(s, name)
|
||||
// make a copy
|
||||
options := *s.Options
|
||||
session.Options = &options
|
||||
session.IsNew = true
|
||||
if c, errCookie := r.Cookie(name); errCookie == nil {
|
||||
err = securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...)
|
||||
if err == nil {
|
||||
res, ok := s.store.Get(s.prefix + session.ID)
|
||||
if ok {
|
||||
err = s.serializer.Deserialize(res.([]byte), session)
|
||||
}
|
||||
|
||||
session.IsNew = !(err == nil && ok) // not new if no error and data available
|
||||
}
|
||||
}
|
||||
return session, err
|
||||
}
|
||||
func (s *kvStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
|
||||
// Marked for deletion.
|
||||
if session.Options.MaxAge <= 0 {
|
||||
if err := s.store.Delete([]string{session.ID}, s.prefix); err != nil {
|
||||
return err
|
||||
}
|
||||
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
|
||||
} else {
|
||||
// Build an alphanumeric key for the redis store.
|
||||
if session.ID == "" {
|
||||
session.ID = strings.TrimRight(base32.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(32)), "=")
|
||||
}
|
||||
|
||||
b, err := s.serializer.Serialize(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
age := session.Options.MaxAge
|
||||
if age == 0 {
|
||||
age = s.DefaultMaxAge
|
||||
}
|
||||
|
||||
if err := s.store.Set(s.prefix+session.ID, b, age); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, s.Codecs...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SessionSerializer provides an interface hook for alternative serializers
|
||||
type SessionSerializer interface {
|
||||
Deserialize(d []byte, ss *sessions.Session) error
|
||||
Serialize(ss *sessions.Session) ([]byte, error)
|
||||
}
|
||||
|
||||
// GobSerializer uses gob package to encode the session map
|
||||
type GobSerializer struct{}
|
||||
|
||||
// Serialize using gob
|
||||
func (s GobSerializer) Serialize(ss *sessions.Session) ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
enc := gob.NewEncoder(buf)
|
||||
err := enc.Encode(ss.Values)
|
||||
if err == nil {
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Deserialize back to map[interface{}]interface{}
|
||||
func (s GobSerializer) Deserialize(d []byte, ss *sessions.Session) error {
|
||||
dec := gob.NewDecoder(bytes.NewBuffer(d))
|
||||
return dec.Decode(&ss.Values)
|
||||
}
|
||||
22
pkg/sessionstore/sessionstore.go
Normal file
22
pkg/sessionstore/sessionstore.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package sessionstore
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/gin-contrib/sessions"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
sessions.Store
|
||||
}
|
||||
|
||||
func NewStore(driver cache.Driver, keyPairs ...[]byte) Store {
|
||||
return &store{newKvStore("cd_session_", driver, keyPairs...)}
|
||||
}
|
||||
|
||||
type store struct {
|
||||
*kvStore
|
||||
}
|
||||
|
||||
func (c *store) Options(options sessions.Options) {
|
||||
c.kvStore.Options = options.ToGorillaOptions()
|
||||
}
|
||||
Reference in New Issue
Block a user