package service import ( "crypto/hmac" "crypto/sha256" "encoding/base64" "errors" "fmt" "os" "strconv" "strings" "time" "memora-api/internal/model" "memora-api/internal/request" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) type AuthService struct { db *gorm.DB } func NewAuthService(db *gorm.DB) *AuthService { return &AuthService{db: db} } func jwtSecret() []byte { secret := os.Getenv("MEMORA_JWT_SECRET") if secret == "" { secret = "memora-dev-secret-change-me" } return []byte(secret) } func (s *AuthService) Register(req request.RegisterRequest) (*model.User, error) { var exists model.User if err := s.db.Where("email = ?", req.Email).First(&exists).Error; err == nil { return nil, errors.New("邮箱已注册") } hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) if err != nil { return nil, err } user := model.User{ Email: req.Email, Name: req.Name, PasswordHash: string(hash), } if err := s.db.Create(&user).Error; err != nil { return nil, err } return &user, nil } func sign(data string) string { mac := hmac.New(sha256.New, jwtSecret()) _, _ = mac.Write([]byte(data)) return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) } func makeToken(uid string) string { exp := time.Now().Add(7 * 24 * time.Hour).Unix() payload := fmt.Sprintf("%s:%d", uid, exp) encPayload := base64.RawURLEncoding.EncodeToString([]byte(payload)) sig := sign(encPayload) return encPayload + "." + sig } func (s *AuthService) Login(req request.LoginRequest) (string, *model.User, error) { var user model.User if err := s.db.Where("email = ?", req.Email).First(&user).Error; err != nil { return "", nil, errors.New("账号或密码错误") } if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil { return "", nil, errors.New("账号或密码错误") } return makeToken(user.ID), &user, nil } func ParseToken(tokenString string) (string, error) { parts := strings.Split(tokenString, ".") if len(parts) != 2 { return "", errors.New("token无效") } encPayload, gotSig := parts[0], parts[1] if sign(encPayload) != gotSig { return "", errors.New("token无效") } payloadBytes, err := base64.RawURLEncoding.DecodeString(encPayload) if err != nil { return "", errors.New("token无效") } payloadParts := strings.Split(string(payloadBytes), ":") if len(payloadParts) != 2 { return "", errors.New("token无效") } uid := payloadParts[0] if uid == "" { return "", errors.New("token无效") } exp, err := strconv.ParseInt(payloadParts[1], 10, 64) if err != nil { return "", errors.New("token无效") } if time.Now().Unix() > exp { return "", errors.New("token已过期") } return uid, nil }