feat: add auth flow and login guard for api/web
This commit is contained in:
119
memora-api/internal/service/auth.go
Normal file
119
memora-api/internal/service/auth.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"memora-api/internal/model"
|
||||
"memora-api/internal/request"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AuthService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAuthService(db *gorm.DB) *AuthService {
|
||||
return &AuthService{db: db}
|
||||
}
|
||||
|
||||
func jwtSecret() []byte {
|
||||
secret := os.Getenv("MEMORA_JWT_SECRET")
|
||||
if secret == "" {
|
||||
secret = "memora-dev-secret-change-me"
|
||||
}
|
||||
return []byte(secret)
|
||||
}
|
||||
|
||||
func (s *AuthService) Register(req request.RegisterRequest) (*model.User, error) {
|
||||
var exists model.User
|
||||
if err := s.db.Where("email = ?", req.Email).First(&exists).Error; err == nil {
|
||||
return nil, errors.New("邮箱已注册")
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
Email: req.Email,
|
||||
Name: req.Name,
|
||||
PasswordHash: string(hash),
|
||||
}
|
||||
|
||||
if err := s.db.Create(&user).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func sign(data string) string {
|
||||
mac := hmac.New(sha256.New, jwtSecret())
|
||||
_, _ = mac.Write([]byte(data))
|
||||
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
func makeToken(uid int64) string {
|
||||
exp := time.Now().Add(7 * 24 * time.Hour).Unix()
|
||||
payload := fmt.Sprintf("%d:%d", uid, exp)
|
||||
encPayload := base64.RawURLEncoding.EncodeToString([]byte(payload))
|
||||
sig := sign(encPayload)
|
||||
return encPayload + "." + sig
|
||||
}
|
||||
|
||||
func (s *AuthService) Login(req request.LoginRequest) (string, *model.User, error) {
|
||||
var user model.User
|
||||
if err := s.db.Where("email = ?", req.Email).First(&user).Error; err != nil {
|
||||
return "", nil, errors.New("账号或密码错误")
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
|
||||
return "", nil, errors.New("账号或密码错误")
|
||||
}
|
||||
|
||||
return makeToken(user.ID), &user, nil
|
||||
}
|
||||
|
||||
func ParseToken(tokenString string) (int64, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 2 {
|
||||
return 0, errors.New("token无效")
|
||||
}
|
||||
encPayload, gotSig := parts[0], parts[1]
|
||||
if sign(encPayload) != gotSig {
|
||||
return 0, errors.New("token无效")
|
||||
}
|
||||
|
||||
payloadBytes, err := base64.RawURLEncoding.DecodeString(encPayload)
|
||||
if err != nil {
|
||||
return 0, errors.New("token无效")
|
||||
}
|
||||
|
||||
payloadParts := strings.Split(string(payloadBytes), ":")
|
||||
if len(payloadParts) != 2 {
|
||||
return 0, errors.New("token无效")
|
||||
}
|
||||
uid, err := strconv.ParseInt(payloadParts[0], 10, 64)
|
||||
if err != nil {
|
||||
return 0, errors.New("token无效")
|
||||
}
|
||||
exp, err := strconv.ParseInt(payloadParts[1], 10, 64)
|
||||
if err != nil {
|
||||
return 0, errors.New("token无效")
|
||||
}
|
||||
if time.Now().Unix() > exp {
|
||||
return 0, errors.New("token已过期")
|
||||
}
|
||||
|
||||
return uid, nil
|
||||
}
|
||||
Reference in New Issue
Block a user