Files
memora/memora-api/internal/service/auth.go

120 lines
2.7 KiB
Go

package service
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"os"
"strconv"
"strings"
"time"
"memora-api/internal/model"
"memora-api/internal/request"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type AuthService struct {
db *gorm.DB
}
func NewAuthService(db *gorm.DB) *AuthService {
return &AuthService{db: db}
}
func jwtSecret() []byte {
secret := os.Getenv("MEMORA_JWT_SECRET")
if secret == "" {
secret = "memora-dev-secret-change-me"
}
return []byte(secret)
}
func (s *AuthService) Register(req request.RegisterRequest) (*model.User, error) {
var exists model.User
if err := s.db.Where("email = ?", req.Email).First(&exists).Error; err == nil {
return nil, errors.New("邮箱已注册")
}
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
user := model.User{
Email: req.Email,
Name: req.Name,
PasswordHash: string(hash),
}
if err := s.db.Create(&user).Error; err != nil {
return nil, err
}
return &user, nil
}
func sign(data string) string {
mac := hmac.New(sha256.New, jwtSecret())
_, _ = mac.Write([]byte(data))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}
func makeToken(uid int64) string {
exp := time.Now().Add(7 * 24 * time.Hour).Unix()
payload := fmt.Sprintf("%d:%d", uid, exp)
encPayload := base64.RawURLEncoding.EncodeToString([]byte(payload))
sig := sign(encPayload)
return encPayload + "." + sig
}
func (s *AuthService) Login(req request.LoginRequest) (string, *model.User, error) {
var user model.User
if err := s.db.Where("email = ?", req.Email).First(&user).Error; err != nil {
return "", nil, errors.New("账号或密码错误")
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
return "", nil, errors.New("账号或密码错误")
}
return makeToken(user.ID), &user, nil
}
func ParseToken(tokenString string) (int64, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 2 {
return 0, errors.New("token无效")
}
encPayload, gotSig := parts[0], parts[1]
if sign(encPayload) != gotSig {
return 0, errors.New("token无效")
}
payloadBytes, err := base64.RawURLEncoding.DecodeString(encPayload)
if err != nil {
return 0, errors.New("token无效")
}
payloadParts := strings.Split(string(payloadBytes), ":")
if len(payloadParts) != 2 {
return 0, errors.New("token无效")
}
uid, err := strconv.ParseInt(payloadParts[0], 10, 64)
if err != nil {
return 0, errors.New("token无效")
}
exp, err := strconv.ParseInt(payloadParts[1], 10, 64)
if err != nil {
return 0, errors.New("token无效")
}
if time.Now().Unix() > exp {
return 0, errors.New("token已过期")
}
return uid, nil
}