To gitea and beyond, let's go(-yco)
This commit is contained in:
360
internal/services/jwt_service.go
Normal file
360
internal/services/jwt_service.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenTypeAccess = "access"
|
||||
TokenTypeRefresh = "refresh"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidTokenType = errors.New("invalid token type")
|
||||
ErrTokenExpired = errors.New("token expired")
|
||||
ErrInvalidIssuer = errors.New("invalid issuer")
|
||||
ErrInvalidAudience = errors.New("invalid audience")
|
||||
ErrInvalidKeyID = errors.New("invalid key ID")
|
||||
ErrRefreshTokenExpired = errors.New("refresh token expired")
|
||||
ErrRefreshTokenInvalid = errors.New("refresh token invalid")
|
||||
)
|
||||
|
||||
type TokenClaims struct {
|
||||
UserID uint `json:"sub"`
|
||||
Username string `json:"username"`
|
||||
SessionVersion uint `json:"session_version"`
|
||||
TokenType string `json:"type"`
|
||||
KeyID string `json:"kid,omitempty"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type JWTService struct {
|
||||
config *config.JWTConfig
|
||||
userRepo UserRepository
|
||||
refreshRepo repositories.RefreshTokenRepositoryInterface
|
||||
}
|
||||
|
||||
type verificationKey struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
type UserRepository interface {
|
||||
GetByID(id uint) (*database.User, error)
|
||||
GetByUsername(username string) (*database.User, error)
|
||||
Update(user *database.User) error
|
||||
}
|
||||
|
||||
func NewJWTService(cfg *config.JWTConfig, userRepo UserRepository, refreshRepo repositories.RefreshTokenRepositoryInterface) *JWTService {
|
||||
return &JWTService{
|
||||
config: cfg,
|
||||
userRepo: userRepo,
|
||||
refreshRepo: refreshRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (j *JWTService) GenerateAccessToken(user *database.User) (string, error) {
|
||||
return j.generateToken(user, TokenTypeAccess, time.Duration(j.config.Expiration)*time.Hour)
|
||||
}
|
||||
|
||||
func (j *JWTService) GenerateRefreshToken(user *database.User) (string, error) {
|
||||
if user == nil {
|
||||
return "", ErrInvalidCredentials
|
||||
}
|
||||
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return "", fmt.Errorf("generate refresh token: %w", err)
|
||||
}
|
||||
|
||||
tokenString := hex.EncodeToString(tokenBytes)
|
||||
tokenHash := j.hashToken(tokenString)
|
||||
|
||||
refreshToken := &database.RefreshToken{
|
||||
UserID: user.ID,
|
||||
TokenHash: tokenHash,
|
||||
ExpiresAt: time.Now().Add(time.Duration(j.config.RefreshExpiration) * time.Hour),
|
||||
}
|
||||
|
||||
if err := j.refreshRepo.Create(refreshToken); err != nil {
|
||||
return "", fmt.Errorf("store refresh token: %w", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
func (j *JWTService) VerifyAccessToken(tokenString string) (uint, error) {
|
||||
claims, err := j.parseToken(tokenString)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if claims.TokenType != TokenTypeAccess {
|
||||
return 0, ErrInvalidTokenType
|
||||
}
|
||||
|
||||
user, err := j.userRepo.GetByID(claims.UserID)
|
||||
if err != nil {
|
||||
if IsRecordNotFound(err) {
|
||||
return 0, ErrInvalidToken
|
||||
}
|
||||
return 0, fmt.Errorf("lookup user: %w", err)
|
||||
}
|
||||
|
||||
if user.Locked {
|
||||
return 0, ErrAccountLocked
|
||||
}
|
||||
|
||||
if user.SessionVersion != claims.SessionVersion {
|
||||
return 0, ErrInvalidToken
|
||||
}
|
||||
|
||||
return claims.UserID, nil
|
||||
}
|
||||
|
||||
func (j *JWTService) RefreshAccessToken(refreshTokenString string) (string, error) {
|
||||
|
||||
tokenHash := j.hashToken(refreshTokenString)
|
||||
|
||||
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", ErrRefreshTokenInvalid
|
||||
}
|
||||
return "", fmt.Errorf("lookup refresh token: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().After(refreshToken.ExpiresAt) {
|
||||
|
||||
j.refreshRepo.DeleteByID(refreshToken.ID)
|
||||
return "", ErrRefreshTokenExpired
|
||||
}
|
||||
|
||||
user, err := j.userRepo.GetByID(refreshToken.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
|
||||
j.refreshRepo.DeleteByID(refreshToken.ID)
|
||||
return "", ErrRefreshTokenInvalid
|
||||
}
|
||||
return "", fmt.Errorf("lookup user: %w", err)
|
||||
}
|
||||
|
||||
if user.Locked {
|
||||
|
||||
j.refreshRepo.DeleteByID(refreshToken.ID)
|
||||
return "", ErrAccountLocked
|
||||
}
|
||||
|
||||
accessToken, err := j.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate access token: %w", err)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (j *JWTService) RevokeRefreshToken(refreshTokenString string) error {
|
||||
tokenHash := j.hashToken(refreshTokenString)
|
||||
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("lookup refresh token: %w", err)
|
||||
}
|
||||
|
||||
return j.refreshRepo.DeleteByID(refreshToken.ID)
|
||||
}
|
||||
|
||||
func (j *JWTService) RevokeAllRefreshTokens(userID uint) error {
|
||||
return j.refreshRepo.DeleteByUserID(userID)
|
||||
}
|
||||
|
||||
func (j *JWTService) CleanupExpiredTokens() error {
|
||||
return j.refreshRepo.DeleteExpired()
|
||||
}
|
||||
|
||||
func (j *JWTService) generateToken(user *database.User, tokenType string, expiration time.Duration) (string, error) {
|
||||
if user == nil {
|
||||
return "", ErrInvalidCredentials
|
||||
}
|
||||
|
||||
jtiBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(jtiBytes); err != nil {
|
||||
return "", fmt.Errorf("generate token ID: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
claims := TokenClaims{
|
||||
UserID: user.ID,
|
||||
Username: user.Username,
|
||||
SessionVersion: user.SessionVersion,
|
||||
TokenType: tokenType,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ID: hex.EncodeToString(jtiBytes),
|
||||
Issuer: j.config.Issuer,
|
||||
Audience: []string{j.config.Audience},
|
||||
Subject: fmt.Sprint(user.ID),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(expiration)),
|
||||
},
|
||||
}
|
||||
|
||||
if j.config.KeyRotation.Enabled {
|
||||
claims.KeyID = j.config.KeyRotation.KeyID
|
||||
}
|
||||
|
||||
var signingMethod jwt.SigningMethod
|
||||
var key any
|
||||
|
||||
if j.config.KeyRotation.Enabled {
|
||||
signingMethod = jwt.SigningMethodHS256
|
||||
key = []byte(j.config.KeyRotation.CurrentKey)
|
||||
} else {
|
||||
signingMethod = jwt.SigningMethodHS256
|
||||
key = []byte(j.config.Secret)
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(signingMethod, claims)
|
||||
|
||||
if j.config.KeyRotation.Enabled {
|
||||
token.Header["kid"] = j.config.KeyRotation.KeyID
|
||||
}
|
||||
|
||||
return token.SignedString(key)
|
||||
}
|
||||
|
||||
func (j *JWTService) parseToken(tokenString string) (*TokenClaims, error) {
|
||||
if TrimString(tokenString) == "" {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
parser := jwt.NewParser()
|
||||
unverified, _, err := parser.ParseUnverified(tokenString, jwt.MapClaims{})
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
headerKid, _ := unverified.Header["kid"].(string)
|
||||
|
||||
if j.config.KeyRotation.Enabled {
|
||||
if headerKid != "" && headerKid != j.config.KeyRotation.KeyID && j.config.KeyRotation.PreviousKey == "" {
|
||||
return nil, ErrInvalidKeyID
|
||||
}
|
||||
} else if headerKid != "" {
|
||||
return nil, ErrInvalidKeyID
|
||||
}
|
||||
|
||||
keys := j.verificationKeys()
|
||||
if len(keys) == 0 {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, candidate := range keys {
|
||||
claims := &TokenClaims{}
|
||||
token, err := jwt.ParseWithClaims(tokenString, claims, func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != jwt.SigningMethodHS256.Alg() {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return candidate.key, nil
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
lastErr = ErrInvalidToken
|
||||
continue
|
||||
}
|
||||
|
||||
if claims.ExpiresAt == nil || time.Until(claims.ExpiresAt.Time) < 0 {
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
lastErr = ErrInvalidToken
|
||||
continue
|
||||
}
|
||||
|
||||
if err := j.validateTokenMetadata(token, claims, headerKid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
func (j *JWTService) verificationKeys() []verificationKey {
|
||||
if j.config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if j.config.KeyRotation.Enabled {
|
||||
keys := []verificationKey{{key: []byte(j.config.KeyRotation.CurrentKey)}}
|
||||
if j.config.KeyRotation.PreviousKey != "" {
|
||||
keys = append(keys, verificationKey{key: []byte(j.config.KeyRotation.PreviousKey)})
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
return []verificationKey{{key: []byte(j.config.Secret)}}
|
||||
}
|
||||
|
||||
func (j *JWTService) validateTokenMetadata(token *jwt.Token, claims *TokenClaims, headerKid string) error {
|
||||
actualKid, _ := token.Header["kid"].(string)
|
||||
if actualKid == "" {
|
||||
actualKid = headerKid
|
||||
}
|
||||
|
||||
if j.config.KeyRotation.Enabled {
|
||||
if actualKid == "" {
|
||||
if claims.KeyID != "" {
|
||||
return ErrInvalidKeyID
|
||||
}
|
||||
} else {
|
||||
if claims.KeyID == "" || claims.KeyID != actualKid {
|
||||
return ErrInvalidKeyID
|
||||
}
|
||||
}
|
||||
|
||||
if actualKid != "" && actualKid != j.config.KeyRotation.KeyID && j.config.KeyRotation.PreviousKey == "" {
|
||||
return ErrInvalidKeyID
|
||||
}
|
||||
} else {
|
||||
if actualKid != "" || claims.KeyID != "" {
|
||||
return ErrInvalidKeyID
|
||||
}
|
||||
}
|
||||
|
||||
if claims.Issuer != j.config.Issuer {
|
||||
return ErrInvalidIssuer
|
||||
}
|
||||
|
||||
if !slices.Contains(claims.Audience, j.config.Audience) {
|
||||
return ErrInvalidAudience
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *JWTService) hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
Reference in New Issue
Block a user