To gitea and beyond, let's go(-yco)

This commit is contained in:
2025-11-10 19:12:09 +01:00
parent 8f6133392d
commit 71a031342b
245 changed files with 83994 additions and 0 deletions

View File

@@ -0,0 +1,360 @@
package services
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"slices"
"time"
"github.com/golang-jwt/jwt/v5"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/repositories"
)
const (
TokenTypeAccess = "access"
TokenTypeRefresh = "refresh"
)
var (
ErrInvalidTokenType = errors.New("invalid token type")
ErrTokenExpired = errors.New("token expired")
ErrInvalidIssuer = errors.New("invalid issuer")
ErrInvalidAudience = errors.New("invalid audience")
ErrInvalidKeyID = errors.New("invalid key ID")
ErrRefreshTokenExpired = errors.New("refresh token expired")
ErrRefreshTokenInvalid = errors.New("refresh token invalid")
)
type TokenClaims struct {
UserID uint `json:"sub"`
Username string `json:"username"`
SessionVersion uint `json:"session_version"`
TokenType string `json:"type"`
KeyID string `json:"kid,omitempty"`
jwt.RegisteredClaims
}
type JWTService struct {
config *config.JWTConfig
userRepo UserRepository
refreshRepo repositories.RefreshTokenRepositoryInterface
}
type verificationKey struct {
key []byte
}
type UserRepository interface {
GetByID(id uint) (*database.User, error)
GetByUsername(username string) (*database.User, error)
Update(user *database.User) error
}
func NewJWTService(cfg *config.JWTConfig, userRepo UserRepository, refreshRepo repositories.RefreshTokenRepositoryInterface) *JWTService {
return &JWTService{
config: cfg,
userRepo: userRepo,
refreshRepo: refreshRepo,
}
}
func (j *JWTService) GenerateAccessToken(user *database.User) (string, error) {
return j.generateToken(user, TokenTypeAccess, time.Duration(j.config.Expiration)*time.Hour)
}
func (j *JWTService) GenerateRefreshToken(user *database.User) (string, error) {
if user == nil {
return "", ErrInvalidCredentials
}
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
return "", fmt.Errorf("generate refresh token: %w", err)
}
tokenString := hex.EncodeToString(tokenBytes)
tokenHash := j.hashToken(tokenString)
refreshToken := &database.RefreshToken{
UserID: user.ID,
TokenHash: tokenHash,
ExpiresAt: time.Now().Add(time.Duration(j.config.RefreshExpiration) * time.Hour),
}
if err := j.refreshRepo.Create(refreshToken); err != nil {
return "", fmt.Errorf("store refresh token: %w", err)
}
return tokenString, nil
}
func (j *JWTService) VerifyAccessToken(tokenString string) (uint, error) {
claims, err := j.parseToken(tokenString)
if err != nil {
return 0, err
}
if claims.TokenType != TokenTypeAccess {
return 0, ErrInvalidTokenType
}
user, err := j.userRepo.GetByID(claims.UserID)
if err != nil {
if IsRecordNotFound(err) {
return 0, ErrInvalidToken
}
return 0, fmt.Errorf("lookup user: %w", err)
}
if user.Locked {
return 0, ErrAccountLocked
}
if user.SessionVersion != claims.SessionVersion {
return 0, ErrInvalidToken
}
return claims.UserID, nil
}
func (j *JWTService) RefreshAccessToken(refreshTokenString string) (string, error) {
tokenHash := j.hashToken(refreshTokenString)
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrRefreshTokenInvalid
}
return "", fmt.Errorf("lookup refresh token: %w", err)
}
if time.Now().After(refreshToken.ExpiresAt) {
j.refreshRepo.DeleteByID(refreshToken.ID)
return "", ErrRefreshTokenExpired
}
user, err := j.userRepo.GetByID(refreshToken.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
j.refreshRepo.DeleteByID(refreshToken.ID)
return "", ErrRefreshTokenInvalid
}
return "", fmt.Errorf("lookup user: %w", err)
}
if user.Locked {
j.refreshRepo.DeleteByID(refreshToken.ID)
return "", ErrAccountLocked
}
accessToken, err := j.GenerateAccessToken(user)
if err != nil {
return "", fmt.Errorf("generate access token: %w", err)
}
return accessToken, nil
}
func (j *JWTService) RevokeRefreshToken(refreshTokenString string) error {
tokenHash := j.hashToken(refreshTokenString)
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
return fmt.Errorf("lookup refresh token: %w", err)
}
return j.refreshRepo.DeleteByID(refreshToken.ID)
}
func (j *JWTService) RevokeAllRefreshTokens(userID uint) error {
return j.refreshRepo.DeleteByUserID(userID)
}
func (j *JWTService) CleanupExpiredTokens() error {
return j.refreshRepo.DeleteExpired()
}
func (j *JWTService) generateToken(user *database.User, tokenType string, expiration time.Duration) (string, error) {
if user == nil {
return "", ErrInvalidCredentials
}
jtiBytes := make([]byte, 16)
if _, err := rand.Read(jtiBytes); err != nil {
return "", fmt.Errorf("generate token ID: %w", err)
}
now := time.Now()
claims := TokenClaims{
UserID: user.ID,
Username: user.Username,
SessionVersion: user.SessionVersion,
TokenType: tokenType,
RegisteredClaims: jwt.RegisteredClaims{
ID: hex.EncodeToString(jtiBytes),
Issuer: j.config.Issuer,
Audience: []string{j.config.Audience},
Subject: fmt.Sprint(user.ID),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(expiration)),
},
}
if j.config.KeyRotation.Enabled {
claims.KeyID = j.config.KeyRotation.KeyID
}
var signingMethod jwt.SigningMethod
var key any
if j.config.KeyRotation.Enabled {
signingMethod = jwt.SigningMethodHS256
key = []byte(j.config.KeyRotation.CurrentKey)
} else {
signingMethod = jwt.SigningMethodHS256
key = []byte(j.config.Secret)
}
token := jwt.NewWithClaims(signingMethod, claims)
if j.config.KeyRotation.Enabled {
token.Header["kid"] = j.config.KeyRotation.KeyID
}
return token.SignedString(key)
}
func (j *JWTService) parseToken(tokenString string) (*TokenClaims, error) {
if TrimString(tokenString) == "" {
return nil, ErrInvalidToken
}
parser := jwt.NewParser()
unverified, _, err := parser.ParseUnverified(tokenString, jwt.MapClaims{})
if err != nil {
return nil, ErrInvalidToken
}
headerKid, _ := unverified.Header["kid"].(string)
if j.config.KeyRotation.Enabled {
if headerKid != "" && headerKid != j.config.KeyRotation.KeyID && j.config.KeyRotation.PreviousKey == "" {
return nil, ErrInvalidKeyID
}
} else if headerKid != "" {
return nil, ErrInvalidKeyID
}
keys := j.verificationKeys()
if len(keys) == 0 {
return nil, ErrInvalidToken
}
var lastErr error
for _, candidate := range keys {
claims := &TokenClaims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Alg() {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return candidate.key, nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrTokenExpired
}
lastErr = ErrInvalidToken
continue
}
if claims.ExpiresAt == nil || time.Until(claims.ExpiresAt.Time) < 0 {
return nil, ErrTokenExpired
}
if !token.Valid {
lastErr = ErrInvalidToken
continue
}
if err := j.validateTokenMetadata(token, claims, headerKid); err != nil {
return nil, err
}
return claims, nil
}
if lastErr != nil {
return nil, lastErr
}
return nil, ErrInvalidToken
}
func (j *JWTService) verificationKeys() []verificationKey {
if j.config == nil {
return nil
}
if j.config.KeyRotation.Enabled {
keys := []verificationKey{{key: []byte(j.config.KeyRotation.CurrentKey)}}
if j.config.KeyRotation.PreviousKey != "" {
keys = append(keys, verificationKey{key: []byte(j.config.KeyRotation.PreviousKey)})
}
return keys
}
return []verificationKey{{key: []byte(j.config.Secret)}}
}
func (j *JWTService) validateTokenMetadata(token *jwt.Token, claims *TokenClaims, headerKid string) error {
actualKid, _ := token.Header["kid"].(string)
if actualKid == "" {
actualKid = headerKid
}
if j.config.KeyRotation.Enabled {
if actualKid == "" {
if claims.KeyID != "" {
return ErrInvalidKeyID
}
} else {
if claims.KeyID == "" || claims.KeyID != actualKid {
return ErrInvalidKeyID
}
}
if actualKid != "" && actualKid != j.config.KeyRotation.KeyID && j.config.KeyRotation.PreviousKey == "" {
return ErrInvalidKeyID
}
} else {
if actualKid != "" || claims.KeyID != "" {
return ErrInvalidKeyID
}
}
if claims.Issuer != j.config.Issuer {
return ErrInvalidIssuer
}
if !slices.Contains(claims.Audience, j.config.Audience) {
return ErrInvalidAudience
}
return nil
}
func (j *JWTService) hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}