361 lines
8.6 KiB
Go
361 lines
8.6 KiB
Go
package services
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"slices"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"gorm.io/gorm"
|
|
"goyco/internal/config"
|
|
"goyco/internal/database"
|
|
"goyco/internal/repositories"
|
|
)
|
|
|
|
const (
|
|
TokenTypeAccess = "access"
|
|
TokenTypeRefresh = "refresh"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidTokenType = errors.New("invalid token type")
|
|
ErrTokenExpired = errors.New("token expired")
|
|
ErrInvalidIssuer = errors.New("invalid issuer")
|
|
ErrInvalidAudience = errors.New("invalid audience")
|
|
ErrInvalidKeyID = errors.New("invalid key ID")
|
|
ErrRefreshTokenExpired = errors.New("refresh token expired")
|
|
ErrRefreshTokenInvalid = errors.New("refresh token invalid")
|
|
)
|
|
|
|
type TokenClaims struct {
|
|
UserID uint `json:"sub"`
|
|
Username string `json:"username"`
|
|
SessionVersion uint `json:"session_version"`
|
|
TokenType string `json:"type"`
|
|
KeyID string `json:"kid,omitempty"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
type JWTService struct {
|
|
config *config.JWTConfig
|
|
userRepo UserRepository
|
|
refreshRepo repositories.RefreshTokenRepositoryInterface
|
|
}
|
|
|
|
type verificationKey struct {
|
|
key []byte
|
|
}
|
|
|
|
type UserRepository interface {
|
|
GetByID(id uint) (*database.User, error)
|
|
GetByUsername(username string) (*database.User, error)
|
|
Update(user *database.User) error
|
|
}
|
|
|
|
func NewJWTService(cfg *config.JWTConfig, userRepo UserRepository, refreshRepo repositories.RefreshTokenRepositoryInterface) *JWTService {
|
|
return &JWTService{
|
|
config: cfg,
|
|
userRepo: userRepo,
|
|
refreshRepo: refreshRepo,
|
|
}
|
|
}
|
|
|
|
func (j *JWTService) GenerateAccessToken(user *database.User) (string, error) {
|
|
return j.generateToken(user, TokenTypeAccess, time.Duration(j.config.Expiration)*time.Hour)
|
|
}
|
|
|
|
func (j *JWTService) GenerateRefreshToken(user *database.User) (string, error) {
|
|
if user == nil {
|
|
return "", ErrInvalidCredentials
|
|
}
|
|
|
|
tokenBytes := make([]byte, 32)
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
|
return "", fmt.Errorf("generate refresh token: %w", err)
|
|
}
|
|
|
|
tokenString := hex.EncodeToString(tokenBytes)
|
|
tokenHash := j.hashToken(tokenString)
|
|
|
|
refreshToken := &database.RefreshToken{
|
|
UserID: user.ID,
|
|
TokenHash: tokenHash,
|
|
ExpiresAt: time.Now().Add(time.Duration(j.config.RefreshExpiration) * time.Hour),
|
|
}
|
|
|
|
if err := j.refreshRepo.Create(refreshToken); err != nil {
|
|
return "", fmt.Errorf("store refresh token: %w", err)
|
|
}
|
|
|
|
return tokenString, nil
|
|
}
|
|
|
|
func (j *JWTService) VerifyAccessToken(tokenString string) (uint, error) {
|
|
claims, err := j.parseToken(tokenString)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if claims.TokenType != TokenTypeAccess {
|
|
return 0, ErrInvalidTokenType
|
|
}
|
|
|
|
user, err := j.userRepo.GetByID(claims.UserID)
|
|
if err != nil {
|
|
if IsRecordNotFound(err) {
|
|
return 0, ErrInvalidToken
|
|
}
|
|
return 0, fmt.Errorf("lookup user: %w", err)
|
|
}
|
|
|
|
if user.Locked {
|
|
return 0, ErrAccountLocked
|
|
}
|
|
|
|
if user.SessionVersion != claims.SessionVersion {
|
|
return 0, ErrInvalidToken
|
|
}
|
|
|
|
return claims.UserID, nil
|
|
}
|
|
|
|
func (j *JWTService) RefreshAccessToken(refreshTokenString string) (string, error) {
|
|
|
|
tokenHash := j.hashToken(refreshTokenString)
|
|
|
|
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return "", ErrRefreshTokenInvalid
|
|
}
|
|
return "", fmt.Errorf("lookup refresh token: %w", err)
|
|
}
|
|
|
|
if time.Now().After(refreshToken.ExpiresAt) {
|
|
|
|
j.refreshRepo.DeleteByID(refreshToken.ID)
|
|
return "", ErrRefreshTokenExpired
|
|
}
|
|
|
|
user, err := j.userRepo.GetByID(refreshToken.UserID)
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
j.refreshRepo.DeleteByID(refreshToken.ID)
|
|
return "", ErrRefreshTokenInvalid
|
|
}
|
|
return "", fmt.Errorf("lookup user: %w", err)
|
|
}
|
|
|
|
if user.Locked {
|
|
|
|
j.refreshRepo.DeleteByID(refreshToken.ID)
|
|
return "", ErrAccountLocked
|
|
}
|
|
|
|
accessToken, err := j.GenerateAccessToken(user)
|
|
if err != nil {
|
|
return "", fmt.Errorf("generate access token: %w", err)
|
|
}
|
|
|
|
return accessToken, nil
|
|
}
|
|
|
|
func (j *JWTService) RevokeRefreshToken(refreshTokenString string) error {
|
|
tokenHash := j.hashToken(refreshTokenString)
|
|
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("lookup refresh token: %w", err)
|
|
}
|
|
|
|
return j.refreshRepo.DeleteByID(refreshToken.ID)
|
|
}
|
|
|
|
func (j *JWTService) RevokeAllRefreshTokens(userID uint) error {
|
|
return j.refreshRepo.DeleteByUserID(userID)
|
|
}
|
|
|
|
func (j *JWTService) CleanupExpiredTokens() error {
|
|
return j.refreshRepo.DeleteExpired()
|
|
}
|
|
|
|
func (j *JWTService) generateToken(user *database.User, tokenType string, expiration time.Duration) (string, error) {
|
|
if user == nil {
|
|
return "", ErrInvalidCredentials
|
|
}
|
|
|
|
jtiBytes := make([]byte, 16)
|
|
if _, err := rand.Read(jtiBytes); err != nil {
|
|
return "", fmt.Errorf("generate token ID: %w", err)
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := TokenClaims{
|
|
UserID: user.ID,
|
|
Username: user.Username,
|
|
SessionVersion: user.SessionVersion,
|
|
TokenType: tokenType,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ID: hex.EncodeToString(jtiBytes),
|
|
Issuer: j.config.Issuer,
|
|
Audience: []string{j.config.Audience},
|
|
Subject: fmt.Sprint(user.ID),
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(expiration)),
|
|
},
|
|
}
|
|
|
|
if j.config.KeyRotation.Enabled {
|
|
claims.KeyID = j.config.KeyRotation.KeyID
|
|
}
|
|
|
|
var signingMethod jwt.SigningMethod
|
|
var key any
|
|
|
|
if j.config.KeyRotation.Enabled {
|
|
signingMethod = jwt.SigningMethodHS256
|
|
key = []byte(j.config.KeyRotation.CurrentKey)
|
|
} else {
|
|
signingMethod = jwt.SigningMethodHS256
|
|
key = []byte(j.config.Secret)
|
|
}
|
|
|
|
token := jwt.NewWithClaims(signingMethod, claims)
|
|
|
|
if j.config.KeyRotation.Enabled {
|
|
token.Header["kid"] = j.config.KeyRotation.KeyID
|
|
}
|
|
|
|
return token.SignedString(key)
|
|
}
|
|
|
|
func (j *JWTService) parseToken(tokenString string) (*TokenClaims, error) {
|
|
if TrimString(tokenString) == "" {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
parser := jwt.NewParser()
|
|
unverified, _, err := parser.ParseUnverified(tokenString, jwt.MapClaims{})
|
|
if err != nil {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
headerKid, _ := unverified.Header["kid"].(string)
|
|
|
|
if j.config.KeyRotation.Enabled {
|
|
if headerKid != "" && headerKid != j.config.KeyRotation.KeyID && j.config.KeyRotation.PreviousKey == "" {
|
|
return nil, ErrInvalidKeyID
|
|
}
|
|
} else if headerKid != "" {
|
|
return nil, ErrInvalidKeyID
|
|
}
|
|
|
|
keys := j.verificationKeys()
|
|
if len(keys) == 0 {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
var lastErr error
|
|
for _, candidate := range keys {
|
|
claims := &TokenClaims{}
|
|
token, err := jwt.ParseWithClaims(tokenString, claims, func(t *jwt.Token) (any, error) {
|
|
if t.Method.Alg() != jwt.SigningMethodHS256.Alg() {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
|
}
|
|
return candidate.key, nil
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, jwt.ErrTokenExpired) {
|
|
return nil, ErrTokenExpired
|
|
}
|
|
lastErr = ErrInvalidToken
|
|
continue
|
|
}
|
|
|
|
if claims.ExpiresAt == nil || time.Until(claims.ExpiresAt.Time) < 0 {
|
|
return nil, ErrTokenExpired
|
|
}
|
|
|
|
if !token.Valid {
|
|
lastErr = ErrInvalidToken
|
|
continue
|
|
}
|
|
|
|
if err := j.validateTokenMetadata(token, claims, headerKid); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
if lastErr != nil {
|
|
return nil, lastErr
|
|
}
|
|
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
func (j *JWTService) verificationKeys() []verificationKey {
|
|
if j.config == nil {
|
|
return nil
|
|
}
|
|
|
|
if j.config.KeyRotation.Enabled {
|
|
keys := []verificationKey{{key: []byte(j.config.KeyRotation.CurrentKey)}}
|
|
if j.config.KeyRotation.PreviousKey != "" {
|
|
keys = append(keys, verificationKey{key: []byte(j.config.KeyRotation.PreviousKey)})
|
|
}
|
|
return keys
|
|
}
|
|
|
|
return []verificationKey{{key: []byte(j.config.Secret)}}
|
|
}
|
|
|
|
func (j *JWTService) validateTokenMetadata(token *jwt.Token, claims *TokenClaims, headerKid string) error {
|
|
actualKid, _ := token.Header["kid"].(string)
|
|
if actualKid == "" {
|
|
actualKid = headerKid
|
|
}
|
|
|
|
if j.config.KeyRotation.Enabled {
|
|
if actualKid == "" {
|
|
if claims.KeyID != "" {
|
|
return ErrInvalidKeyID
|
|
}
|
|
} else {
|
|
if claims.KeyID == "" || claims.KeyID != actualKid {
|
|
return ErrInvalidKeyID
|
|
}
|
|
}
|
|
|
|
if actualKid != "" && actualKid != j.config.KeyRotation.KeyID && j.config.KeyRotation.PreviousKey == "" {
|
|
return ErrInvalidKeyID
|
|
}
|
|
} else {
|
|
if actualKid != "" || claims.KeyID != "" {
|
|
return ErrInvalidKeyID
|
|
}
|
|
}
|
|
|
|
if claims.Issuer != j.config.Issuer {
|
|
return ErrInvalidIssuer
|
|
}
|
|
|
|
if !slices.Contains(claims.Audience, j.config.Audience) {
|
|
return ErrInvalidAudience
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (j *JWTService) hashToken(token string) string {
|
|
hash := sha256.Sum256([]byte(token))
|
|
return hex.EncodeToString(hash[:])
|
|
}
|