967 lines
27 KiB
Go
967 lines
27 KiB
Go
package services
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"slices"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"gorm.io/gorm"
|
|
"goyco/internal/config"
|
|
"goyco/internal/database"
|
|
)
|
|
|
|
type jwtMockUserRepo struct {
|
|
users map[uint]*database.User
|
|
}
|
|
|
|
func (m *jwtMockUserRepo) GetByID(id uint) (*database.User, error) {
|
|
if user, exists := m.users[id]; exists {
|
|
return user, nil
|
|
}
|
|
return nil, gorm.ErrRecordNotFound
|
|
}
|
|
|
|
func (m *jwtMockUserRepo) GetByUsername(username string) (*database.User, error) {
|
|
for _, user := range m.users {
|
|
if user.Username == username {
|
|
return user, nil
|
|
}
|
|
}
|
|
return nil, gorm.ErrRecordNotFound
|
|
}
|
|
|
|
func (m *jwtMockUserRepo) Update(user *database.User) error {
|
|
if _, exists := m.users[user.ID]; !exists {
|
|
return gorm.ErrRecordNotFound
|
|
}
|
|
m.users[user.ID] = user
|
|
return nil
|
|
}
|
|
|
|
type jwtMockRefreshTokenRepo struct {
|
|
tokens map[string]*database.RefreshToken
|
|
nextID uint
|
|
}
|
|
|
|
func (m *jwtMockRefreshTokenRepo) Create(token *database.RefreshToken) error {
|
|
if m.tokens == nil {
|
|
m.tokens = make(map[string]*database.RefreshToken)
|
|
}
|
|
m.nextID++
|
|
token.ID = m.nextID
|
|
m.tokens[token.TokenHash] = token
|
|
return nil
|
|
}
|
|
|
|
func (m *jwtMockRefreshTokenRepo) GetByTokenHash(tokenHash string) (*database.RefreshToken, error) {
|
|
if token, exists := m.tokens[tokenHash]; exists {
|
|
return token, nil
|
|
}
|
|
return nil, gorm.ErrRecordNotFound
|
|
}
|
|
|
|
func (m *jwtMockRefreshTokenRepo) DeleteByUserID(userID uint) error {
|
|
for hash, token := range m.tokens {
|
|
if token.UserID == userID {
|
|
delete(m.tokens, hash)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *jwtMockRefreshTokenRepo) DeleteExpired() error {
|
|
now := time.Now()
|
|
for hash, token := range m.tokens {
|
|
if token.ExpiresAt.Before(now) {
|
|
delete(m.tokens, hash)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *jwtMockRefreshTokenRepo) DeleteByID(id uint) error {
|
|
for hash, token := range m.tokens {
|
|
if token.ID == id {
|
|
delete(m.tokens, hash)
|
|
return nil
|
|
}
|
|
}
|
|
return gorm.ErrRecordNotFound
|
|
}
|
|
|
|
func (m *jwtMockRefreshTokenRepo) GetByUserID(userID uint) ([]database.RefreshToken, error) {
|
|
var tokens []database.RefreshToken
|
|
for _, token := range m.tokens {
|
|
if token.UserID == userID {
|
|
tokens = append(tokens, *token)
|
|
}
|
|
}
|
|
return tokens, nil
|
|
}
|
|
|
|
func (m *jwtMockRefreshTokenRepo) CountByUserID(userID uint) (int64, error) {
|
|
var count int64
|
|
for _, token := range m.tokens {
|
|
if token.UserID == userID {
|
|
count++
|
|
}
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
func createTestJWTService() (*JWTService, *jwtMockUserRepo, *jwtMockRefreshTokenRepo) {
|
|
cfg := &config.JWTConfig{
|
|
Secret: "test-secret-key-that-is-long-enough-for-security",
|
|
Expiration: 1,
|
|
RefreshExpiration: 24,
|
|
Issuer: "test-issuer",
|
|
Audience: "test-audience",
|
|
KeyRotation: config.KeyRotationConfig{
|
|
Enabled: false,
|
|
CurrentKey: "",
|
|
PreviousKey: "",
|
|
KeyID: "",
|
|
},
|
|
}
|
|
|
|
userRepo := &jwtMockUserRepo{
|
|
users: make(map[uint]*database.User),
|
|
}
|
|
refreshRepo := &jwtMockRefreshTokenRepo{
|
|
tokens: make(map[string]*database.RefreshToken),
|
|
}
|
|
|
|
jwtService := NewJWTService(cfg, userRepo, refreshRepo)
|
|
return jwtService, userRepo, refreshRepo
|
|
}
|
|
|
|
func createTestUser() *database.User {
|
|
return &database.User{
|
|
ID: 1,
|
|
Username: "testuser",
|
|
Email: "test@example.com",
|
|
Password: "hashedpassword",
|
|
EmailVerified: true,
|
|
Locked: false,
|
|
SessionVersion: 1,
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
}
|
|
|
|
func TestJWTService_GenerateAccessToken(t *testing.T) {
|
|
jwtService, userRepo, _ := createTestJWTService()
|
|
user := createTestUser()
|
|
userRepo.users[user.ID] = user
|
|
|
|
t.Run("Successful_Generation", func(t *testing.T) {
|
|
token, err := jwtService.GenerateAccessToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Expected successful token generation, got error: %v", err)
|
|
}
|
|
|
|
if token == "" {
|
|
t.Error("Expected non-empty token")
|
|
}
|
|
|
|
parsedToken, err := jwt.Parse(token, func(token *jwt.Token) (any, error) {
|
|
return []byte(jwtService.config.Secret), nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse generated token: %v", err)
|
|
}
|
|
|
|
if !parsedToken.Valid {
|
|
t.Error("Generated token should be valid")
|
|
}
|
|
|
|
if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
|
|
if claims["sub"] != float64(user.ID) {
|
|
t.Errorf("Expected subject %d, got %v", user.ID, claims["sub"])
|
|
}
|
|
if claims["username"] != user.Username {
|
|
t.Errorf("Expected username %s, got %v", user.Username, claims["username"])
|
|
}
|
|
if claims["session_version"] != float64(user.SessionVersion) {
|
|
t.Errorf("Expected session_version %d, got %v", user.SessionVersion, claims["session_version"])
|
|
}
|
|
if claims["type"] != TokenTypeAccess {
|
|
t.Errorf("Expected type %s, got %v", TokenTypeAccess, claims["type"])
|
|
}
|
|
if claims["iss"] != jwtService.config.Issuer {
|
|
t.Errorf("Expected issuer %s, got %v", jwtService.config.Issuer, claims["iss"])
|
|
}
|
|
if aud, ok := claims["aud"].([]any); !ok || len(aud) != 1 || aud[0] != jwtService.config.Audience {
|
|
t.Errorf("Expected audience [%s], got %v", jwtService.config.Audience, claims["aud"])
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("Nil_User", func(t *testing.T) {
|
|
_, err := jwtService.GenerateAccessToken(nil)
|
|
if err == nil {
|
|
t.Error("Expected error for nil user")
|
|
}
|
|
if !errors.Is(err, ErrInvalidCredentials) {
|
|
t.Errorf("Expected ErrInvalidCredentials, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestJWTService_GenerateRefreshToken(t *testing.T) {
|
|
jwtService, userRepo, refreshRepo := createTestJWTService()
|
|
user := createTestUser()
|
|
userRepo.users[user.ID] = user
|
|
|
|
t.Run("Successful_Generation", func(t *testing.T) {
|
|
token, err := jwtService.GenerateRefreshToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Expected successful refresh token generation, got error: %v", err)
|
|
}
|
|
|
|
if token == "" {
|
|
t.Error("Expected non-empty refresh token")
|
|
}
|
|
|
|
tokenHash := jwtService.hashToken(token)
|
|
storedToken, err := refreshRepo.GetByTokenHash(tokenHash)
|
|
if err != nil {
|
|
t.Fatalf("Expected refresh token to be stored in database: %v", err)
|
|
}
|
|
|
|
if storedToken.UserID != user.ID {
|
|
t.Errorf("Expected user ID %d, got %d", user.ID, storedToken.UserID)
|
|
}
|
|
|
|
expectedExpiry := time.Now().Add(time.Duration(jwtService.config.RefreshExpiration) * time.Hour)
|
|
if storedToken.ExpiresAt.Before(expectedExpiry.Add(-time.Minute)) || storedToken.ExpiresAt.After(expectedExpiry.Add(time.Minute)) {
|
|
t.Errorf("Expected expiry around %v, got %v", expectedExpiry, storedToken.ExpiresAt)
|
|
}
|
|
})
|
|
|
|
t.Run("Nil_User", func(t *testing.T) {
|
|
_, err := jwtService.GenerateRefreshToken(nil)
|
|
if err == nil {
|
|
t.Error("Expected error for nil user")
|
|
}
|
|
if !errors.Is(err, ErrInvalidCredentials) {
|
|
t.Errorf("Expected ErrInvalidCredentials, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestJWTService_VerifyAccessToken(t *testing.T) {
|
|
jwtService, userRepo, _ := createTestJWTService()
|
|
user := createTestUser()
|
|
userRepo.users[user.ID] = user
|
|
|
|
t.Run("Valid_Token", func(t *testing.T) {
|
|
token, err := jwtService.GenerateAccessToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate token: %v", err)
|
|
}
|
|
|
|
userID, err := jwtService.VerifyAccessToken(token)
|
|
if err != nil {
|
|
t.Fatalf("Expected successful token verification, got error: %v", err)
|
|
}
|
|
|
|
if userID != user.ID {
|
|
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
|
}
|
|
})
|
|
|
|
t.Run("Invalid_Token", func(t *testing.T) {
|
|
_, err := jwtService.VerifyAccessToken("invalid-token")
|
|
if err == nil {
|
|
t.Error("Expected error for invalid token")
|
|
}
|
|
})
|
|
|
|
t.Run("Empty_Token", func(t *testing.T) {
|
|
_, err := jwtService.VerifyAccessToken("")
|
|
if err == nil {
|
|
t.Error("Expected error for empty token")
|
|
}
|
|
if !errors.Is(err, ErrInvalidToken) {
|
|
t.Errorf("Expected ErrInvalidToken, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("User_Not_Found", func(t *testing.T) {
|
|
token, err := jwtService.GenerateAccessToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate token: %v", err)
|
|
}
|
|
|
|
delete(userRepo.users, user.ID)
|
|
|
|
_, err = jwtService.VerifyAccessToken(token)
|
|
if err == nil {
|
|
t.Error("Expected error for non-existent user")
|
|
}
|
|
if !errors.Is(err, ErrInvalidToken) {
|
|
t.Errorf("Expected ErrInvalidToken, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("Locked_User", func(t *testing.T) {
|
|
user.Locked = true
|
|
userRepo.users[user.ID] = user
|
|
|
|
token, err := jwtService.GenerateAccessToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate token: %v", err)
|
|
}
|
|
|
|
_, err = jwtService.VerifyAccessToken(token)
|
|
if err == nil {
|
|
t.Error("Expected error for locked user")
|
|
}
|
|
if !errors.Is(err, ErrAccountLocked) {
|
|
t.Errorf("Expected ErrAccountLocked, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("Session_Version_Mismatch", func(t *testing.T) {
|
|
user.Locked = false
|
|
user.SessionVersion = 2
|
|
userRepo.users[user.ID] = user
|
|
|
|
oldUser := *user
|
|
oldUser.SessionVersion = 1
|
|
token, err := jwtService.GenerateAccessToken(&oldUser)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate token: %v", err)
|
|
}
|
|
|
|
_, err = jwtService.VerifyAccessToken(token)
|
|
if err == nil {
|
|
t.Error("Expected error for session version mismatch")
|
|
}
|
|
if !errors.Is(err, ErrInvalidToken) {
|
|
t.Errorf("Expected ErrInvalidToken, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestJWTService_RefreshAccessToken(t *testing.T) {
|
|
jwtService, userRepo, refreshRepo := createTestJWTService()
|
|
user := createTestUser()
|
|
userRepo.users[user.ID] = user
|
|
|
|
t.Run("Successful_Refresh", func(t *testing.T) {
|
|
|
|
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate refresh token: %v", err)
|
|
}
|
|
|
|
accessToken, err := jwtService.RefreshAccessToken(refreshToken)
|
|
if err != nil {
|
|
t.Fatalf("Expected successful token refresh, got error: %v", err)
|
|
}
|
|
|
|
if accessToken == "" {
|
|
t.Error("Expected non-empty access token")
|
|
}
|
|
|
|
userID, err := jwtService.VerifyAccessToken(accessToken)
|
|
if err != nil {
|
|
t.Fatalf("Expected valid access token, got error: %v", err)
|
|
}
|
|
|
|
if userID != user.ID {
|
|
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
|
}
|
|
})
|
|
|
|
t.Run("Invalid_Refresh_Token", func(t *testing.T) {
|
|
_, err := jwtService.RefreshAccessToken("invalid-refresh-token")
|
|
if err == nil {
|
|
t.Error("Expected error for invalid refresh token")
|
|
}
|
|
if !errors.Is(err, ErrRefreshTokenInvalid) {
|
|
t.Errorf("Expected ErrRefreshTokenInvalid, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("Expired_Refresh_Token", func(t *testing.T) {
|
|
|
|
refreshToken := &database.RefreshToken{
|
|
UserID: user.ID,
|
|
TokenHash: "expired-token-hash",
|
|
ExpiresAt: time.Now().Add(-time.Hour),
|
|
}
|
|
refreshRepo.tokens["expired-token-hash"] = refreshToken
|
|
|
|
testToken := "test-expired-token"
|
|
tokenHash := jwtService.hashToken(testToken)
|
|
refreshToken.TokenHash = tokenHash
|
|
refreshRepo.tokens[tokenHash] = refreshToken
|
|
|
|
_, err := jwtService.RefreshAccessToken(testToken)
|
|
if err == nil {
|
|
t.Error("Expected error for expired refresh token")
|
|
}
|
|
if !errors.Is(err, ErrRefreshTokenExpired) {
|
|
t.Errorf("Expected ErrRefreshTokenExpired, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("User_Not_Found", func(t *testing.T) {
|
|
|
|
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate refresh token: %v", err)
|
|
}
|
|
|
|
delete(userRepo.users, user.ID)
|
|
|
|
_, err = jwtService.RefreshAccessToken(refreshToken)
|
|
if err == nil {
|
|
t.Error("Expected error for non-existent user")
|
|
}
|
|
if !errors.Is(err, ErrRefreshTokenInvalid) {
|
|
t.Errorf("Expected ErrRefreshTokenInvalid, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("Locked_User", func(t *testing.T) {
|
|
user.Locked = true
|
|
userRepo.users[user.ID] = user
|
|
|
|
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate refresh token: %v", err)
|
|
}
|
|
|
|
_, err = jwtService.RefreshAccessToken(refreshToken)
|
|
if err == nil {
|
|
t.Error("Expected error for locked user")
|
|
}
|
|
if !errors.Is(err, ErrAccountLocked) {
|
|
t.Errorf("Expected ErrAccountLocked, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestJWTService_RevokeRefreshToken(t *testing.T) {
|
|
jwtService, userRepo, refreshRepo := createTestJWTService()
|
|
user := createTestUser()
|
|
userRepo.users[user.ID] = user
|
|
|
|
t.Run("Successful_Revocation", func(t *testing.T) {
|
|
|
|
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate refresh token: %v", err)
|
|
}
|
|
|
|
tokenHash := jwtService.hashToken(refreshToken)
|
|
_, err = refreshRepo.GetByTokenHash(tokenHash)
|
|
if err != nil {
|
|
t.Fatalf("Expected refresh token to exist: %v", err)
|
|
}
|
|
|
|
err = jwtService.RevokeRefreshToken(refreshToken)
|
|
if err != nil {
|
|
t.Fatalf("Expected successful token revocation, got error: %v", err)
|
|
}
|
|
|
|
_, err = refreshRepo.GetByTokenHash(tokenHash)
|
|
if err == nil {
|
|
t.Error("Expected refresh token to be removed")
|
|
}
|
|
})
|
|
|
|
t.Run("Non_Existent_Token", func(t *testing.T) {
|
|
err := jwtService.RevokeRefreshToken("non-existent-token")
|
|
if err != nil {
|
|
t.Errorf("Expected no error for non-existent token, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestJWTService_RevokeAllRefreshTokens(t *testing.T) {
|
|
jwtService, userRepo, refreshRepo := createTestJWTService()
|
|
user := createTestUser()
|
|
userRepo.users[user.ID] = user
|
|
|
|
t.Run("Successful_Revocation", func(t *testing.T) {
|
|
|
|
_, err := jwtService.GenerateRefreshToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate first refresh token: %v", err)
|
|
}
|
|
|
|
_, err = jwtService.GenerateRefreshToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate second refresh token: %v", err)
|
|
}
|
|
|
|
count, err := refreshRepo.CountByUserID(user.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to count tokens: %v", err)
|
|
}
|
|
if count != 2 {
|
|
t.Errorf("Expected 2 tokens, got %d", count)
|
|
}
|
|
|
|
err = jwtService.RevokeAllRefreshTokens(user.ID)
|
|
if err != nil {
|
|
t.Fatalf("Expected successful token revocation, got error: %v", err)
|
|
}
|
|
|
|
count, err = refreshRepo.CountByUserID(user.ID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to count tokens: %v", err)
|
|
}
|
|
if count != 0 {
|
|
t.Errorf("Expected 0 tokens, got %d", count)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestJWTService_CleanupExpiredTokens(t *testing.T) {
|
|
jwtService, userRepo, refreshRepo := createTestJWTService()
|
|
user := createTestUser()
|
|
userRepo.users[user.ID] = user
|
|
|
|
t.Run("Successful_Cleanup", func(t *testing.T) {
|
|
|
|
expiredToken := &database.RefreshToken{
|
|
UserID: user.ID,
|
|
TokenHash: "expired-token-hash",
|
|
ExpiresAt: time.Now().Add(-time.Hour),
|
|
}
|
|
refreshRepo.tokens["expired-token-hash"] = expiredToken
|
|
|
|
validToken, err := jwtService.GenerateRefreshToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate valid refresh token: %v", err)
|
|
}
|
|
|
|
if len(refreshRepo.tokens) != 2 {
|
|
t.Errorf("Expected 2 tokens, got %d", len(refreshRepo.tokens))
|
|
}
|
|
|
|
err = jwtService.CleanupExpiredTokens()
|
|
if err != nil {
|
|
t.Fatalf("Expected successful cleanup, got error: %v", err)
|
|
}
|
|
|
|
if len(refreshRepo.tokens) != 1 {
|
|
t.Errorf("Expected 1 token after cleanup, got %d", len(refreshRepo.tokens))
|
|
}
|
|
|
|
tokenHash := jwtService.hashToken(validToken)
|
|
_, exists := refreshRepo.tokens[tokenHash]
|
|
if !exists {
|
|
t.Error("Expected valid token to remain after cleanup")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestJWTService_KeyRotation(t *testing.T) {
|
|
cfg := &config.JWTConfig{
|
|
Secret: "old-secret-key-that-is-long-enough-for-security",
|
|
Expiration: 1,
|
|
RefreshExpiration: 24,
|
|
Issuer: "test-issuer",
|
|
Audience: "test-audience",
|
|
KeyRotation: config.KeyRotationConfig{
|
|
Enabled: true,
|
|
CurrentKey: "current-key-that-is-long-enough-for-security",
|
|
PreviousKey: "previous-key-that-is-long-enough-for-security",
|
|
KeyID: "current-key-id",
|
|
},
|
|
}
|
|
|
|
userRepo := &jwtMockUserRepo{
|
|
users: make(map[uint]*database.User),
|
|
}
|
|
refreshRepo := &jwtMockRefreshTokenRepo{
|
|
tokens: make(map[string]*database.RefreshToken),
|
|
}
|
|
|
|
jwtService := NewJWTService(cfg, userRepo, refreshRepo)
|
|
user := createTestUser()
|
|
userRepo.users[user.ID] = user
|
|
|
|
t.Run("Generate_Token_With_Key_Rotation", func(t *testing.T) {
|
|
token, err := jwtService.GenerateAccessToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Expected successful token generation with key rotation, got error: %v", err)
|
|
}
|
|
|
|
parsedToken, err := jwt.Parse(token, func(token *jwt.Token) (any, error) {
|
|
return []byte(cfg.KeyRotation.CurrentKey), nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse token with key rotation: %v", err)
|
|
}
|
|
|
|
if !parsedToken.Valid {
|
|
t.Error("Generated token should be valid")
|
|
}
|
|
|
|
if kid, ok := parsedToken.Header["kid"].(string); !ok || kid != cfg.KeyRotation.KeyID {
|
|
t.Errorf("Expected key ID %s, got %v", cfg.KeyRotation.KeyID, parsedToken.Header["kid"])
|
|
}
|
|
})
|
|
|
|
t.Run("Verify_Token_With_Current_Key", func(t *testing.T) {
|
|
token, err := jwtService.GenerateAccessToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate token: %v", err)
|
|
}
|
|
|
|
userID, err := jwtService.VerifyAccessToken(token)
|
|
if err != nil {
|
|
t.Fatalf("Expected successful token verification with current key, got error: %v", err)
|
|
}
|
|
|
|
if userID != user.ID {
|
|
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
|
}
|
|
})
|
|
|
|
t.Run("Legacy_Token_Without_KID_Remains_Valid", func(t *testing.T) {
|
|
legacyCfg := &config.JWTConfig{
|
|
Secret: "legacy-secret-key-that-is-long-enough-for-security",
|
|
Expiration: 1,
|
|
RefreshExpiration: 24,
|
|
Issuer: "legacy-issuer",
|
|
Audience: "legacy-audience",
|
|
KeyRotation: config.KeyRotationConfig{Enabled: false},
|
|
}
|
|
|
|
legacyUserRepo := &jwtMockUserRepo{users: map[uint]*database.User{user.ID: user}}
|
|
legacyRefreshRepo := &jwtMockRefreshTokenRepo{tokens: make(map[string]*database.RefreshToken)}
|
|
legacyService := NewJWTService(legacyCfg, legacyUserRepo, legacyRefreshRepo)
|
|
|
|
legacyToken, err := legacyService.GenerateAccessToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate legacy token: %v", err)
|
|
}
|
|
|
|
legacyCfg.KeyRotation.Enabled = true
|
|
legacyCfg.KeyRotation.CurrentKey = "rotated-current-key-that-is-long-enough-for-security"
|
|
legacyCfg.KeyRotation.PreviousKey = legacyCfg.Secret
|
|
legacyCfg.KeyRotation.KeyID = "rotated-key-id"
|
|
|
|
parsedUserID, err := legacyService.VerifyAccessToken(legacyToken)
|
|
if err != nil {
|
|
t.Fatalf("Legacy token should remain valid after enabling rotation: %v", err)
|
|
}
|
|
if parsedUserID != user.ID {
|
|
t.Fatalf("Expected user ID %d, got %d", user.ID, parsedUserID)
|
|
}
|
|
})
|
|
|
|
t.Run("Legacy_Token_With_Previous_KID_Remains_Valid", func(t *testing.T) {
|
|
rotCfg := &config.JWTConfig{
|
|
Secret: "unused-secret-key-that-is-long-enough-for-security",
|
|
Expiration: 1,
|
|
RefreshExpiration: 24,
|
|
Issuer: "rotation-issuer",
|
|
Audience: "rotation-audience",
|
|
KeyRotation: config.KeyRotationConfig{
|
|
Enabled: true,
|
|
CurrentKey: "rotation-key-v1-that-is-long-enough-for-security",
|
|
KeyID: "key-id-v1",
|
|
},
|
|
}
|
|
|
|
rotUserRepo := &jwtMockUserRepo{users: map[uint]*database.User{user.ID: user}}
|
|
rotRefreshRepo := &jwtMockRefreshTokenRepo{tokens: make(map[string]*database.RefreshToken)}
|
|
rotService := NewJWTService(rotCfg, rotUserRepo, rotRefreshRepo)
|
|
|
|
tokenV1, err := rotService.GenerateAccessToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate v1 token: %v", err)
|
|
}
|
|
|
|
rotCfg.KeyRotation.PreviousKey = rotCfg.KeyRotation.CurrentKey
|
|
rotCfg.KeyRotation.CurrentKey = "rotation-key-v2-that-is-long-enough-for-security"
|
|
rotCfg.KeyRotation.KeyID = "key-id-v2"
|
|
|
|
parsedUserID, err := rotService.VerifyAccessToken(tokenV1)
|
|
if err != nil {
|
|
t.Fatalf("Token signed with previous key should remain valid: %v", err)
|
|
}
|
|
if parsedUserID != user.ID {
|
|
t.Fatalf("Expected user ID %d, got %d", user.ID, parsedUserID)
|
|
}
|
|
})
|
|
|
|
t.Run("Unknown_KID_Is_Rejected", func(t *testing.T) {
|
|
cfg := &config.JWTConfig{
|
|
Secret: "unused-secret-key-that-is-long-enough-for-security",
|
|
Expiration: 1,
|
|
RefreshExpiration: 24,
|
|
Issuer: "issuer",
|
|
Audience: "audience",
|
|
KeyRotation: config.KeyRotationConfig{
|
|
Enabled: true,
|
|
CurrentKey: "current-key-for-unknown-kid-test-that-is-long-enough",
|
|
KeyID: "expected-key-id",
|
|
},
|
|
}
|
|
|
|
repo := &jwtMockUserRepo{users: map[uint]*database.User{user.ID: user}}
|
|
refreshRepo := &jwtMockRefreshTokenRepo{tokens: make(map[string]*database.RefreshToken)}
|
|
service := NewJWTService(cfg, repo, refreshRepo)
|
|
|
|
claims := TokenClaims{
|
|
UserID: user.ID,
|
|
Username: user.Username,
|
|
SessionVersion: user.SessionVersion,
|
|
TokenType: TokenTypeAccess,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: cfg.Issuer,
|
|
Audience: []string{cfg.Audience},
|
|
Subject: fmt.Sprint(user.ID),
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
token.Header["kid"] = "unexpected-key-id"
|
|
tokenString, err := token.SignedString([]byte(cfg.KeyRotation.CurrentKey))
|
|
if err != nil {
|
|
t.Fatalf("Failed to sign token: %v", err)
|
|
}
|
|
|
|
_, err = service.VerifyAccessToken(tokenString)
|
|
if !errors.Is(err, ErrInvalidKeyID) {
|
|
t.Fatalf("Expected ErrInvalidKeyID, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestJWTService_ErrorHandling(t *testing.T) {
|
|
jwtService, _, _ := createTestJWTService()
|
|
|
|
t.Run("Invalid_Issuer", func(t *testing.T) {
|
|
|
|
claims := TokenClaims{
|
|
UserID: 1,
|
|
Username: "testuser",
|
|
SessionVersion: 1,
|
|
TokenType: TokenTypeAccess,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: "wrong-issuer",
|
|
Audience: []string{jwtService.config.Audience},
|
|
Subject: "1",
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
tokenString, err := token.SignedString([]byte(jwtService.config.Secret))
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test token: %v", err)
|
|
}
|
|
|
|
_, err = jwtService.VerifyAccessToken(tokenString)
|
|
if err == nil {
|
|
t.Error("Expected error for invalid issuer")
|
|
}
|
|
if !errors.Is(err, ErrInvalidIssuer) {
|
|
t.Errorf("Expected ErrInvalidIssuer, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("Invalid_Audience", func(t *testing.T) {
|
|
|
|
claims := TokenClaims{
|
|
UserID: 1,
|
|
Username: "testuser",
|
|
SessionVersion: 1,
|
|
TokenType: TokenTypeAccess,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: jwtService.config.Issuer,
|
|
Audience: []string{"wrong-audience"},
|
|
Subject: "1",
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
tokenString, err := token.SignedString([]byte(jwtService.config.Secret))
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test token: %v", err)
|
|
}
|
|
|
|
_, err = jwtService.VerifyAccessToken(tokenString)
|
|
if err == nil {
|
|
t.Error("Expected error for invalid audience")
|
|
}
|
|
if !errors.Is(err, ErrInvalidAudience) {
|
|
t.Errorf("Expected ErrInvalidAudience, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("Expired_Token", func(t *testing.T) {
|
|
|
|
claims := TokenClaims{
|
|
UserID: 1,
|
|
Username: "testuser",
|
|
SessionVersion: 1,
|
|
TokenType: TokenTypeAccess,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: jwtService.config.Issuer,
|
|
Audience: []string{jwtService.config.Audience},
|
|
Subject: "1",
|
|
IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
tokenString, err := token.SignedString([]byte(jwtService.config.Secret))
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test token: %v", err)
|
|
}
|
|
|
|
_, err = jwtService.VerifyAccessToken(tokenString)
|
|
if err == nil {
|
|
t.Error("Expected error for expired token")
|
|
}
|
|
if !errors.Is(err, ErrTokenExpired) {
|
|
t.Errorf("Expected ErrTokenExpired, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("Subject_Mismatch", func(t *testing.T) {
|
|
claims := TokenClaims{
|
|
UserID: 1,
|
|
Username: "testuser",
|
|
SessionVersion: 1,
|
|
TokenType: TokenTypeAccess,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: jwtService.config.Issuer,
|
|
Audience: []string{jwtService.config.Audience},
|
|
Subject: "999",
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
tokenString, err := token.SignedString([]byte(jwtService.config.Secret))
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test token: %v", err)
|
|
}
|
|
|
|
_, err = jwtService.VerifyAccessToken(tokenString)
|
|
if !errors.Is(err, ErrInvalidToken) {
|
|
t.Fatalf("Expected ErrInvalidToken for subject mismatch, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestJWTService_HelperFunctions(t *testing.T) {
|
|
jwtService, _, _ := createTestJWTService()
|
|
|
|
t.Run("HashToken", func(t *testing.T) {
|
|
token := "test-token"
|
|
hash1 := jwtService.hashToken(token)
|
|
hash2 := jwtService.hashToken(token)
|
|
|
|
if hash1 != hash2 {
|
|
t.Error("Hash should be deterministic")
|
|
}
|
|
|
|
if hash1 == token {
|
|
t.Error("Hash should be different from original token")
|
|
}
|
|
|
|
hash3 := jwtService.hashToken("different-token")
|
|
if hash1 == hash3 {
|
|
t.Error("Different tokens should produce different hashes")
|
|
}
|
|
})
|
|
|
|
t.Run("Contains", func(t *testing.T) {
|
|
slice := []string{"item1", "item2", "item3"}
|
|
|
|
if !slices.Contains(slice, "item1") {
|
|
t.Error("Should contain item1")
|
|
}
|
|
|
|
if !slices.Contains(slice, "item2") {
|
|
t.Error("Should contain item2")
|
|
}
|
|
|
|
if slices.Contains(slice, "item4") {
|
|
t.Error("Should not contain item4")
|
|
}
|
|
|
|
if slices.Contains(slice, "") {
|
|
t.Error("Should not contain empty string")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestJWTService_Integration(t *testing.T) {
|
|
jwtService, userRepo, _ := createTestJWTService()
|
|
user := createTestUser()
|
|
userRepo.users[user.ID] = user
|
|
|
|
t.Run("Complete_Flow", func(t *testing.T) {
|
|
|
|
accessToken, err := jwtService.GenerateAccessToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate access token: %v", err)
|
|
}
|
|
|
|
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate refresh token: %v", err)
|
|
}
|
|
|
|
userID, err := jwtService.VerifyAccessToken(accessToken)
|
|
if err != nil {
|
|
t.Fatalf("Failed to verify access token: %v", err)
|
|
}
|
|
if userID != user.ID {
|
|
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
|
}
|
|
|
|
newAccessToken, err := jwtService.RefreshAccessToken(refreshToken)
|
|
if err != nil {
|
|
t.Fatalf("Failed to refresh access token: %v", err)
|
|
}
|
|
|
|
userID, err = jwtService.VerifyAccessToken(newAccessToken)
|
|
if err != nil {
|
|
t.Fatalf("Failed to verify new access token: %v", err)
|
|
}
|
|
if userID != user.ID {
|
|
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
|
}
|
|
|
|
err = jwtService.RevokeRefreshToken(refreshToken)
|
|
if err != nil {
|
|
t.Fatalf("Failed to revoke refresh token: %v", err)
|
|
}
|
|
|
|
_, err = jwtService.RefreshAccessToken(refreshToken)
|
|
if err == nil {
|
|
t.Error("Expected error when using revoked refresh token")
|
|
}
|
|
if !errors.Is(err, ErrRefreshTokenInvalid) {
|
|
t.Errorf("Expected ErrRefreshTokenInvalid, got %v", err)
|
|
}
|
|
})
|
|
}
|