Files
goyco/internal/services/session_service_test.go

564 lines
15 KiB
Go

package services
import (
"errors"
"testing"
"time"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/testutils"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type sessionMockRefreshTokenRepo struct {
tokens map[string]*database.RefreshToken
createErr error
deleteByUserIDErr error
deleteExpiredErr error
getByTokenHashErr error
}
func newSessionMockRefreshTokenRepo() *sessionMockRefreshTokenRepo {
return &sessionMockRefreshTokenRepo{
tokens: make(map[string]*database.RefreshToken),
}
}
func (m *sessionMockRefreshTokenRepo) Create(token *database.RefreshToken) error {
if m.createErr != nil {
return m.createErr
}
if m.tokens == nil {
m.tokens = make(map[string]*database.RefreshToken)
}
m.tokens[token.TokenHash] = token
return nil
}
func (m *sessionMockRefreshTokenRepo) GetByTokenHash(tokenHash string) (*database.RefreshToken, error) {
if m.getByTokenHashErr != nil {
return nil, m.getByTokenHashErr
}
if token, ok := m.tokens[tokenHash]; ok {
return token, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *sessionMockRefreshTokenRepo) DeleteByUserID(userID uint) error {
if m.deleteByUserIDErr != nil {
return m.deleteByUserIDErr
}
for hash, token := range m.tokens {
if token.UserID == userID {
delete(m.tokens, hash)
}
}
return nil
}
func (m *sessionMockRefreshTokenRepo) DeleteExpired() error {
if m.deleteExpiredErr != nil {
return m.deleteExpiredErr
}
now := time.Now()
for hash, token := range m.tokens {
if token.ExpiresAt.Before(now) {
delete(m.tokens, hash)
}
}
return nil
}
func (m *sessionMockRefreshTokenRepo) DeleteByID(id uint) error {
for hash, token := range m.tokens {
if token.ID == id {
delete(m.tokens, hash)
return nil
}
}
return gorm.ErrRecordNotFound
}
func (m *sessionMockRefreshTokenRepo) GetByUserID(userID uint) ([]database.RefreshToken, error) {
var tokens []database.RefreshToken
for _, token := range m.tokens {
if token.UserID == userID {
tokens = append(tokens, *token)
}
}
return tokens, nil
}
func (m *sessionMockRefreshTokenRepo) CountByUserID(userID uint) (int64, error) {
var count int64
for _, token := range m.tokens {
if token.UserID == userID {
count++
}
}
return count, nil
}
func createSessionTestJWTService(userRepo *testutils.MockUserRepository) (*JWTService, *sessionMockRefreshTokenRepo) {
cfg := &config.JWTConfig{
Secret: "test-secret-key-that-is-long-enough-for-security",
Expiration: 1,
RefreshExpiration: 24,
Issuer: "test-issuer",
Audience: "test-audience",
KeyRotation: config.KeyRotationConfig{
Enabled: false,
CurrentKey: "",
PreviousKey: "",
KeyID: "",
},
}
refreshRepo := newSessionMockRefreshTokenRepo()
jwtService := NewJWTService(cfg, userRepo, refreshRepo)
return jwtService, refreshRepo
}
func createTestUserWithPassword(password string) *database.User {
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
Password: string(hashedPassword),
EmailVerified: true,
Locked: false,
SessionVersion: 1,
}
}
func TestNewSessionService(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
jwtService, _ := createSessionTestJWTService(userRepo)
service := NewSessionService(jwtService, userRepo)
if service == nil {
t.Fatal("expected service to be created")
}
if service.jwtService != jwtService {
t.Error("expected jwtService to be set")
}
if service.userRepo != userRepo {
t.Error("expected userRepo to be set")
}
}
func TestSessionService_Login(t *testing.T) {
tests := []struct {
name string
username string
password string
setupMocks func() (*JWTService, *testutils.MockUserRepository)
expectedError error
checkResult func(*testing.T, *AuthResult)
}{
{
name: "successful login",
username: "testuser",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: nil,
checkResult: func(t *testing.T, result *AuthResult) {
if result == nil {
t.Fatal("expected non-nil result")
}
if result.AccessToken == "" {
t.Error("expected non-empty access token")
}
if result.RefreshToken == "" {
t.Error("expected non-empty refresh token")
}
if result.User == nil {
t.Fatal("expected non-nil user")
}
if result.User.Username != "testuser" {
t.Errorf("expected username 'testuser', got %q", result.User.Username)
}
if result.User.Password != "" {
t.Error("expected password to be sanitized")
}
},
},
{
name: "empty username",
username: "",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: ErrInvalidCredentials,
checkResult: nil,
},
{
name: "whitespace only username",
username: " ",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: ErrInvalidCredentials,
checkResult: nil,
},
{
name: "user not found",
username: "nonexistent",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: ErrInvalidCredentials,
checkResult: nil,
},
{
name: "email not verified",
username: "testuser",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
user.EmailVerified = false
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: ErrEmailNotVerified,
checkResult: nil,
},
{
name: "account locked",
username: "testuser",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
user.Locked = true
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: ErrAccountLocked,
checkResult: nil,
},
{
name: "invalid password",
username: "testuser",
password: "WrongPassword",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: ErrInvalidCredentials,
checkResult: nil,
},
{
name: "trims username whitespace",
username: " testuser ",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: nil,
checkResult: func(t *testing.T, result *AuthResult) {
if result.User.Username != "testuser" {
t.Errorf("expected trimmed username 'testuser', got %q", result.User.Username)
}
if result.AccessToken == "" {
t.Error("expected non-empty access token")
}
if result.RefreshToken == "" {
t.Error("expected non-empty refresh token")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
jwtService, userRepo := tt.setupMocks()
service := NewSessionService(jwtService, userRepo)
result, err := service.Login(tt.username, tt.password)
if tt.expectedError != nil {
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
return
}
if err != nil {
if tt.checkResult == nil {
return
}
t.Fatalf("unexpected error: %v", err)
}
if tt.checkResult != nil {
tt.checkResult(t, result)
}
})
}
}
func TestSessionService_VerifyToken(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
t.Run("successful verification", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
token, err := jwtService.GenerateAccessToken(user)
if err != nil {
t.Fatalf("failed to generate token: %v", err)
}
userID, err := service.VerifyToken(token)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if userID != user.ID {
t.Errorf("expected user ID %d, got %d", user.ID, userID)
}
})
t.Run("invalid token", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
_, err := service.VerifyToken("invalid-token")
if err == nil {
t.Fatal("expected error for invalid token")
}
})
t.Run("empty token", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
_, err := service.VerifyToken("")
if err == nil {
t.Fatal("expected error for empty token")
}
})
}
func TestSessionService_RefreshAccessToken(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
t.Run("successful refresh", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
refreshToken, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("failed to generate refresh token: %v", err)
}
result, err := service.RefreshAccessToken(refreshToken)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
if result.AccessToken == "" {
t.Error("expected non-empty access token")
}
if result.RefreshToken != refreshToken {
t.Errorf("expected refresh token to remain unchanged")
}
if result.User == nil {
t.Fatal("expected non-nil user")
}
})
t.Run("invalid refresh token", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
_, err := service.RefreshAccessToken("invalid-refresh-token")
if err == nil {
t.Fatal("expected error for invalid refresh token")
}
if !errors.Is(err, ErrRefreshTokenInvalid) {
t.Errorf("expected ErrRefreshTokenInvalid, got %v", err)
}
})
}
func TestSessionService_RevokeRefreshToken(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
t.Run("successful revocation", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
refreshToken, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("failed to generate refresh token: %v", err)
}
err = service.RevokeRefreshToken(refreshToken)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, err = service.RefreshAccessToken(refreshToken)
if err == nil {
t.Fatal("expected error when using revoked refresh token")
}
})
}
func TestSessionService_RevokeAllUserTokens(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
t.Run("successful revocation", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
refreshToken, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("failed to generate refresh token: %v", err)
}
err = service.RevokeAllUserTokens(user.ID)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, err = service.RefreshAccessToken(refreshToken)
if err == nil {
t.Fatal("expected error when using revoked refresh token")
}
})
}
func TestSessionService_InvalidateAllSessions(t *testing.T) {
tests := []struct {
name string
userID uint
setupMocks func() (*JWTService, *testutils.MockUserRepository)
expectedError error
checkResult func(*testing.T, *testutils.MockUserRepository)
}{
{
name: "successful invalidation",
userID: 1,
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
SessionVersion: 1,
}
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: nil,
checkResult: func(t *testing.T, userRepo *testutils.MockUserRepository) {
user, err := userRepo.GetByID(1)
if err != nil {
t.Fatalf("failed to get user: %v", err)
}
if user.SessionVersion != 2 {
t.Errorf("expected SessionVersion to be 2, got %d", user.SessionVersion)
}
},
},
{
name: "user not found",
userID: 999,
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: nil,
checkResult: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
jwtService, userRepo := tt.setupMocks()
service := NewSessionService(jwtService, userRepo)
err := service.InvalidateAllSessions(tt.userID)
if tt.expectedError != nil {
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
return
}
if err != nil {
if tt.name == "user not found" {
if err.Error() == "" {
t.Fatal("expected error message")
}
return
}
t.Fatalf("unexpected error: %v", err)
}
if tt.checkResult != nil {
tt.checkResult(t, userRepo)
}
})
}
}
func TestSessionService_CleanupExpiredTokens(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
jwtService, _ := createSessionTestJWTService(userRepo)
t.Run("successful cleanup", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
err := service.CleanupExpiredTokens()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
}