To gitea and beyond, let's go(-yco)

This commit is contained in:
2025-11-10 19:12:09 +01:00
parent 8f6133392d
commit 71a031342b
245 changed files with 83994 additions and 0 deletions

View File

@@ -0,0 +1,176 @@
package services
import (
"errors"
"fmt"
"log"
"time"
"goyco/internal/database"
"goyco/internal/repositories"
"gorm.io/gorm"
)
type AccountDeletionService struct {
userRepo repositories.UserRepository
postRepo repositories.PostRepository
deletionRepo repositories.AccountDeletionRepository
emailService *EmailService
}
func NewAccountDeletionService(userRepo repositories.UserRepository, postRepo repositories.PostRepository, deletionRepo repositories.AccountDeletionRepository, emailService *EmailService) *AccountDeletionService {
return &AccountDeletionService{
userRepo: userRepo,
postRepo: postRepo,
deletionRepo: deletionRepo,
emailService: emailService,
}
}
func (s *AccountDeletionService) GetUserIDFromDeletionToken(token string) (uint, error) {
trimmed := TrimString(token)
if trimmed == "" {
return 0, ErrInvalidDeletionToken
}
if s.deletionRepo == nil {
return 0, fmt.Errorf("account deletion repository not configured")
}
hashed := HashVerificationToken(trimmed)
req, err := s.deletionRepo.GetByTokenHash(hashed)
if err != nil {
if IsRecordNotFound(err) {
return 0, ErrInvalidDeletionToken
}
return 0, fmt.Errorf("lookup deletion request: %w", err)
}
if time.Now().After(req.ExpiresAt) {
if delErr := s.deletionRepo.DeleteByID(req.ID); delErr != nil {
log.Printf("Failed to delete expired deletion request %d: %v", req.ID, delErr)
}
return 0, ErrInvalidDeletionToken
}
return req.UserID, nil
}
func (s *AccountDeletionService) RequestAccountDeletion(userID uint) error {
if userID == 0 {
return fmt.Errorf("invalid user identifier")
}
if s.deletionRepo == nil {
return fmt.Errorf("account deletion repository not configured")
}
user, err := s.userRepo.GetByID(userID)
if err != nil {
if IsRecordNotFound(err) {
return ErrUserNotFound
}
return fmt.Errorf("load user: %w", err)
}
if err := s.deletionRepo.DeleteByUserID(userID); err != nil {
return fmt.Errorf("clear existing deletion requests: %w", err)
}
token, hash, err := generateVerificationToken()
if err != nil {
return err
}
req := &database.AccountDeletionRequest{
UserID: userID,
TokenHash: hash,
ExpiresAt: time.Now().Add(time.Duration(deletionTokenExpirationHours) * time.Hour),
}
if err := s.deletionRepo.Create(req); err != nil {
return fmt.Errorf("create deletion request: %w", err)
}
if err := s.emailService.SendAccountDeletionEmail(user, token); err != nil {
if delErr := s.deletionRepo.DeleteByID(req.ID); delErr != nil {
log.Printf("Failed to cleanup deletion request %d after email failure: %v", req.ID, delErr)
}
return fmt.Errorf("send deletion confirmation email: %w", err)
}
return nil
}
func (s *AccountDeletionService) validateAndGetDeletionRequest(token string) (*database.AccountDeletionRequest, *database.User, error) {
trimmed := TrimString(token)
if trimmed == "" {
return nil, nil, ErrInvalidDeletionToken
}
if s.deletionRepo == nil {
return nil, nil, fmt.Errorf("account deletion repository not configured")
}
hashed := HashVerificationToken(trimmed)
req, err := s.deletionRepo.GetByTokenHash(hashed)
if err != nil {
if IsRecordNotFound(err) {
return nil, nil, ErrInvalidDeletionToken
}
return nil, nil, fmt.Errorf("lookup deletion request: %w", err)
}
if time.Now().After(req.ExpiresAt) {
if delErr := s.deletionRepo.DeleteByID(req.ID); delErr != nil {
log.Printf("Failed to delete expired deletion request %d: %v", req.ID, delErr)
}
return nil, nil, ErrInvalidDeletionToken
}
user, err := s.userRepo.GetByID(req.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if delErr := s.deletionRepo.DeleteByID(req.ID); delErr != nil {
log.Printf("Failed to delete orphaned deletion request %d: %v", req.ID, delErr)
}
return nil, nil, ErrInvalidDeletionToken
}
return nil, nil, fmt.Errorf("load user: %w", err)
}
return req, user, nil
}
func (s *AccountDeletionService) ConfirmAccountDeletion(token string) error {
return s.ConfirmAccountDeletionWithPosts(token, true)
}
func (s *AccountDeletionService) ConfirmAccountDeletionWithPosts(token string, deletePosts bool) error {
req, user, err := s.validateAndGetDeletionRequest(token)
if err != nil {
return err
}
if !deletePosts {
if err := s.userRepo.SoftDeleteWithPosts(user.ID); err != nil {
return fmt.Errorf("soft delete user with posts: %w", err)
}
} else {
if err := s.userRepo.HardDelete(user.ID); err != nil {
return fmt.Errorf("delete user: %w", err)
}
}
if err := s.deletionRepo.DeleteByID(req.ID); err != nil {
return fmt.Errorf("clear deletion request: %w", err)
}
if err := s.emailService.SendAccountDeletionNotificationEmail(user, deletePosts); err != nil {
log.Printf("Failed to send account deletion notification email for user %d: %v", user.ID, err)
return fmt.Errorf("%w: %w", ErrDeletionEmailFailed, err)
}
return nil
}

View File

@@ -0,0 +1,529 @@
package services
import (
"errors"
"testing"
"time"
"goyco/internal/database"
"goyco/internal/testutils"
"gorm.io/gorm"
)
type errorEmailSender struct {
err error
}
func (e *errorEmailSender) Send(to, subject, body string) error {
return e.err
}
type mockAccountDeletionRepository struct {
requests map[uint]*database.AccountDeletionRequest
requestsByTokenHash map[string]*database.AccountDeletionRequest
nextID uint
createErr error
getByTokenHashErr error
deleteByIDErr error
deleteByUserIDErr error
}
func newMockAccountDeletionRepository() *mockAccountDeletionRepository {
return &mockAccountDeletionRepository{
requests: make(map[uint]*database.AccountDeletionRequest),
requestsByTokenHash: make(map[string]*database.AccountDeletionRequest),
nextID: 1,
}
}
func (m *mockAccountDeletionRepository) Create(req *database.AccountDeletionRequest) error {
if m.createErr != nil {
return m.createErr
}
req.ID = m.nextID
m.nextID++
reqCopy := *req
m.requests[req.ID] = &reqCopy
m.requestsByTokenHash[req.TokenHash] = &reqCopy
return nil
}
func (m *mockAccountDeletionRepository) GetByTokenHash(hash string) (*database.AccountDeletionRequest, error) {
if m.getByTokenHashErr != nil {
return nil, m.getByTokenHashErr
}
if req, ok := m.requestsByTokenHash[hash]; ok {
reqCopy := *req
return &reqCopy, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *mockAccountDeletionRepository) DeleteByID(id uint) error {
if m.deleteByIDErr != nil {
return m.deleteByIDErr
}
if req, ok := m.requests[id]; ok {
delete(m.requests, id)
delete(m.requestsByTokenHash, req.TokenHash)
return nil
}
return gorm.ErrRecordNotFound
}
func (m *mockAccountDeletionRepository) DeleteByUserID(userID uint) error {
if m.deleteByUserIDErr != nil {
return m.deleteByUserIDErr
}
for id, req := range m.requests {
if req.UserID == userID {
delete(m.requests, id)
delete(m.requestsByTokenHash, req.TokenHash)
}
}
return nil
}
func TestNewAccountDeletionService(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
deletionRepo := newMockAccountDeletionRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
service := NewAccountDeletionService(userRepo, postRepo, deletionRepo, emailService)
if service == nil {
t.Fatal("expected service to be created")
}
if service.userRepo != userRepo {
t.Error("expected userRepo to be set")
}
if service.postRepo != postRepo {
t.Error("expected postRepo to be set")
}
if service.deletionRepo != deletionRepo {
t.Error("expected deletionRepo to be set")
}
if service.emailService != emailService {
t.Error("expected emailService to be set")
}
}
func TestAccountDeletionService_GetUserIDFromDeletionToken(t *testing.T) {
tests := []struct {
name string
token string
setupRepo func() *mockAccountDeletionRepository
expectedID uint
expectedError error
}{
{
name: "successful retrieval",
token: "valid-token",
setupRepo: func() *mockAccountDeletionRepository {
repo := newMockAccountDeletionRepository()
req := &database.AccountDeletionRequest{
UserID: 1,
TokenHash: HashVerificationToken("valid-token"),
ExpiresAt: time.Now().Add(time.Hour),
}
repo.Create(req)
return repo
},
expectedID: 1,
expectedError: nil,
},
{
name: "empty token",
token: "",
setupRepo: func() *mockAccountDeletionRepository { return newMockAccountDeletionRepository() },
expectedID: 0,
expectedError: ErrInvalidDeletionToken,
},
{
name: "whitespace only token",
token: " ",
setupRepo: func() *mockAccountDeletionRepository { return newMockAccountDeletionRepository() },
expectedID: 0,
expectedError: ErrInvalidDeletionToken,
},
{
name: "token not found",
token: "invalid-token",
setupRepo: func() *mockAccountDeletionRepository {
return newMockAccountDeletionRepository()
},
expectedID: 0,
expectedError: ErrInvalidDeletionToken,
},
{
name: "expired token",
token: "expired-token",
setupRepo: func() *mockAccountDeletionRepository {
repo := newMockAccountDeletionRepository()
req := &database.AccountDeletionRequest{
UserID: 1,
TokenHash: HashVerificationToken("expired-token"),
ExpiresAt: time.Now().Add(-time.Hour),
}
repo.Create(req)
return repo
},
expectedID: 0,
expectedError: ErrInvalidDeletionToken,
},
{
name: "repository error",
token: "valid-token",
setupRepo: func() *mockAccountDeletionRepository {
repo := newMockAccountDeletionRepository()
repo.getByTokenHashErr = errors.New("database error")
return repo
},
expectedID: 0,
expectedError: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
deletionRepo := tt.setupRepo()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
service := NewAccountDeletionService(userRepo, postRepo, deletionRepo, emailService)
userID, err := service.GetUserIDFromDeletionToken(tt.token)
if tt.expectedError != nil {
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
} else {
if tt.name == "repository error" || tt.name == "nil repository" {
if err == nil {
t.Error("expected error but got none")
}
} else if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
if userID != tt.expectedID {
t.Errorf("expected userID %d, got %d", tt.expectedID, userID)
}
})
}
}
func TestAccountDeletionService_RequestAccountDeletion(t *testing.T) {
tests := []struct {
name string
userID uint
setupMocks func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender)
expectedError bool
checkToken bool
}{
{
name: "successful request",
userID: 1,
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{ID: 1, Username: "testuser", Email: "test@example.com"}
userRepo.Create(user)
deletionRepo := newMockAccountDeletionRepository()
emailSender := &testutils.MockEmailSender{}
return userRepo, deletionRepo, emailSender
},
expectedError: false,
checkToken: true,
},
{
name: "invalid user ID",
userID: 0,
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
return testutils.NewMockUserRepository(), newMockAccountDeletionRepository(), &testutils.MockEmailSender{}
},
expectedError: true,
checkToken: false,
},
{
name: "user not found",
userID: 999,
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
return testutils.NewMockUserRepository(), newMockAccountDeletionRepository(), &testutils.MockEmailSender{}
},
expectedError: true,
checkToken: false,
},
{
name: "email service error",
userID: 1,
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{ID: 1, Username: "testuser", Email: "test@example.com"}
userRepo.Create(user)
deletionRepo := newMockAccountDeletionRepository()
var errorSender errorEmailSender
errorSender.err = errors.New("email service error")
emailSender := &errorSender
return userRepo, deletionRepo, emailSender
},
expectedError: true,
checkToken: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo, deletionRepo, emailSender := tt.setupMocks()
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
service := NewAccountDeletionService(userRepo, postRepo, deletionRepo, emailService)
err := service.RequestAccountDeletion(tt.userID)
if tt.expectedError {
if err == nil {
t.Error("expected error but got none")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if tt.checkToken {
if len(deletionRepo.requests) == 0 {
t.Error("expected deletion request to be created")
}
}
}
})
}
}
func TestAccountDeletionService_ConfirmAccountDeletion(t *testing.T) {
tests := []struct {
name string
token string
setupMocks func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender)
expectedError error
}{
{
name: "successful deletion",
token: "valid-token",
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{ID: 1, Username: "testuser", Email: "test@example.com"}
userRepo.Create(user)
deletionRepo := newMockAccountDeletionRepository()
req := &database.AccountDeletionRequest{
UserID: 1,
TokenHash: HashVerificationToken("valid-token"),
ExpiresAt: time.Now().Add(time.Hour),
}
deletionRepo.Create(req)
emailSender := &testutils.MockEmailSender{}
return userRepo, deletionRepo, emailSender
},
expectedError: nil,
},
{
name: "empty token",
token: "",
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
return testutils.NewMockUserRepository(), newMockAccountDeletionRepository(), &testutils.MockEmailSender{}
},
expectedError: ErrInvalidDeletionToken,
},
{
name: "token not found",
token: "invalid-token",
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
return testutils.NewMockUserRepository(), newMockAccountDeletionRepository(), &testutils.MockEmailSender{}
},
expectedError: ErrInvalidDeletionToken,
},
{
name: "expired token",
token: "expired-token",
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
deletionRepo := newMockAccountDeletionRepository()
req := &database.AccountDeletionRequest{
UserID: 1,
TokenHash: HashVerificationToken("expired-token"),
ExpiresAt: time.Now().Add(-time.Hour),
}
deletionRepo.Create(req)
return testutils.NewMockUserRepository(), deletionRepo, &testutils.MockEmailSender{}
},
expectedError: ErrInvalidDeletionToken,
},
{
name: "user not found",
token: "valid-token",
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
deletionRepo := newMockAccountDeletionRepository()
req := &database.AccountDeletionRequest{
UserID: 999,
TokenHash: HashVerificationToken("valid-token"),
ExpiresAt: time.Now().Add(time.Hour),
}
deletionRepo.Create(req)
return testutils.NewMockUserRepository(), deletionRepo, &testutils.MockEmailSender{}
},
expectedError: ErrInvalidDeletionToken,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo, deletionRepo, emailSender := tt.setupMocks()
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
service := NewAccountDeletionService(userRepo, postRepo, deletionRepo, emailService)
err := service.ConfirmAccountDeletion(tt.token)
if tt.expectedError != nil {
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
})
}
}
func TestAccountDeletionService_ConfirmAccountDeletionWithPosts(t *testing.T) {
tests := []struct {
name string
token string
deletePosts bool
setupMocks func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender)
expectedError error
}{
{
name: "successful deletion without posts",
token: "valid-token",
deletePosts: false,
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{ID: 1, Username: "testuser", Email: "test@example.com"}
userRepo.Create(user)
deletionRepo := newMockAccountDeletionRepository()
req := &database.AccountDeletionRequest{
UserID: 1,
TokenHash: HashVerificationToken("valid-token"),
ExpiresAt: time.Now().Add(time.Hour),
}
deletionRepo.Create(req)
emailSender := &testutils.MockEmailSender{}
return userRepo, deletionRepo, emailSender
},
expectedError: nil,
},
{
name: "successful deletion with posts",
token: "valid-token",
deletePosts: true,
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{ID: 1, Username: "testuser", Email: "test@example.com"}
userRepo.Create(user)
deletionRepo := newMockAccountDeletionRepository()
req := &database.AccountDeletionRequest{
UserID: 1,
TokenHash: HashVerificationToken("valid-token"),
ExpiresAt: time.Now().Add(time.Hour),
}
deletionRepo.Create(req)
emailSender := &testutils.MockEmailSender{}
return userRepo, deletionRepo, emailSender
},
expectedError: nil,
},
{
name: "empty token",
token: "",
deletePosts: false,
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
return testutils.NewMockUserRepository(), newMockAccountDeletionRepository(), &testutils.MockEmailSender{}
},
expectedError: ErrInvalidDeletionToken,
},
{
name: "expired token",
token: "expired-token",
deletePosts: false,
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
deletionRepo := newMockAccountDeletionRepository()
req := &database.AccountDeletionRequest{
UserID: 1,
TokenHash: HashVerificationToken("expired-token"),
ExpiresAt: time.Now().Add(-time.Hour),
}
deletionRepo.Create(req)
return testutils.NewMockUserRepository(), deletionRepo, &testutils.MockEmailSender{}
},
expectedError: ErrInvalidDeletionToken,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo, deletionRepo, emailSender := tt.setupMocks()
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
service := NewAccountDeletionService(userRepo, postRepo, deletionRepo, emailService)
err := service.ConfirmAccountDeletionWithPosts(tt.token, tt.deletePosts)
if tt.expectedError != nil {
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
})
}
}
func TestAccountDeletionService_UserHasPosts(t *testing.T) {
}

View File

@@ -0,0 +1,139 @@
package services
import (
"fmt"
"goyco/internal/config"
"goyco/internal/database"
)
type AuthFacade struct {
registrationService *RegistrationService
passwordResetService *PasswordResetService
deletionService *AccountDeletionService
sessionService *SessionService
userManagementService *UserManagementService
config *config.Config
}
func NewAuthFacade(
registrationService *RegistrationService,
passwordResetService *PasswordResetService,
deletionService *AccountDeletionService,
sessionService *SessionService,
userManagementService *UserManagementService,
config *config.Config,
) *AuthFacade {
return &AuthFacade{
registrationService: registrationService,
passwordResetService: passwordResetService,
deletionService: deletionService,
sessionService: sessionService,
userManagementService: userManagementService,
config: config,
}
}
func (f *AuthFacade) Register(username, email, password string) (*RegistrationResult, error) {
return f.registrationService.Register(username, email, password)
}
func (f *AuthFacade) Login(username, password string) (*AuthResult, error) {
return f.sessionService.Login(username, password)
}
func (f *AuthFacade) VerifyToken(tokenString string) (uint, error) {
return f.sessionService.VerifyToken(tokenString)
}
func (f *AuthFacade) ConfirmEmail(token string) (*database.User, error) {
return f.registrationService.ConfirmEmail(token)
}
func (f *AuthFacade) ResendVerificationEmail(email string) error {
return f.registrationService.ResendVerificationEmail(email)
}
func (f *AuthFacade) RequestPasswordReset(usernameOrEmail string) error {
return f.passwordResetService.RequestPasswordReset(usernameOrEmail)
}
func (f *AuthFacade) ResetPassword(token, newPassword string) error {
user, err := f.passwordResetService.GetUserByResetToken(token)
if err != nil {
return err
}
if err := f.passwordResetService.ResetPassword(token, newPassword); err != nil {
return err
}
if err := f.sessionService.InvalidateAllSessions(user.ID); err != nil {
return fmt.Errorf("invalidate sessions: %w", err)
}
return nil
}
func (f *AuthFacade) UpdateUsername(userID uint, username string) (*database.User, error) {
return f.userManagementService.UpdateUsername(userID, username)
}
func (f *AuthFacade) UpdateEmail(userID uint, email string) (*database.User, error) {
return f.userManagementService.UpdateEmail(userID, email)
}
func (f *AuthFacade) UpdatePassword(userID uint, currentPassword, newPassword string) (*database.User, error) {
user, err := f.userManagementService.UpdatePassword(userID, currentPassword, newPassword)
if err != nil {
return nil, err
}
if err := f.sessionService.InvalidateAllSessions(userID); err != nil {
return nil, fmt.Errorf("invalidate sessions: %w", err)
}
return user, nil
}
func (f *AuthFacade) RequestAccountDeletion(userID uint) error {
return f.deletionService.RequestAccountDeletion(userID)
}
func (f *AuthFacade) ConfirmAccountDeletion(token string) error {
return f.deletionService.ConfirmAccountDeletion(token)
}
func (f *AuthFacade) ConfirmAccountDeletionWithPosts(token string, deletePosts bool) error {
return f.deletionService.ConfirmAccountDeletionWithPosts(token, deletePosts)
}
func (f *AuthFacade) GetUserIDFromDeletionToken(token string) (uint, error) {
return f.deletionService.GetUserIDFromDeletionToken(token)
}
func (f *AuthFacade) RefreshAccessToken(refreshToken string) (*AuthResult, error) {
return f.sessionService.RefreshAccessToken(refreshToken)
}
func (f *AuthFacade) RevokeRefreshToken(refreshToken string) error {
return f.sessionService.RevokeRefreshToken(refreshToken)
}
func (f *AuthFacade) RevokeAllUserTokens(userID uint) error {
return f.sessionService.RevokeAllUserTokens(userID)
}
func (f *AuthFacade) InvalidateAllSessions(userID uint) error {
return f.sessionService.InvalidateAllSessions(userID)
}
func (f *AuthFacade) CleanupExpiredTokens() error {
return f.sessionService.CleanupExpiredTokens()
}
func (f *AuthFacade) GetAdminEmail() string {
return f.config.App.AdminEmail
}
func (f *AuthFacade) UserHasPosts(userID uint) (bool, int64, error) {
return f.userManagementService.UserHasPosts(userID)
}

View File

@@ -0,0 +1,672 @@
package services
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/testutils"
)
func testHashVerificationToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func setupAuthService(t *testing.T) (*AuthFacade, *testutils.ServiceSuite) {
t.Helper()
suite := testutils.NewServiceSuite(t)
authService, err := NewAuthFacadeForTest(testutils.AppTestConfig, suite.UserRepo, suite.PostRepo, suite.DeletionRepo, suite.RefreshTokenRepo, suite.EmailSender)
if err != nil {
t.Fatalf("Failed to create auth service: %v", err)
}
return authService, suite
}
func TestAuthService_Unit_Register(t *testing.T) {
t.Run("Successful_Registration", func(t *testing.T) {
authService, _ := setupAuthService(t)
result, err := authService.Register("testuser", "test@example.com", "SecurePass123!")
if err != nil {
t.Fatalf("Expected successful registration, got error: %v", err)
}
if result.User.Username != "testuser" {
t.Errorf("Expected username 'testuser', got '%s'", result.User.Username)
}
if result.User.Email != "test@example.com" {
t.Errorf("Expected email 'test@example.com', got '%s'", result.User.Email)
}
if result.User.EmailVerified {
t.Error("Expected email to be unverified initially")
}
if !result.VerificationSent {
t.Error("Expected verification to be sent")
}
})
t.Run("Duplicate_Username", func(t *testing.T) {
authService, suite := setupAuthService(t)
existingUser := &database.User{
Username: "existinguser",
Email: "existing@example.com",
Password: "hashed",
EmailVerified: true,
}
if err := suite.UserRepo.Create(existingUser); err != nil {
t.Fatalf("Failed to create existing user: %v", err)
}
_, err := authService.Register("existinguser", "test@example.com", "SecurePass123!")
if err == nil {
t.Error("Expected error for duplicate username")
}
if !strings.Contains(err.Error(), "username already exists") {
t.Errorf("Expected username conflict error, got: %v", err)
}
})
t.Run("Duplicate_Email", func(t *testing.T) {
authService, suite := setupAuthService(t)
existingUser := &database.User{
Username: "existinguser",
Email: "existing@example.com",
Password: "hashed",
EmailVerified: true,
}
if err := suite.UserRepo.Create(existingUser); err != nil {
t.Fatalf("Failed to create existing user: %v", err)
}
_, err := authService.Register("newuser", "existing@example.com", "SecurePass123!")
if err == nil {
t.Error("Expected error for duplicate email")
}
if !strings.Contains(err.Error(), "email already exists") {
t.Errorf("Expected email conflict error, got: %v", err)
}
})
t.Run("Weak_Password", func(t *testing.T) {
authService, _ := setupAuthService(t)
_, err := authService.Register("testuser", "test@example.com", "123")
if err == nil {
t.Error("Expected error for weak password")
}
if !strings.Contains(strings.ToLower(err.Error()), "password") {
t.Errorf("Expected password validation error, got: %v", err)
}
})
t.Run("Invalid_Email", func(t *testing.T) {
authService, _ := setupAuthService(t)
_, err := authService.Register("testuser", "invalid-email", "SecurePass123!")
if err == nil {
t.Error("Expected error for invalid email")
}
if !strings.Contains(err.Error(), "email") {
t.Errorf("Expected email validation error, got: %v", err)
}
})
}
func TestAuthService_Unit_Login(t *testing.T) {
t.Run("Successful_Login", func(t *testing.T) {
authService, suite := setupAuthService(t)
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("SecurePass123!"), bcrypt.DefaultCost)
testUser := &database.User{
Username: "testuser",
Email: "test@example.com",
Password: string(hashedPassword),
EmailVerified: true,
}
if err := suite.UserRepo.Create(testUser); err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
result, err := authService.Login("testuser", "SecurePass123!")
if err != nil {
t.Fatalf("Expected successful login, got error: %v", err)
}
if result.User.Username != "testuser" {
t.Errorf("Expected username 'testuser', got '%s'", result.User.Username)
}
if result.AccessToken == "" {
t.Error("Expected access token to be generated")
}
if result.RefreshToken == "" {
t.Error("Expected refresh token to be generated")
}
token, err := jwt.Parse(result.AccessToken, func(token *jwt.Token) (any, error) {
return []byte(testutils.AppTestConfig.JWT.Secret), nil
})
if err != nil {
t.Fatalf("Failed to parse token: %v", err)
}
if !token.Valid {
t.Error("Expected valid JWT token")
}
})
t.Run("Invalid_Credentials", func(t *testing.T) {
authService, _ := setupAuthService(t)
_, err := authService.Login("nonexistent", "SecurePass123!")
if err == nil {
t.Error("Expected error for invalid credentials")
}
if !strings.Contains(err.Error(), "invalid credentials") {
t.Errorf("Expected invalid credentials error, got: %v", err)
}
})
t.Run("Wrong_Password", func(t *testing.T) {
authService, suite := setupAuthService(t)
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("correctpassword"), bcrypt.DefaultCost)
testUser := &database.User{
Username: "testuser",
Email: "test@example.com",
Password: string(hashedPassword),
EmailVerified: true,
}
if err := suite.UserRepo.Create(testUser); err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
_, err := authService.Login("testuser", "wrongpassword")
if err == nil {
t.Error("Expected error for wrong password")
}
if !strings.Contains(err.Error(), "invalid credentials") {
t.Errorf("Expected invalid credentials error, got: %v", err)
}
})
t.Run("Unverified_Email", func(t *testing.T) {
authService, suite := setupAuthService(t)
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("SecurePass123!"), bcrypt.DefaultCost)
testUser := &database.User{
Username: "testuser",
Email: "test@example.com",
Password: string(hashedPassword),
EmailVerified: false,
}
if err := suite.UserRepo.Create(testUser); err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
_, err := authService.Login("testuser", "SecurePass123!")
if err == nil {
t.Error("Expected error for unverified email")
}
if !strings.Contains(err.Error(), "email not verified") {
t.Errorf("Expected email verification error, got: %v", err)
}
})
}
func TestAuthService_Unit_ConfirmEmail(t *testing.T) {
t.Run("Successful_Email_Confirmation", func(t *testing.T) {
authService, suite := setupAuthService(t)
rawToken := "valid-token"
hashedToken := testHashVerificationToken(rawToken)
testUser := &database.User{
Username: "testuser",
Email: "test@example.com",
Password: "hashed",
EmailVerified: false,
EmailVerificationToken: hashedToken,
}
if err := suite.UserRepo.Create(testUser); err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
result, err := authService.ConfirmEmail("valid-token")
if err != nil {
t.Fatalf("Expected successful email confirmation, got error: %v", err)
}
if !result.EmailVerified {
t.Error("Expected email to be verified")
}
if result.EmailVerificationToken != "" {
t.Error("Expected verification token to be cleared")
}
})
t.Run("Invalid_Token", func(t *testing.T) {
authService, _ := setupAuthService(t)
_, err := authService.ConfirmEmail("invalid-token")
if err == nil {
t.Error("Expected error for invalid token")
}
if !strings.Contains(err.Error(), "invalid verification token") {
t.Errorf("Expected invalid verification token error, got: %v", err)
}
})
}
func TestAuthService_Integration_Complete_Workflow(t *testing.T) {
suite := testutils.NewServiceSuite(t)
authService, err := NewAuthFacadeForTest(testutils.AppTestConfig, suite.UserRepo, suite.PostRepo, suite.DeletionRepo, suite.RefreshTokenRepo, suite.EmailSender)
if err != nil {
t.Fatalf("Failed to create auth service: %v", err)
}
t.Run("Complete_User_Lifecycle", func(t *testing.T) {
suite.EmailSender.Reset()
registerResult, err := authService.Register("lifecycle_user", "lifecycle@example.com", "SecurePass123!")
if err != nil {
t.Fatalf("Failed to register user: %v", err)
}
if registerResult.User.Username != "lifecycle_user" {
t.Errorf("Expected username 'lifecycle_user', got '%s'", registerResult.User.Username)
}
if registerResult.User.EmailVerified {
t.Error("Expected email to be unverified initially")
}
verificationToken := setupVerificationTokenForTest(t, suite.EmailSender, suite.UserRepo, "lifecycle_user")
confirmResult, err := authService.ConfirmEmail(verificationToken)
if err != nil {
t.Fatalf("Failed to confirm email: %v", err)
}
if !confirmResult.EmailVerified {
t.Error("Expected email to be verified after confirmation")
}
loginResult, err := authService.Login("lifecycle_user", "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
if loginResult.User.Username != "lifecycle_user" {
t.Errorf("Expected username 'lifecycle_user', got '%s'", loginResult.User.Username)
}
if loginResult.AccessToken == "" {
t.Error("Expected access token to be generated")
}
if loginResult.RefreshToken == "" {
t.Error("Expected refresh token to be generated")
}
updateResult, err := authService.UpdateUsername(loginResult.User.ID, "updated_lifecycle_user")
if err != nil {
t.Fatalf("Failed to update username: %v", err)
}
if updateResult.Username != "updated_lifecycle_user" {
t.Errorf("Expected updated username, got '%s'", updateResult.Username)
}
suite.EmailSender.Reset()
emailResult, err := authService.UpdateEmail(loginResult.User.ID, "updated@example.com")
if err != nil {
t.Fatalf("Failed to update email: %v", err)
}
if emailResult.Email != "updated@example.com" {
t.Errorf("Expected updated email, got '%s'", emailResult.Email)
}
newVerificationToken := setupVerificationTokenForTest(t, suite.EmailSender, suite.UserRepo, "updated_lifecycle_user")
_, err = authService.ConfirmEmail(newVerificationToken)
if err != nil {
t.Fatalf("Failed to confirm updated email: %v", err)
}
_, err = authService.UpdatePassword(loginResult.User.ID, "SecurePass123!", "NewSecurePass123!")
if err != nil {
t.Fatalf("Failed to update password: %v", err)
}
_, err = authService.Login("updated_lifecycle_user", "NewSecurePass123!")
if err != nil {
t.Fatalf("Failed to login with new password: %v", err)
}
})
t.Run("Account_Deletion_Workflow", func(t *testing.T) {
suite.EmailSender.Reset()
registerResult, err := authService.Register("deletion_user", "deletion@example.com", "SecurePass123!")
if err != nil {
t.Fatalf("Failed to register user: %v", err)
}
verificationToken := setupVerificationTokenForTest(t, suite.EmailSender, suite.UserRepo, "deletion_user")
_, err = authService.ConfirmEmail(verificationToken)
if err != nil {
t.Fatalf("Failed to confirm email: %v", err)
}
err = authService.RequestAccountDeletion(registerResult.User.ID)
if err != nil {
t.Fatalf("Failed to request account deletion: %v", err)
}
deletionToken := setupDeletionTokenForTest(t, suite.EmailSender, suite.DeletionRepo, registerResult.User.ID)
err = authService.ConfirmAccountDeletion(deletionToken)
if err != nil {
t.Fatalf("Failed to confirm account deletion: %v", err)
}
if err := authService.ConfirmAccountDeletion(deletionToken); !errors.Is(err, ErrInvalidDeletionToken) {
t.Fatalf("Expected token reuse to fail with ErrInvalidDeletionToken, got %v", err)
}
})
t.Run("Password_Reset_Workflow", func(t *testing.T) {
suite.EmailSender.Reset()
_, err := authService.Register("reset_user", "reset@example.com", "SecurePass123!")
if err != nil {
t.Fatalf("Failed to register user: %v", err)
}
verificationToken := setupVerificationTokenForTest(t, suite.EmailSender, suite.UserRepo, "reset_user")
_, err = authService.ConfirmEmail(verificationToken)
if err != nil {
t.Fatalf("Failed to confirm email: %v", err)
}
err = authService.RequestPasswordReset("reset@example.com")
if err != nil {
t.Fatalf("Failed to request password reset: %v", err)
}
resetToken := setupPasswordResetTokenForTest(t, suite.EmailSender, suite.UserRepo, "reset@example.com")
err = authService.ResetPassword(resetToken, "NewSecurePass123!")
if err != nil {
t.Fatalf("Failed to reset password: %v", err)
}
if err := authService.ResetPassword(resetToken, "AnotherPass123!"); err == nil {
t.Fatal("expected reset token reuse to fail")
}
_, err = authService.Login("reset_user", "NewSecurePass123!")
if err != nil {
t.Fatalf("Failed to login with new password: %v", err)
}
})
}
func TestAuthService_Integration_Error_Handling(t *testing.T) {
suite := testutils.NewServiceSuite(t)
authService, err := NewAuthFacadeForTest(testutils.AppTestConfig, suite.UserRepo, suite.PostRepo, suite.DeletionRepo, suite.RefreshTokenRepo, suite.EmailSender)
if err != nil {
t.Fatalf("Failed to create auth service: %v", err)
}
t.Run("Validation_Errors", func(t *testing.T) {
_, err := authService.Register("weak_user", "weak@example.com", "123")
if err == nil {
t.Error("Expected error for weak password")
}
_, err = authService.Register("invalid_user", "not-an-email", "SecurePass123!")
if err == nil {
t.Error("Expected error for invalid email")
}
_, err = authService.Register("", "test@example.com", "SecurePass123!")
if err == nil {
t.Error("Expected error for empty username")
}
})
t.Run("Duplicate_Constraints", func(t *testing.T) {
_, err := authService.Register("duplicate_user", "duplicate1@example.com", "SecurePass123!")
if err != nil {
t.Fatalf("Failed to register first user: %v", err)
}
_, err = authService.Register("duplicate_user", "duplicate2@example.com", "SecurePass123!")
if err == nil {
t.Error("Expected error for duplicate username")
}
_, err = authService.Register("another_user", "duplicate1@example.com", "SecurePass123!")
if err == nil {
t.Error("Expected error for duplicate email")
}
})
t.Run("Duplicate_LongUsername", func(t *testing.T) {
longUsername := strings.Repeat("x", 50)
user := &database.User{
Username: longUsername,
Email: "longuser@example.com",
Password: testutils.HashPassword("SecurePass123!"),
EmailVerified: true,
}
if err := suite.UserRepo.Create(user); err != nil {
t.Fatalf("Failed to create user: %v", err)
}
_, err := authService.Register(longUsername, "longuser2@example.com", "SecurePass123!")
if !errors.Is(err, ErrUsernameTaken) {
t.Fatalf("expected ErrUsernameTaken for duplicate long username, got %v", err)
}
})
t.Run("Authentication_Errors", func(t *testing.T) {
_, err := authService.Login("nonexistent", "password")
if err == nil {
t.Error("Expected error for non-existent user")
}
_, err = authService.Register("auth_user", "auth@example.com", "SecurePass123!")
if err != nil {
t.Fatalf("Failed to register user: %v", err)
}
verificationToken := suite.EmailSender.VerificationToken()
if verificationToken == "" {
t.Fatal("Expected verification token to be generated")
}
_, err = authService.ConfirmEmail(verificationToken)
if err != nil {
t.Fatalf("Failed to confirm email: %v", err)
}
_, err = authService.Login("auth_user", "wrongpassword")
if err == nil {
t.Error("Expected error for wrong password")
}
})
t.Run("Pre_Verified_User_Operations", func(t *testing.T) {
user := createTestUserWithAuth(authService, suite.EmailSender, suite.UserRepo, "preverified_user", "preverified@example.com")
if user.Username != "preverified_user" {
t.Errorf("Expected username 'preverified_user', got '%s'", user.Username)
}
if user.Email != "preverified@example.com" {
t.Errorf("Expected email 'preverified@example.com', got '%s'", user.Email)
}
if !user.EmailVerified {
t.Error("Expected user to be email verified")
}
loginResult, err := authService.Login("preverified_user", "SecurePass123!")
if err != nil {
t.Fatalf("Expected successful login for pre-verified user, got error: %v", err)
}
if loginResult.User.ID != user.ID {
t.Errorf("Expected user ID %d, got %d", user.ID, loginResult.User.ID)
}
updateResult, err := authService.UpdateUsername(user.ID, "updated_preverified_user")
if err != nil {
t.Fatalf("Expected successful username update, got error: %v", err)
}
if updateResult.Username != "updated_preverified_user" {
t.Errorf("Expected updated username, got '%s'", updateResult.Username)
}
})
}
func createTestUserWithAuth(authService interface {
Register(username, email, password string) (*RegistrationResult, error)
ConfirmEmail(token string) (*database.User, error)
}, emailSender interface {
Reset()
VerificationToken() string
}, userRepo repositories.UserRepository, username, email string) *database.User {
emailSender.Reset()
_, err := authService.Register(username, email, "SecurePass123!")
if err != nil {
panic(fmt.Sprintf("Failed to register user: %v", err))
}
verificationToken := emailSender.VerificationToken()
if verificationToken == "" {
panic("Failed to capture verification token during test setup")
}
hashedToken := testutils.HashVerificationToken(verificationToken)
user, err := userRepo.GetByUsername(username)
if err != nil {
panic(fmt.Sprintf("Failed to get user: %v", err))
}
user.EmailVerificationToken = hashedToken
if err := userRepo.Update(user); err != nil {
panic(fmt.Sprintf("Failed to update user with hashed token: %v", err))
}
confirmResult, err := authService.ConfirmEmail(verificationToken)
if err != nil {
panic(fmt.Sprintf("Failed to confirm email: %v", err))
}
return confirmResult
}
func setupVerificationTokenForTest(t *testing.T, emailSender *testutils.MockEmailSender, userRepo repositories.UserRepository, username string) string {
t.Helper()
verificationToken := emailSender.VerificationToken()
if verificationToken == "" {
t.Fatal("Expected verification token to be generated")
}
hashedToken := testutils.HashVerificationToken(verificationToken)
user, err := userRepo.GetByUsername(username)
if err != nil {
t.Fatalf("Failed to get user: %v", err)
}
user.EmailVerificationToken = hashedToken
if err := userRepo.Update(user); err != nil {
t.Fatalf("Failed to update user with hashed token: %v", err)
}
return verificationToken
}
func setupDeletionTokenForTest(t *testing.T, emailSender *testutils.MockEmailSender, deletionRepo repositories.AccountDeletionRepository, userID uint) string {
t.Helper()
deletionToken := emailSender.DeletionToken()
if deletionToken == "" {
t.Fatal("Expected deletion token to be generated")
}
hashedToken := testutils.HashVerificationToken(deletionToken)
if err := deletionRepo.DeleteByUserID(userID); err != nil {
t.Fatalf("Cannot delete user %d", userID)
}
req := &database.AccountDeletionRequest{
UserID: userID,
TokenHash: hashedToken,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
if err := deletionRepo.Create(req); err != nil {
t.Fatalf("Failed to create account deletion request: %v", err)
}
return deletionToken
}
func setupPasswordResetTokenForTest(t *testing.T, emailSender *testutils.MockEmailSender, userRepo repositories.UserRepository, email string) string {
t.Helper()
resetToken := emailSender.PasswordResetToken()
if resetToken == "" {
t.Fatal("Expected password reset token to be generated")
}
hashedToken := testutils.HashVerificationToken(resetToken)
user, err := userRepo.GetByEmail(email)
if err != nil {
t.Fatalf("Failed to get user: %v", err)
}
user.PasswordResetToken = hashedToken
if err := userRepo.Update(user); err != nil {
t.Fatalf("Failed to update user with hashed reset token: %v", err)
}
return resetToken
}

View File

@@ -0,0 +1,35 @@
package services
import (
"errors"
"goyco/internal/database"
)
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidToken = errors.New("invalid or expired token")
ErrUsernameTaken = errors.New("username already exists")
ErrEmailTaken = errors.New("email already exists")
ErrInvalidEmail = errors.New("invalid email address")
ErrPasswordTooShort = errors.New("password too short")
ErrEmailNotVerified = errors.New("email not verified")
ErrAccountLocked = errors.New("account is locked")
ErrInvalidVerificationToken = errors.New("invalid verification token")
ErrEmailSenderUnavailable = errors.New("email sender not configured")
ErrDeletionEmailFailed = errors.New("account deletion email failed")
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
ErrUserNotFound = errors.New("user not found")
ErrDeletionRequestNotFound = errors.New("deletion request not found")
)
type AuthResult struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
User *database.User `json:"user"`
}
type RegistrationResult struct {
User *database.User `json:"user"`
VerificationSent bool `json:"verification_sent"`
}

View File

@@ -0,0 +1,59 @@
package services
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/mail"
"strings"
"goyco/internal/database"
)
const (
defaultTokenExpirationHours = 24
verificationTokenBytes = 32
deletionTokenExpirationHours = 24
)
func normalizeEmail(email string) (string, error) {
trimmed := strings.TrimSpace(email)
if trimmed == "" {
return "", fmt.Errorf("email is required")
}
parsed, err := mail.ParseAddress(trimmed)
if err != nil {
return "", ErrInvalidEmail
}
return strings.ToLower(parsed.Address), nil
}
func generateVerificationToken() (string, string, error) {
buf := make([]byte, verificationTokenBytes)
if _, err := rand.Read(buf); err != nil {
return "", "", fmt.Errorf("generate verification token: %w", err)
}
token := hex.EncodeToString(buf)
hashed := HashVerificationToken(token)
return token, hashed, nil
}
func HashVerificationToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func sanitizeUser(user *database.User) *database.User {
if user == nil {
return nil
}
copy := *user
copy.Password = ""
copy.EmailVerificationToken = ""
return &copy
}

136
internal/services/common.go Normal file
View File

@@ -0,0 +1,136 @@
package services
import (
"errors"
"fmt"
"strings"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/repositories"
)
func TrimString(s string) string {
return strings.TrimSpace(s)
}
const (
DefaultBcryptCost = 10
)
func HashPassword(password string, cost int) (string, error) {
if cost <= 0 {
cost = DefaultBcryptCost
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), cost)
if err != nil {
return "", fmt.Errorf("hash password: %w", err)
}
return string(hashedPassword), nil
}
func IsRecordNotFound(err error) bool {
return errors.Is(err, gorm.ErrRecordNotFound)
}
func HandleUniqueConstraintError(err error) error {
if err == nil {
return nil
}
var pqErr *pq.Error
if !errors.As(err, &pqErr) || pqErr.Code != "23505" {
return err
}
constraintLower := strings.ToLower(pqErr.Constraint)
errMsgLower := strings.ToLower(pqErr.Message)
if strings.Contains(constraintLower, "username") ||
strings.Contains(errMsgLower, "username") ||
strings.Contains(errMsgLower, "users_username_key") ||
strings.Contains(errMsgLower, "users.username") {
return ErrUsernameTaken
}
if strings.Contains(constraintLower, "email") ||
strings.Contains(errMsgLower, "email") ||
strings.Contains(errMsgLower, "users_email_key") ||
strings.Contains(errMsgLower, "users.email") {
return ErrEmailTaken
}
return ErrUsernameTaken
}
func HandleUniqueConstraintErrorWithMessage(err error) error {
if err == nil {
return nil
}
if handled := HandleUniqueConstraintError(err); handled != err {
return handled
}
errMsg := err.Error()
errMsgLower := strings.ToLower(errMsg)
isUniqueError := strings.Contains(errMsgLower, "duplicate key") ||
strings.Contains(errMsgLower, "unique constraint") ||
strings.Contains(errMsgLower, "violates unique constraint") ||
strings.Contains(errMsgLower, "unique constraint failed") ||
strings.Contains(errMsgLower, "constraint failed") ||
(strings.Contains(errMsgLower, "constraint") && strings.Contains(errMsgLower, "unique"))
if !isUniqueError {
return err
}
if strings.Contains(errMsgLower, "username") ||
strings.Contains(errMsgLower, "users_username_key") ||
strings.Contains(errMsgLower, "users.username") ||
strings.Contains(errMsg, "username") ||
strings.Contains(errMsg, "users_username_key") ||
strings.Contains(errMsg, "users.username") {
return ErrUsernameTaken
}
if strings.Contains(errMsgLower, "email") ||
strings.Contains(errMsgLower, "users_email_key") ||
strings.Contains(errMsgLower, "users.email") ||
strings.Contains(errMsg, "email") ||
strings.Contains(errMsg, "users_email_key") ||
strings.Contains(errMsg, "users.email") {
return ErrEmailTaken
}
return ErrUsernameTaken
}
func NewAuthFacadeForTest(cfg *config.Config, userRepo repositories.UserRepository, postRepo repositories.PostRepository, deletionRepo repositories.AccountDeletionRepository, refreshRepo repositories.RefreshTokenRepositoryInterface, emailSender EmailSender) (*AuthFacade, error) {
emailService, err := NewEmailService(cfg, emailSender)
if err != nil {
return nil, fmt.Errorf("create email service: %w", err)
}
jwtService := NewJWTService(&cfg.JWT, userRepo, refreshRepo)
registrationService := NewRegistrationService(userRepo, emailService, cfg)
passwordResetService := NewPasswordResetService(userRepo, emailService)
deletionService := NewAccountDeletionService(userRepo, postRepo, deletionRepo, emailService)
sessionService := NewSessionService(jwtService, userRepo)
userManagementService := NewUserManagementService(userRepo, postRepo, emailService)
authFacade := NewAuthFacade(
registrationService,
passwordResetService,
deletionService,
sessionService,
userManagementService,
cfg,
)
return authFacade, nil
}

View File

@@ -0,0 +1,822 @@
package services
import (
"errors"
"testing"
"goyco/internal/config"
"goyco/internal/repositories"
"goyco/internal/testutils"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
func TestTrimString(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "no_whitespace",
input: "test",
expected: "test",
},
{
name: "leading_spaces",
input: " test",
expected: "test",
},
{
name: "trailing_spaces",
input: "test ",
expected: "test",
},
{
name: "leading_and_trailing_spaces",
input: " test ",
expected: "test",
},
{
name: "only_spaces",
input: " ",
expected: "",
},
{
name: "empty_string",
input: "",
expected: "",
},
{
name: "tabs_and_newlines",
input: "\t\n test \n\t",
expected: "test",
},
{
name: "internal_spaces_preserved",
input: " test string ",
expected: "test string",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := TrimString(tt.input)
if result != tt.expected {
t.Errorf("TrimString(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestHashPassword(t *testing.T) {
t.Run("successful_hash", func(t *testing.T) {
password := "testpassword123"
hashed, err := HashPassword(password, DefaultBcryptCost)
if err != nil {
t.Fatalf("HashPassword() error = %v, want no error", err)
}
if hashed == "" {
t.Error("HashPassword() returned empty string")
}
if hashed == password {
t.Error("HashPassword() returned plain password")
}
err = bcrypt.CompareHashAndPassword([]byte(hashed), []byte(password))
if err != nil {
t.Errorf("HashPassword() produced invalid hash: %v", err)
}
})
t.Run("different_passwords_produce_different_hashes", func(t *testing.T) {
password1 := "password1"
password2 := "password2"
hash1, err := HashPassword(password1, DefaultBcryptCost)
if err != nil {
t.Fatalf("HashPassword() error = %v", err)
}
hash2, err := HashPassword(password2, DefaultBcryptCost)
if err != nil {
t.Fatalf("HashPassword() error = %v", err)
}
if hash1 == hash2 {
t.Error("Different passwords produced same hash")
}
})
t.Run("same_password_produces_different_hashes", func(t *testing.T) {
password := "samepassword"
hash1, err := HashPassword(password, DefaultBcryptCost)
if err != nil {
t.Fatalf("HashPassword() error = %v", err)
}
hash2, err := HashPassword(password, DefaultBcryptCost)
if err != nil {
t.Fatalf("HashPassword() error = %v", err)
}
if hash1 == hash2 {
t.Error("Same password produced same hash (should be different due to salt)")
}
if err := bcrypt.CompareHashAndPassword([]byte(hash1), []byte(password)); err != nil {
t.Errorf("First hash doesn't verify: %v", err)
}
if err := bcrypt.CompareHashAndPassword([]byte(hash2), []byte(password)); err != nil {
t.Errorf("Second hash doesn't verify: %v", err)
}
})
t.Run("custom_cost", func(t *testing.T) {
password := "testpassword"
customCost := 12
hashed, err := HashPassword(password, customCost)
if err != nil {
t.Fatalf("HashPassword() error = %v", err)
}
err = bcrypt.CompareHashAndPassword([]byte(hashed), []byte(password))
if err != nil {
t.Errorf("HashPassword() with custom cost produced invalid hash: %v", err)
}
})
t.Run("zero_cost_uses_default", func(t *testing.T) {
password := "testpassword"
hashed, err := HashPassword(password, 0)
if err != nil {
t.Fatalf("HashPassword() error = %v", err)
}
err = bcrypt.CompareHashAndPassword([]byte(hashed), []byte(password))
if err != nil {
t.Errorf("HashPassword() with zero cost produced invalid hash: %v", err)
}
})
t.Run("negative_cost_uses_default", func(t *testing.T) {
password := "testpassword"
hashed, err := HashPassword(password, -1)
if err != nil {
t.Fatalf("HashPassword() error = %v", err)
}
err = bcrypt.CompareHashAndPassword([]byte(hashed), []byte(password))
if err != nil {
t.Errorf("HashPassword() with negative cost produced invalid hash: %v", err)
}
})
t.Run("empty_password", func(t *testing.T) {
hashed, err := HashPassword("", DefaultBcryptCost)
if err != nil {
t.Fatalf("HashPassword() error = %v", err)
}
err = bcrypt.CompareHashAndPassword([]byte(hashed), []byte(""))
if err != nil {
t.Errorf("HashPassword() with empty password produced invalid hash: %v", err)
}
})
}
func TestIsRecordNotFound(t *testing.T) {
t.Run("gorm_record_not_found", func(t *testing.T) {
err := gorm.ErrRecordNotFound
if !IsRecordNotFound(err) {
t.Error("IsRecordNotFound() = false, want true for gorm.ErrRecordNotFound")
}
})
t.Run("wrapped_gorm_record_not_found", func(t *testing.T) {
err := errors.New("some context")
wrappedErr := errors.Join(err, gorm.ErrRecordNotFound)
if !IsRecordNotFound(wrappedErr) {
t.Error("IsRecordNotFound() = false, want true for wrapped gorm.ErrRecordNotFound")
}
})
t.Run("other_error", func(t *testing.T) {
err := errors.New("some other error")
if IsRecordNotFound(err) {
t.Error("IsRecordNotFound() = true, want false for other error")
}
})
t.Run("nil_error", func(t *testing.T) {
if IsRecordNotFound(nil) {
t.Error("IsRecordNotFound() = true, want false for nil error")
}
})
}
func TestHandleUniqueConstraintError(t *testing.T) {
t.Run("nil_error", func(t *testing.T) {
err := HandleUniqueConstraintError(nil)
if err != nil {
t.Errorf("HandleUniqueConstraintError(nil) = %v, want nil", err)
}
})
t.Run("non_pq_error", func(t *testing.T) {
originalErr := errors.New("some other error")
err := HandleUniqueConstraintError(originalErr)
if err != originalErr {
t.Errorf("HandleUniqueConstraintError() = %v, want %v", err, originalErr)
}
})
t.Run("pq_error_wrong_code", func(t *testing.T) {
pqErr := &pq.Error{
Code: "23503",
Message: "some error",
}
err := HandleUniqueConstraintError(pqErr)
if err != pqErr {
t.Errorf("HandleUniqueConstraintError() = %v, want %v", err, pqErr)
}
})
t.Run("username_constraint_in_constraint_field", func(t *testing.T) {
pqErr := &pq.Error{
Code: "23505",
Constraint: "users_username_key",
Message: "duplicate key value violates unique constraint",
}
err := HandleUniqueConstraintError(pqErr)
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintError() = %v, want ErrUsernameTaken", err)
}
})
t.Run("username_constraint_in_message", func(t *testing.T) {
pqErr := &pq.Error{
Code: "23505",
Constraint: "some_constraint",
Message: "duplicate key value violates unique constraint users.username",
}
err := HandleUniqueConstraintError(pqErr)
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintError() = %v, want ErrUsernameTaken", err)
}
})
t.Run("username_constraint_case_insensitive", func(t *testing.T) {
pqErr := &pq.Error{
Code: "23505",
Constraint: "USERS_USERNAME_KEY",
Message: "DUPLICATE KEY VALUE VIOLATES UNIQUE CONSTRAINT",
}
err := HandleUniqueConstraintError(pqErr)
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintError() = %v, want ErrUsernameTaken", err)
}
})
t.Run("email_constraint_in_constraint_field", func(t *testing.T) {
pqErr := &pq.Error{
Code: "23505",
Constraint: "users_email_key",
Message: "duplicate key value violates unique constraint",
}
err := HandleUniqueConstraintError(pqErr)
if !errors.Is(err, ErrEmailTaken) {
t.Errorf("HandleUniqueConstraintError() = %v, want ErrEmailTaken", err)
}
})
t.Run("email_constraint_in_message", func(t *testing.T) {
pqErr := &pq.Error{
Code: "23505",
Constraint: "some_constraint",
Message: "duplicate key value violates unique constraint users.email",
}
err := HandleUniqueConstraintError(pqErr)
if !errors.Is(err, ErrEmailTaken) {
t.Errorf("HandleUniqueConstraintError() = %v, want ErrEmailTaken", err)
}
})
t.Run("email_constraint_case_insensitive", func(t *testing.T) {
pqErr := &pq.Error{
Code: "23505",
Constraint: "USERS_EMAIL_KEY",
Message: "DUPLICATE KEY VALUE VIOLATES UNIQUE CONSTRAINT",
}
err := HandleUniqueConstraintError(pqErr)
if !errors.Is(err, ErrEmailTaken) {
t.Errorf("HandleUniqueConstraintError() = %v, want ErrEmailTaken", err)
}
})
t.Run("unknown_unique_constraint_defaults_to_username", func(t *testing.T) {
pqErr := &pq.Error{
Code: "23505",
Constraint: "some_other_constraint",
Message: "duplicate key value violates unique constraint",
}
err := HandleUniqueConstraintError(pqErr)
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintError() = %v, want ErrUsernameTaken (default)", err)
}
})
}
func TestHandleUniqueConstraintErrorWithMessage(t *testing.T) {
t.Run("nil_error", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(nil)
if err != nil {
t.Errorf("HandleUniqueConstraintErrorWithMessage(nil) = %v, want nil", err)
}
})
t.Run("pq_error_username_handled_by_HandleUniqueConstraintError", func(t *testing.T) {
pqErr := &pq.Error{
Code: "23505",
Constraint: "users_username_key",
Message: "duplicate key value violates unique constraint",
}
err := HandleUniqueConstraintErrorWithMessage(pqErr)
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
}
})
t.Run("pq_error_email_handled_by_HandleUniqueConstraintError", func(t *testing.T) {
pqErr := &pq.Error{
Code: "23505",
Constraint: "users_email_key",
Message: "duplicate key value violates unique constraint",
}
err := HandleUniqueConstraintErrorWithMessage(pqErr)
if !errors.Is(err, ErrEmailTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
}
})
t.Run("message_based_duplicate_key_username", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint username"))
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
}
})
t.Run("message_based_unique_constraint_username", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("unique constraint failed on username"))
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
}
})
t.Run("message_based_violates_unique_constraint_username", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("violates unique constraint users_username_key"))
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
}
})
t.Run("message_based_unique_constraint_failed_username", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("unique constraint failed on users.username"))
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
}
})
t.Run("message_based_constraint_failed_username", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("constraint failed on username"))
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
}
})
t.Run("message_based_constraint_and_unique_username", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("constraint unique username failed"))
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
}
})
t.Run("message_based_duplicate_key_email", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint email"))
if !errors.Is(err, ErrEmailTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
}
})
t.Run("message_based_unique_constraint_email", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("unique constraint failed on users_email_key"))
if !errors.Is(err, ErrEmailTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
}
})
t.Run("message_based_violates_unique_constraint_email", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("violates unique constraint users.email"))
if !errors.Is(err, ErrEmailTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
}
})
t.Run("message_based_case_insensitive_username", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("DUPLICATE KEY VALUE VIOLATES UNIQUE CONSTRAINT USERNAME"))
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
}
})
t.Run("message_based_case_insensitive_email", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("DUPLICATE KEY VALUE VIOLATES UNIQUE CONSTRAINT EMAIL"))
if !errors.Is(err, ErrEmailTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
}
})
t.Run("non_unique_error_passed_through", func(t *testing.T) {
originalErr := errors.New("some other database error")
err := HandleUniqueConstraintErrorWithMessage(originalErr)
if err != originalErr {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want %v", err, originalErr)
}
})
t.Run("unique_error_without_username_or_email_defaults_to_username", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint"))
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken (default)", err)
}
})
t.Run("wrapped_pq_error", func(t *testing.T) {
pqErr := &pq.Error{
Code: "23505",
Constraint: "users_username_key",
Message: "duplicate key value violates unique constraint",
}
wrappedErr := errors.New("context: " + pqErr.Error())
err := HandleUniqueConstraintErrorWithMessage(wrappedErr)
if err == nil {
t.Error("HandleUniqueConstraintErrorWithMessage() returned nil, expected error")
}
})
}
func TestHandleUniqueConstraintErrorWithMessage_EdgeCases(t *testing.T) {
t.Run("mixed_case_username_in_message", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint UserName"))
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
}
})
t.Run("mixed_case_email_in_message", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint Email"))
if !errors.Is(err, ErrEmailTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
}
})
t.Run("username_substring_not_matched", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint usernames"))
if !errors.Is(err, ErrUsernameTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
}
})
t.Run("email_substring_not_matched", func(t *testing.T) {
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint emails"))
if !errors.Is(err, ErrEmailTaken) {
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
}
})
}
func BenchmarkTrimString(b *testing.B) {
input := " test string with spaces "
for i := 0; i < b.N; i++ {
_ = TrimString(input)
}
}
func BenchmarkHashPassword(b *testing.B) {
password := "testpassword123"
for i := 0; i < b.N; i++ {
_, _ = HashPassword(password, DefaultBcryptCost)
}
}
func BenchmarkIsRecordNotFound(b *testing.B) {
err := gorm.ErrRecordNotFound
for i := 0; i < b.N; i++ {
_ = IsRecordNotFound(err)
}
}
func TestNewAuthFacadeForTest(t *testing.T) {
t.Run("successful_creation", func(t *testing.T) {
db := testutils.NewTestDB(t)
defer func() {
sqlDB, _ := db.DB()
sqlDB.Close()
}()
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
deletionRepo := repositories.NewAccountDeletionRepository(db)
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
emailSender := &testutils.MockEmailSender{}
authFacade, err := NewAuthFacadeForTest(testutils.AppTestConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
if err != nil {
t.Fatalf("NewAuthFacadeForTest() error = %v, want no error", err)
}
if authFacade == nil {
t.Fatal("NewAuthFacadeForTest() returned nil AuthFacade")
}
})
t.Run("successful_creation_with_custom_config", func(t *testing.T) {
db := testutils.NewTestDB(t)
defer func() {
sqlDB, _ := db.DB()
sqlDB.Close()
}()
customConfig := &config.Config{
JWT: config.JWTConfig{
Secret: "custom-secret-key-for-testing-purposes-only",
Expiration: 48,
},
App: config.AppConfig{
BaseURL: "http://localhost:3000",
BcryptCost: 12,
},
}
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
deletionRepo := repositories.NewAccountDeletionRepository(db)
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
emailSender := &testutils.MockEmailSender{}
authFacade, err := NewAuthFacadeForTest(customConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
if err != nil {
t.Fatalf("NewAuthFacadeForTest() error = %v, want no error", err)
}
if authFacade == nil {
t.Fatal("NewAuthFacadeForTest() returned nil AuthFacade")
}
})
t.Run("error_on_empty_base_url", func(t *testing.T) {
db := testutils.NewTestDB(t)
defer func() {
sqlDB, _ := db.DB()
sqlDB.Close()
}()
invalidConfig := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret-key",
Expiration: 24,
},
App: config.AppConfig{
BaseURL: "",
BcryptCost: 10,
},
}
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
deletionRepo := repositories.NewAccountDeletionRepository(db)
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
emailSender := &testutils.MockEmailSender{}
authFacade, err := NewAuthFacadeForTest(invalidConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
if err == nil {
t.Fatal("NewAuthFacadeForTest() expected error for empty base URL, got nil")
}
if authFacade != nil {
t.Fatal("NewAuthFacadeForTest() expected nil AuthFacade on error, got non-nil")
}
if err.Error() != "create email service: APP_BASE_URL is required and must be externally reachable" {
t.Errorf("NewAuthFacadeForTest() error = %v, want error about APP_BASE_URL", err)
}
})
t.Run("error_on_whitespace_only_base_url", func(t *testing.T) {
db := testutils.NewTestDB(t)
defer func() {
sqlDB, _ := db.DB()
sqlDB.Close()
}()
invalidConfig := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret-key",
Expiration: 24,
},
App: config.AppConfig{
BaseURL: " ",
BcryptCost: 10,
},
}
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
deletionRepo := repositories.NewAccountDeletionRepository(db)
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
emailSender := &testutils.MockEmailSender{}
authFacade, err := NewAuthFacadeForTest(invalidConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
if err == nil {
t.Fatal("NewAuthFacadeForTest() expected error for whitespace-only base URL, got nil")
}
if authFacade != nil {
t.Fatal("NewAuthFacadeForTest() expected nil AuthFacade on error, got non-nil")
}
})
t.Run("facade_can_perform_operations", func(t *testing.T) {
db := testutils.NewTestDB(t)
defer func() {
sqlDB, _ := db.DB()
sqlDB.Close()
}()
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
deletionRepo := repositories.NewAccountDeletionRepository(db)
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
emailSender := &testutils.MockEmailSender{}
authFacade, err := NewAuthFacadeForTest(testutils.AppTestConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
if err != nil {
t.Fatalf("NewAuthFacadeForTest() error = %v, want no error", err)
}
result, err := authFacade.Register("testuser", "test@example.com", "SecurePass123!")
if err != nil {
t.Fatalf("AuthFacade.Register() error = %v, want no error", err)
}
if result == nil {
t.Fatal("AuthFacade.Register() returned nil result")
}
if result.User.Username != "testuser" {
t.Errorf("AuthFacade.Register() username = %v, want 'testuser'", result.User.Username)
}
if result.User.Email != "test@example.com" {
t.Errorf("AuthFacade.Register() email = %v, want 'test@example.com'", result.User.Email)
}
})
t.Run("facade_can_login", func(t *testing.T) {
db := testutils.NewTestDB(t)
defer func() {
sqlDB, _ := db.DB()
sqlDB.Close()
}()
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
deletionRepo := repositories.NewAccountDeletionRepository(db)
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
emailSender := &testutils.MockEmailSender{}
authFacade, err := NewAuthFacadeForTest(testutils.AppTestConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
if err != nil {
t.Fatalf("NewAuthFacadeForTest() error = %v, want no error", err)
}
password := "SecurePass123!"
_, err = authFacade.Register("loginuser", "login@example.com", password)
if err != nil {
t.Fatalf("AuthFacade.Register() error = %v, want no error", err)
}
user, err := userRepo.GetByUsername("loginuser")
if err != nil {
t.Fatalf("GetByUsername() error = %v", err)
}
user.EmailVerified = true
if err := userRepo.Update(user); err != nil {
t.Fatalf("Update() error = %v", err)
}
authResult, err := authFacade.Login("loginuser", password)
if err != nil {
t.Fatalf("AuthFacade.Login() error = %v, want no error", err)
}
if authResult == nil {
t.Fatal("AuthFacade.Login() returned nil result")
}
if authResult.AccessToken == "" {
t.Error("AuthFacade.Login() returned empty access token")
}
if authResult.RefreshToken == "" {
t.Error("AuthFacade.Login() returned empty refresh token")
}
if authResult.User == nil {
t.Error("AuthFacade.Login() returned nil user")
}
})
t.Run("multiple_facades_independent", func(t *testing.T) {
db1 := testutils.NewTestDB(t)
defer func() {
sqlDB, _ := db1.DB()
sqlDB.Close()
}()
db2 := testutils.NewTestDB(t)
defer func() {
sqlDB, _ := db2.DB()
sqlDB.Close()
}()
userRepo1 := repositories.NewUserRepository(db1)
postRepo1 := repositories.NewPostRepository(db1)
deletionRepo1 := repositories.NewAccountDeletionRepository(db1)
refreshTokenRepo1 := repositories.NewRefreshTokenRepository(db1)
emailSender1 := &testutils.MockEmailSender{}
userRepo2 := repositories.NewUserRepository(db2)
postRepo2 := repositories.NewPostRepository(db2)
deletionRepo2 := repositories.NewAccountDeletionRepository(db2)
refreshTokenRepo2 := repositories.NewRefreshTokenRepository(db2)
emailSender2 := &testutils.MockEmailSender{}
authFacade1, err := NewAuthFacadeForTest(testutils.AppTestConfig, userRepo1, postRepo1, deletionRepo1, refreshTokenRepo1, emailSender1)
if err != nil {
t.Fatalf("NewAuthFacadeForTest() error = %v", err)
}
authFacade2, err := NewAuthFacadeForTest(testutils.AppTestConfig, userRepo2, postRepo2, deletionRepo2, refreshTokenRepo2, emailSender2)
if err != nil {
t.Fatalf("NewAuthFacadeForTest() error = %v", err)
}
if authFacade1 == authFacade2 {
t.Error("NewAuthFacadeForTest() returned same instance for different repositories")
}
_, err = authFacade1.Register("user1", "user1@example.com", "Pass123!")
if err != nil {
t.Fatalf("AuthFacade1.Register() error = %v", err)
}
_, err = authFacade2.Register("user2", "user2@example.com", "Pass123!")
if err != nil {
t.Fatalf("AuthFacade2.Register() error = %v", err)
}
user1, err := userRepo1.GetByUsername("user1")
if err != nil {
t.Fatalf("GetByUsername() error = %v", err)
}
user1.EmailVerified = true
if err := userRepo1.Update(user1); err != nil {
t.Fatalf("Update() error = %v", err)
}
user2, err := userRepo2.GetByUsername("user2")
if err != nil {
t.Fatalf("GetByUsername() error = %v", err)
}
user2.EmailVerified = true
if err := userRepo2.Update(user2); err != nil {
t.Fatalf("Update() error = %v", err)
}
_, err = authFacade1.Login("user1", "Pass123!")
if err != nil {
t.Errorf("AuthFacade1.Login() error = %v, should find user1", err)
}
_, err = authFacade2.Login("user2", "Pass123!")
if err != nil {
t.Errorf("AuthFacade2.Login() error = %v, should find user2", err)
}
})
}

View File

@@ -0,0 +1,99 @@
package services
import (
"fmt"
"net/smtp"
"strings"
"time"
)
type EmailSender interface {
Send(to, subject, body string) error
}
type SMTPSender struct {
host string
port int
username string
password string
from string
timeout time.Duration
}
func NewSMTPSender(cfgHost string, cfgPort int, cfgUsername, cfgPassword, cfgFrom string) *SMTPSender {
return &SMTPSender{
host: cfgHost,
port: cfgPort,
username: cfgUsername,
password: cfgPassword,
from: cfgFrom,
timeout: 30 * time.Second,
}
}
func NewSMTPSenderWithTimeout(cfgHost string, cfgPort int, cfgUsername, cfgPassword, cfgFrom string, timeout time.Duration) *SMTPSender {
return &SMTPSender{
host: cfgHost,
port: cfgPort,
username: cfgUsername,
password: cfgPassword,
from: cfgFrom,
timeout: timeout,
}
}
func (s *SMTPSender) Send(to, subject, body string) error {
if s == nil {
return fmt.Errorf("smtp sender is not configured")
}
to = strings.TrimSpace(to)
if to == "" {
return fmt.Errorf("recipient address is required")
}
address := fmt.Sprintf("%s:%d", s.host, s.port)
headers := []string{
fmt.Sprintf("From: %s", s.from),
fmt.Sprintf("To: %s", to),
fmt.Sprintf("Subject: %s", subject),
"MIME-Version: 1.0",
"Content-Type: text/html; charset=\"UTF-8\"",
}
message := strings.Join(headers, "\r\n") + "\r\n\r\n" + body + "\r\n"
var auth smtp.Auth
if strings.TrimSpace(s.username) != "" {
auth = smtp.PlainAuth("", s.username, s.password, s.host)
}
done := make(chan error, 1)
go func() {
done <- smtp.SendMail(address, auth, s.from, []string{to}, []byte(message))
}()
select {
case err := <-done:
return err
case <-time.After(s.timeout):
return fmt.Errorf("email sending timeout after %v", s.timeout)
}
}
func (s *SMTPSender) SendAsync(to, subject, body string) <-chan error {
result := make(chan error, 1)
go func() {
result <- s.Send(to, subject, body)
}()
return result
}
func (s *SMTPSender) SetTimeout(timeout time.Duration) {
if s == nil {
return
}
s.timeout = timeout
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,574 @@
package services
import (
"fmt"
"html"
"net/url"
"strings"
"goyco/internal/config"
"goyco/internal/database"
)
type EmailService struct {
EmailSender EmailSender
baseURL string
config *config.Config
}
func NewEmailService(cfg *config.Config, sender EmailSender) (*EmailService, error) {
baseURL := strings.TrimRight(strings.TrimSpace(cfg.App.BaseURL), "/")
if baseURL == "" {
return nil, fmt.Errorf("APP_BASE_URL is required and must be externally reachable")
}
return &EmailService{
EmailSender: sender,
baseURL: baseURL,
config: cfg,
}, nil
}
func (s *EmailService) SendVerificationEmail(user *database.User, token string) error {
if s.EmailSender == nil {
return ErrEmailSenderUnavailable
}
verificationURL := fmt.Sprintf("%s/confirm?token=%s", s.baseURL, url.QueryEscape(token))
subject := fmt.Sprintf("🎉 Welcome to %s! Confirm your email address", s.config.App.Title)
body := s.GenerateVerificationEmailBody(user.Username, verificationURL)
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
return fmt.Errorf("send verification email: %w", err)
}
return nil
}
func (s *EmailService) SendEmailChangeVerificationEmail(user *database.User, token string) error {
if s.EmailSender == nil {
return ErrEmailSenderUnavailable
}
verificationURL := fmt.Sprintf("%s/confirm?token=%s", s.baseURL, url.QueryEscape(token))
subject := "📧 Confirm your new email address"
body := s.GenerateEmailChangeVerificationEmailBody(user.Username, verificationURL)
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
return fmt.Errorf("send email change verification email: %w", err)
}
return nil
}
func (s *EmailService) SendResendVerificationEmail(user *database.User, token string) error {
if s.EmailSender == nil {
return ErrEmailSenderUnavailable
}
verificationURL := fmt.Sprintf("%s/confirm?token=%s", s.baseURL, url.QueryEscape(token))
subject := fmt.Sprintf("🔄 Resend: Confirm your %s account", s.config.App.Title)
body := s.GenerateResendVerificationEmailBody(user.Username, verificationURL)
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
return fmt.Errorf("send verification email: %w", err)
}
return nil
}
func (s *EmailService) SendPasswordResetEmail(user *database.User, token string) error {
if s.EmailSender == nil {
return ErrEmailSenderUnavailable
}
resetURL := fmt.Sprintf("%s/reset-password?token=%s", s.baseURL, url.QueryEscape(token))
subject := fmt.Sprintf("Reset your %s password", s.config.App.Title)
body := s.GeneratePasswordResetEmailBody(user.Username, resetURL)
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
return fmt.Errorf("send password reset email: %w", err)
}
return nil
}
func (s *EmailService) SendAccountDeletionEmail(user *database.User, token string) error {
if s.EmailSender == nil {
return ErrEmailSenderUnavailable
}
confirmationURL := fmt.Sprintf("%s/settings/delete/confirm?token=%s", s.baseURL, url.QueryEscape(token))
subject := "Confirm Account Deletion"
body := s.GenerateAccountDeletionEmailBody(user.Username, confirmationURL)
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
return fmt.Errorf("send account deletion email: %w", err)
}
return nil
}
func (s *EmailService) SendAccountDeletionNotificationEmail(user *database.User, deletedPosts bool) error {
if s.EmailSender == nil {
return ErrEmailSenderUnavailable
}
subject, body := GenerateAccountDeletionNotificationEmail(user.Username, s.config.App.AdminEmail, s.baseURL, s.config.App.Title, deletedPosts)
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
return fmt.Errorf("send account deletion notification email: %w", err)
}
return nil
}
func (s *EmailService) SendAccountLockNotificationEmail(user *database.User) error {
if s.EmailSender == nil {
return ErrEmailSenderUnavailable
}
subject, body := GenerateAccountLockNotificationEmail(user.Username, s.config.App.AdminEmail, s.config.App.Title)
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
return fmt.Errorf("send account lock notification email: %w", err)
}
return nil
}
func (s *EmailService) SendAccountUnlockNotificationEmail(user *database.User) error {
if s.EmailSender == nil {
return ErrEmailSenderUnavailable
}
subject, body := GenerateAccountUnlockNotificationEmail(user.Username, s.config.App.AdminEmail, s.baseURL, s.config.App.Title)
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
return fmt.Errorf("send account unlock notification email: %w", err)
}
return nil
}
func (s *EmailService) GenerateVerificationEmailBody(username, verificationURL string) string {
safeUsername := html.EscapeString(username)
safeURL := html.EscapeString(verificationURL)
siteTitle := s.config.App.Title
return s.generateStyledEmailBody(siteTitle, "🎉 Welcome! Confirm your email address",
fmt.Sprintf("Hello %s,", safeUsername),
fmt.Sprintf("Welcome to %s! Please confirm your email address by clicking the link below:", siteTitle),
safeURL, "Confirm Email Address",
"If the link doesn't work, you can copy and paste it into your browser.\n\nOnce confirmed, you'll be able to:\n- Create and share posts with the community\n- Vote on content you find interesting\n- Connect with other members\n- Customize your profile and preferences\n\nIf you didn't create this account, you can safely ignore this email.")
}
func (s *EmailService) GenerateEmailChangeVerificationEmailBody(username, verificationURL string) string {
safeUsername := html.EscapeString(username)
safeURL := html.EscapeString(verificationURL)
siteTitle := s.config.App.Title
return s.generateStyledEmailBody(siteTitle, "📧 Confirm your new email address",
fmt.Sprintf("Hello %s,", safeUsername),
"You've requested to change your email address. To complete this change, please confirm your new email address by clicking the link below:\n\nIf the link doesn't work, you can copy and paste it into your browser.\n\nOnce confirmed, your new email address will be active and you'll need to use it for future logins.\n\nIf you didn't request this email change, please contact our support team immediately.",
safeURL, "Confirm New Email Address",
"")
}
func (s *EmailService) GenerateResendVerificationEmailBody(username, verificationURL string) string {
safeUsername := html.EscapeString(username)
safeURL := html.EscapeString(verificationURL)
siteTitle := s.config.App.Title
return s.generateStyledEmailBody(siteTitle, fmt.Sprintf("🔄 Resend: Confirm your %s account", siteTitle),
fmt.Sprintf("Hello %s,", safeUsername),
"We've sent you a new verification link.\n\nPlease confirm your email address by clicking the link below:\n\nIf you're having trouble with the verification link:\n- Check your spam/junk folder\n- Make sure you're clicking the most recent email\n- Try copying the link and pasting it in your browser\n- Contact support if the problem persists\n\nIf you didn't request this email, you can safely ignore this message.",
safeURL, "Confirm Email Address",
"")
}
func (s *EmailService) GeneratePasswordResetEmailBody(username, resetURL string) string {
safeUsername := html.EscapeString(username)
safeURL := html.EscapeString(resetURL)
siteTitle := s.config.App.Title
return s.generateStyledEmailBody(siteTitle, fmt.Sprintf("Reset your %s password", siteTitle),
fmt.Sprintf("Hello %s,", safeUsername),
"We received a request to reset your password. To reset it, click the link below:",
safeURL, "Reset Password",
"This link will expire in 24 hours. If you didn't make this request, you can safely ignore this message.")
}
func (s *EmailService) GenerateAccountDeletionEmailBody(username, confirmationURL string) string {
safeUsername := html.EscapeString(username)
safeURL := html.EscapeString(confirmationURL)
siteTitle := s.config.App.Title
return s.generateStyledEmailBody(siteTitle, "Confirm Account Deletion",
fmt.Sprintf("Hello %s,", safeUsername),
fmt.Sprintf("We received a request to delete your %s account.\n\nTo confirm, click the link below:", siteTitle),
safeURL, "Confirm Account Deletion",
"This link will expire in 24 hours.\n\nIf you didn't make this request, you can safely ignore this message.")
}
func (s *EmailService) generateStyledEmailBody(siteTitle, emailTitle, greeting, message, actionURL, actionText, footer string) string {
safeAdminEmail := html.EscapeString(s.config.App.AdminEmail)
var buttonHTML string
if actionURL != "" && actionText != "" {
buttonHTML = fmt.Sprintf(`<div style="text-align: center;">
<a href="%s" class="action-button">%s</a>
</div>`, actionURL, actionText)
}
return fmt.Sprintf(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>%s</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
line-height: 1.6;
color: #333;
max-width: 600px;
margin: 0 auto;
padding: 20px;
background-color: #f8fafc;
}
.email-container {
background: white;
border-radius: 12px;
padding: 40px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
border: 1px solid #e2e8f0;
}
.header {
text-align: center;
margin-bottom: 30px;
}
.logo {
font-size: 28px;
font-weight: 700;
color: #0fb9b1;
margin-bottom: 10px;
}
.title {
font-size: 24px;
font-weight: 600;
color: #1a202c;
margin: 0;
}
.content {
margin-bottom: 30px;
}
.greeting {
font-size: 16px;
margin-bottom: 20px;
color: #2d3748;
}
.message {
font-size: 16px;
margin-bottom: 30px;
color: #4a5568;
white-space: pre-line;
}
.action-button {
display: inline-block;
background: #0fb9b1;
color: white;
text-decoration: none;
padding: 12px 24px;
border-radius: 8px;
font-weight: 600;
font-size: 16px;
text-align: center;
margin: 20px 0;
transition: background-color 0.2s;
}
.action-button:hover {
background: #0ea5a0;
}
.footer {
font-size: 14px;
color: #718096;
margin-top: 30px;
padding-top: 20px;
border-top: 1px solid #e2e8f0;
white-space: pre-line;
}
.link {
color: #0fb9b1;
text-decoration: none;
}
.link:hover {
text-decoration: underline;
}
@media (max-width: 600px) {
body {
padding: 10px;
}
.email-container {
padding: 20px;
}
.title {
font-size: 20px;
}
}
</style>
</head>
<body>
<div class="email-container">
<div class="header">
<div class="logo">%s</div>
<h1 class="title">%s</h1>
</div>
<div class="content">
<div class="greeting">%s</div>
<div class="message">%s</div>
%s
</div>
<div class="footer">
%s
If you have any questions or concerns, please <a href="mailto:%s" class="link">contact our support team</a>.<br>
Best regards,<br>
The %s Team
</div>
<div class="powered-by" style="text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #e2e8f0; font-size: 12px; color: #718096;">
Powered with ❤️ by <a href="https://goyco" style="color: #0fb9b1; text-decoration: none;">Goyco</a>
</div>
</div>
</body>
</html>`, emailTitle, siteTitle, emailTitle, greeting, message, buttonHTML, footer, safeAdminEmail, siteTitle)
}
func GenerateAccountDeletionNotificationEmail(username, adminEmail, baseURL, siteTitle string, deletedPosts bool) (string, string) {
subject := "Account Deleted"
safeUsername := html.EscapeString(username)
safeAdminEmail := html.EscapeString(adminEmail)
var body string
if deletedPosts {
body = fmt.Sprintf(`Your %s account has been permanently deleted.
All your posts have also been removed.
If you did not request this change, please <a href="mailto:%s" class="link">contact us</a>.
Thank you for being part of the community.`, siteTitle, safeAdminEmail)
} else {
body = fmt.Sprintf(`Your %s account has been permanently deleted.
Your posts have been preserved and are now anonymous.
If you did not request this change, please <a href="mailto:%s" class="link">contact us</a>.
Thank you for being part of the community.`, siteTitle, safeAdminEmail)
}
mainPageURL := strings.TrimRight(baseURL, "/")
safeMainPageURL := html.EscapeString(mainPageURL)
styledBody := generateStyledEmailBodyStatic(siteTitle, "Account Deleted",
fmt.Sprintf("Hello %s,", safeUsername),
body,
safeMainPageURL, fmt.Sprintf("Visit %s", siteTitle),
"", adminEmail)
return subject, styledBody
}
func GenerateAdminAccountDeletionNotificationEmail(username, adminEmail, baseURL, siteTitle string, deletedPosts bool) (string, string) {
subject := "Account Deleted by Administrator"
safeUsername := html.EscapeString(username)
safeAdminEmail := html.EscapeString(adminEmail)
var message string
if deletedPosts {
message = fmt.Sprintf("Your %s account has been permanently deleted by an administrator.\n\nAll your posts have also been removed.\n\nIf you did not request this change, please <a href=\"mailto:%s\" class=\"link\">contact us</a>.\n\nThank you for being part of the community.", siteTitle, safeAdminEmail)
} else {
message = fmt.Sprintf("Your %s account has been permanently deleted by an administrator.\n\nYour posts have been preserved and are now anonymous.\n\nIf you did not request this change, please <a href=\"mailto:%s\" class=\"link\">contact us</a>.\n\nThank you for being part of the community.", siteTitle, safeAdminEmail)
}
mainPageURL := strings.TrimRight(baseURL, "/")
safeMainPageURL := html.EscapeString(mainPageURL)
body := generateStyledEmailBodyStatic(siteTitle, "Account Deleted by Administrator",
fmt.Sprintf("Hello %s,", safeUsername),
message,
safeMainPageURL, fmt.Sprintf("Visit %s", siteTitle),
"", adminEmail)
return subject, body
}
func GenerateAccountLockNotificationEmail(username, adminEmail, siteTitle string) (string, string) {
subject := "Account Locked"
safeUsername := html.EscapeString(username)
safeAdminEmail := html.EscapeString(adminEmail)
message := fmt.Sprintf("Your %s account has been locked by an administrator.\n\nYou will not be able to log in or access your account until it is unlocked.\n\nIf you believe this is an error or need to discuss your account status, please <a href=\"mailto:%s\" class=\"link\">contact us</a>.\n\nThis action was taken to protect the security and integrity of our platform.", siteTitle, safeAdminEmail)
body := generateStyledEmailBodyStatic(siteTitle, "Account Locked",
fmt.Sprintf("Hello %s,", safeUsername),
message,
"", "",
"", adminEmail)
return subject, body
}
func GenerateAccountUnlockNotificationEmail(username, adminEmail, baseURL, siteTitle string) (string, string) {
subject := "Account Unlocked"
safeUsername := html.EscapeString(username)
message := fmt.Sprintf("Your %s account has been unlocked by an administrator.\n\nYou can now log in and access all your account features normally.\n\nAll your previous data and settings have been preserved.\n\nWelcome back!", siteTitle)
loginURL := fmt.Sprintf("%s/login", strings.TrimRight(baseURL, "/"))
safeLoginURL := html.EscapeString(loginURL)
body := generateStyledEmailBodyStatic(siteTitle, "Account Unlocked",
fmt.Sprintf("Hello %s,", safeUsername),
message,
safeLoginURL, fmt.Sprintf("Login to %s", siteTitle),
"", adminEmail)
return subject, body
}
func generateStyledEmailBodyStatic(siteTitle, emailTitle, greeting, message, actionURL, actionText, footer, adminEmail string) string {
safeAdminEmail := html.EscapeString(adminEmail)
var buttonHTML string
if actionURL != "" && actionText != "" {
buttonHTML = fmt.Sprintf(`<div style="text-align: center;">
<a href="%s" class="action-button">%s</a>
</div>`, actionURL, actionText)
}
return fmt.Sprintf(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>%s</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
line-height: 1.6;
color: #333;
max-width: 600px;
margin: 0 auto;
padding: 20px;
background-color: #f8fafc;
}
.email-container {
background: white;
border-radius: 12px;
padding: 40px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
border: 1px solid #e2e8f0;
}
.header {
text-align: center;
margin-bottom: 30px;
}
.logo {
font-size: 28px;
font-weight: 700;
color: #0fb9b1;
margin-bottom: 10px;
}
.title {
font-size: 24px;
font-weight: 600;
color: #1a202c;
margin: 0;
}
.content {
margin-bottom: 30px;
}
.greeting {
font-size: 16px;
margin-bottom: 20px;
color: #2d3748;
}
.message {
font-size: 16px;
margin-bottom: 30px;
color: #4a5568;
white-space: pre-line;
}
.action-button {
display: inline-block;
background: #0fb9b1;
color: white;
text-decoration: none;
padding: 12px 24px;
border-radius: 8px;
font-weight: 600;
font-size: 16px;
text-align: center;
margin: 20px 0;
transition: background-color 0.2s;
}
.action-button:hover {
background: #0ea5a0;
}
.footer {
font-size: 14px;
color: #718096;
margin-top: 30px;
padding-top: 20px;
border-top: 1px solid #e2e8f0;
white-space: pre-line;
}
.link {
color: #0fb9b1;
text-decoration: none;
}
.link:hover {
text-decoration: underline;
}
@media (max-width: 600px) {
body {
padding: 10px;
}
.email-container {
padding: 20px;
}
.title {
font-size: 20px;
}
}
</style>
</head>
<body>
<div class="email-container">
<div class="header">
<div class="logo">%s</div>
<h1 class="title">%s</h1>
</div>
<div class="content">
<div class="greeting">%s</div>
<div class="message">%s</div>
%s
</div>
<div class="footer">
%s
If you have any questions or concerns, please <a href="mailto:%s" class="link">contact our support team</a>.<br>
Best regards,<br>
The %s Team
</div>
<div class="powered-by" style="text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #e2e8f0; font-size: 12px; color: #718096;">
Powered with ❤️ by <a href="https://goyco" style="color: #0fb9b1; text-decoration: none;">Goyco</a>
</div>
</div>
</body>
</html>`, emailTitle, siteTitle, emailTitle, greeting, message, buttonHTML, footer, safeAdminEmail, siteTitle)
}

View File

@@ -0,0 +1,275 @@
package services
import (
"fmt"
"html"
"os"
"strings"
"testing"
"time"
"goyco/internal/config"
"goyco/internal/testutils"
)
func TestEmailService_IntegrationWithRealSMTP(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
sender := testutils.GetSMTPSenderFromEnv(t)
recipient := os.Getenv("SMTP_TEST_RECIPIENT")
if strings.TrimSpace(recipient) == "" {
recipient = sender.From
}
config := testutils.NewEmailTestConfig("https://example.com")
service, err := NewEmailService(config, sender)
if err != nil {
t.Skipf("Skipping SMTP integration test: failed to create email service: %v", err)
}
user := testutils.NewEmailTestUser("integrationuser", recipient)
token := "integration-test-token"
t.Run("VerificationEmail_RealSMTP", func(t *testing.T) {
err := service.SendVerificationEmail(user, token)
if err != nil {
t.Errorf("SendVerificationEmail failed: %v", err)
}
body := service.GenerateVerificationEmailBody(user.Username, fmt.Sprintf("%s/confirm?token=%s", config.App.BaseURL, token))
if !strings.Contains(body, "https://example.com/confirm?token=integration-test-token") {
t.Errorf("Expected body to contain verification URL")
}
if !strings.Contains(body, "<!DOCTYPE html>") {
t.Errorf("Expected HTML content in email body")
}
})
t.Run("PasswordResetEmail_RealSMTP", func(t *testing.T) {
err := service.SendPasswordResetEmail(user, token)
if err != nil {
t.Errorf("SendPasswordResetEmail failed: %v", err)
}
body := service.GeneratePasswordResetEmailBody(user.Username, fmt.Sprintf("%s/reset-password?token=%s", config.App.BaseURL, token))
if !strings.Contains(body, "https://example.com/reset-password?token=integration-test-token") {
t.Errorf("Expected body to contain reset URL")
}
})
t.Run("AccountDeletionEmail_RealSMTP", func(t *testing.T) {
err := service.SendAccountDeletionEmail(user, token)
if err != nil {
t.Errorf("SendAccountDeletionEmail failed: %v", err)
}
body := service.GenerateAccountDeletionEmailBody(user.Username, fmt.Sprintf("%s/settings/delete/confirm?token=%s", config.App.BaseURL, token))
if !strings.Contains(body, "https://example.com/settings/delete/confirm?token=integration-test-token") {
t.Errorf("Expected body to contain deletion confirmation URL")
}
})
}
func TestEmailService_Performance(t *testing.T) {
if testing.Short() {
t.Skip("Skipping performance test in short mode")
}
config := testutils.NewEmailTestConfig("https://example.com")
service, err := NewEmailService(config, &testutils.MockEmailSender{})
if err != nil {
t.Fatalf("Failed to create email service: %v", err)
}
user := testutils.NewEmailTestUser("perfuser", "perf@example.com")
service.GenerateVerificationEmailBody(user.Username, "https://example.com/confirm?token=test")
start := time.Now()
iterations := 1000
for i := 0; i < iterations; i++ {
service.GenerateVerificationEmailBody(user.Username, "https://example.com/confirm?token=test")
}
duration := time.Since(start)
maxDuration := 500 * time.Millisecond
if duration > maxDuration {
t.Errorf("HTML generation took too long: %v (expected < %v for %d iterations, %.2fms per template)",
duration, maxDuration, iterations, float64(duration.Nanoseconds())/float64(iterations)/1e6)
}
t.Logf("Generated %d HTML emails in %v (%.2fms per template)",
iterations, duration, float64(duration.Nanoseconds())/float64(iterations)/1e6)
}
func TestEmailService_EdgeCases(t *testing.T) {
config := testutils.NewEmailTestConfig("https://example.com")
service, err := NewEmailService(config, &testutils.MockEmailSender{})
if err != nil {
t.Fatalf("Failed to create email service: %v", err)
}
t.Run("EmptyUsername", func(t *testing.T) {
body := service.GenerateVerificationEmailBody("", "https://example.com/confirm?token=test")
if !strings.Contains(body, "Hello ,") {
t.Error("Expected empty username to be handled gracefully")
}
})
t.Run("VeryLongUsername", func(t *testing.T) {
longUsername := strings.Repeat("a", 1000)
body := service.GenerateVerificationEmailBody(longUsername, "https://example.com/confirm?token=test")
if !strings.Contains(body, longUsername) {
t.Error("Expected long username to be included in email")
}
})
t.Run("SpecialCharactersInUsername", func(t *testing.T) {
specialUsername := "user@domain.com & <script>alert('xss')</script>"
body := service.GenerateVerificationEmailBody(specialUsername, "https://example.com/confirm?token=test")
escapedUsername := html.EscapeString(specialUsername)
if !strings.Contains(body, escapedUsername) {
t.Errorf("Expected escaped username %q to be included", escapedUsername)
}
})
t.Run("EmptyToken", func(t *testing.T) {
body := service.GenerateVerificationEmailBody("testuser", "https://example.com/confirm?token=")
if !strings.Contains(body, "https://example.com/confirm?token=") {
t.Error("Expected empty token to be handled")
}
})
t.Run("VeryLongToken", func(t *testing.T) {
longToken := strings.Repeat("a", 1000)
url := fmt.Sprintf("https://example.com/confirm?token=%s", longToken)
body := service.GenerateVerificationEmailBody("testuser", url)
if !strings.Contains(body, url) {
t.Error("Expected long token to be included in email")
}
})
}
func TestNewEmailService(t *testing.T) {
tests := []struct {
name string
config *config.Config
sender EmailSender
expectError bool
errorMsg string
}{
{
name: "Valid configuration",
config: testutils.NewEmailTestConfig("https://example.com"),
sender: &testutils.MockEmailSender{},
expectError: false,
},
{
name: "Empty base URL",
config: testutils.NewEmailTestConfig(""),
sender: &testutils.MockEmailSender{},
expectError: true,
errorMsg: "APP_BASE_URL is required",
},
{
name: "Whitespace base URL",
config: testutils.NewEmailTestConfig(" "),
sender: &testutils.MockEmailSender{},
expectError: true,
errorMsg: "APP_BASE_URL is required",
},
{
name: "Base URL with trailing slash",
config: testutils.NewEmailTestConfig("https://example.com/"),
sender: &testutils.MockEmailSender{},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service, err := NewEmailService(tt.config, tt.sender)
if tt.expectError {
if err == nil {
t.Errorf("Expected error, got nil")
return
}
if !strings.Contains(err.Error(), tt.errorMsg) {
t.Errorf("Expected error to contain '%s', got '%s'", tt.errorMsg, err.Error())
}
return
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
if service == nil {
t.Error("Expected service, got nil")
return
}
expectedBaseURL := strings.TrimRight(strings.TrimSpace(tt.config.App.BaseURL), "/")
if service.baseURL != expectedBaseURL {
t.Errorf("Expected baseURL '%s', got '%s'", expectedBaseURL, service.baseURL)
}
})
}
}
func TestEmailService_DynamicTitle(t *testing.T) {
const (
placeholderTitle = "My Custom Site"
customTitle = "Custom Community"
)
cfg := &config.Config{
App: config.AppConfig{
Title: customTitle,
BaseURL: "https://example.com",
},
}
sender := &testutils.MockEmailSender{}
service, err := NewEmailService(cfg, sender)
if err != nil {
t.Fatalf("Failed to create email service: %v", err)
}
body := service.GenerateVerificationEmailBody("testuser", "https://example.com/confirm?token=abc123")
if !strings.Contains(body, customTitle) {
t.Error("Expected email body to contain custom site title")
}
if strings.Contains(body, placeholderTitle) {
t.Errorf("Expected email body to not contain placeholder title %q", placeholderTitle)
}
if strings.Contains(body, "The Goyco Team") {
t.Error("Expected email body to not contain default team name when custom title is set")
}
cfgDefault := &config.Config{
App: config.AppConfig{
Title: "Goyco",
BaseURL: "https://example.com",
},
}
serviceDefault, err := NewEmailService(cfgDefault, sender)
if err != nil {
t.Fatalf("Failed to create email service: %v", err)
}
bodyDefault := serviceDefault.GenerateVerificationEmailBody("testuser", "https://example.com/confirm?token=abc123")
if !strings.Contains(bodyDefault, "Goyco") {
t.Error("Expected email body to contain default site title")
}
}

View File

@@ -0,0 +1,360 @@
package services
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"slices"
"time"
"github.com/golang-jwt/jwt/v5"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/repositories"
)
const (
TokenTypeAccess = "access"
TokenTypeRefresh = "refresh"
)
var (
ErrInvalidTokenType = errors.New("invalid token type")
ErrTokenExpired = errors.New("token expired")
ErrInvalidIssuer = errors.New("invalid issuer")
ErrInvalidAudience = errors.New("invalid audience")
ErrInvalidKeyID = errors.New("invalid key ID")
ErrRefreshTokenExpired = errors.New("refresh token expired")
ErrRefreshTokenInvalid = errors.New("refresh token invalid")
)
type TokenClaims struct {
UserID uint `json:"sub"`
Username string `json:"username"`
SessionVersion uint `json:"session_version"`
TokenType string `json:"type"`
KeyID string `json:"kid,omitempty"`
jwt.RegisteredClaims
}
type JWTService struct {
config *config.JWTConfig
userRepo UserRepository
refreshRepo repositories.RefreshTokenRepositoryInterface
}
type verificationKey struct {
key []byte
}
type UserRepository interface {
GetByID(id uint) (*database.User, error)
GetByUsername(username string) (*database.User, error)
Update(user *database.User) error
}
func NewJWTService(cfg *config.JWTConfig, userRepo UserRepository, refreshRepo repositories.RefreshTokenRepositoryInterface) *JWTService {
return &JWTService{
config: cfg,
userRepo: userRepo,
refreshRepo: refreshRepo,
}
}
func (j *JWTService) GenerateAccessToken(user *database.User) (string, error) {
return j.generateToken(user, TokenTypeAccess, time.Duration(j.config.Expiration)*time.Hour)
}
func (j *JWTService) GenerateRefreshToken(user *database.User) (string, error) {
if user == nil {
return "", ErrInvalidCredentials
}
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
return "", fmt.Errorf("generate refresh token: %w", err)
}
tokenString := hex.EncodeToString(tokenBytes)
tokenHash := j.hashToken(tokenString)
refreshToken := &database.RefreshToken{
UserID: user.ID,
TokenHash: tokenHash,
ExpiresAt: time.Now().Add(time.Duration(j.config.RefreshExpiration) * time.Hour),
}
if err := j.refreshRepo.Create(refreshToken); err != nil {
return "", fmt.Errorf("store refresh token: %w", err)
}
return tokenString, nil
}
func (j *JWTService) VerifyAccessToken(tokenString string) (uint, error) {
claims, err := j.parseToken(tokenString)
if err != nil {
return 0, err
}
if claims.TokenType != TokenTypeAccess {
return 0, ErrInvalidTokenType
}
user, err := j.userRepo.GetByID(claims.UserID)
if err != nil {
if IsRecordNotFound(err) {
return 0, ErrInvalidToken
}
return 0, fmt.Errorf("lookup user: %w", err)
}
if user.Locked {
return 0, ErrAccountLocked
}
if user.SessionVersion != claims.SessionVersion {
return 0, ErrInvalidToken
}
return claims.UserID, nil
}
func (j *JWTService) RefreshAccessToken(refreshTokenString string) (string, error) {
tokenHash := j.hashToken(refreshTokenString)
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrRefreshTokenInvalid
}
return "", fmt.Errorf("lookup refresh token: %w", err)
}
if time.Now().After(refreshToken.ExpiresAt) {
j.refreshRepo.DeleteByID(refreshToken.ID)
return "", ErrRefreshTokenExpired
}
user, err := j.userRepo.GetByID(refreshToken.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
j.refreshRepo.DeleteByID(refreshToken.ID)
return "", ErrRefreshTokenInvalid
}
return "", fmt.Errorf("lookup user: %w", err)
}
if user.Locked {
j.refreshRepo.DeleteByID(refreshToken.ID)
return "", ErrAccountLocked
}
accessToken, err := j.GenerateAccessToken(user)
if err != nil {
return "", fmt.Errorf("generate access token: %w", err)
}
return accessToken, nil
}
func (j *JWTService) RevokeRefreshToken(refreshTokenString string) error {
tokenHash := j.hashToken(refreshTokenString)
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
return fmt.Errorf("lookup refresh token: %w", err)
}
return j.refreshRepo.DeleteByID(refreshToken.ID)
}
func (j *JWTService) RevokeAllRefreshTokens(userID uint) error {
return j.refreshRepo.DeleteByUserID(userID)
}
func (j *JWTService) CleanupExpiredTokens() error {
return j.refreshRepo.DeleteExpired()
}
func (j *JWTService) generateToken(user *database.User, tokenType string, expiration time.Duration) (string, error) {
if user == nil {
return "", ErrInvalidCredentials
}
jtiBytes := make([]byte, 16)
if _, err := rand.Read(jtiBytes); err != nil {
return "", fmt.Errorf("generate token ID: %w", err)
}
now := time.Now()
claims := TokenClaims{
UserID: user.ID,
Username: user.Username,
SessionVersion: user.SessionVersion,
TokenType: tokenType,
RegisteredClaims: jwt.RegisteredClaims{
ID: hex.EncodeToString(jtiBytes),
Issuer: j.config.Issuer,
Audience: []string{j.config.Audience},
Subject: fmt.Sprint(user.ID),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(expiration)),
},
}
if j.config.KeyRotation.Enabled {
claims.KeyID = j.config.KeyRotation.KeyID
}
var signingMethod jwt.SigningMethod
var key any
if j.config.KeyRotation.Enabled {
signingMethod = jwt.SigningMethodHS256
key = []byte(j.config.KeyRotation.CurrentKey)
} else {
signingMethod = jwt.SigningMethodHS256
key = []byte(j.config.Secret)
}
token := jwt.NewWithClaims(signingMethod, claims)
if j.config.KeyRotation.Enabled {
token.Header["kid"] = j.config.KeyRotation.KeyID
}
return token.SignedString(key)
}
func (j *JWTService) parseToken(tokenString string) (*TokenClaims, error) {
if TrimString(tokenString) == "" {
return nil, ErrInvalidToken
}
parser := jwt.NewParser()
unverified, _, err := parser.ParseUnverified(tokenString, jwt.MapClaims{})
if err != nil {
return nil, ErrInvalidToken
}
headerKid, _ := unverified.Header["kid"].(string)
if j.config.KeyRotation.Enabled {
if headerKid != "" && headerKid != j.config.KeyRotation.KeyID && j.config.KeyRotation.PreviousKey == "" {
return nil, ErrInvalidKeyID
}
} else if headerKid != "" {
return nil, ErrInvalidKeyID
}
keys := j.verificationKeys()
if len(keys) == 0 {
return nil, ErrInvalidToken
}
var lastErr error
for _, candidate := range keys {
claims := &TokenClaims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Alg() {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return candidate.key, nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrTokenExpired
}
lastErr = ErrInvalidToken
continue
}
if claims.ExpiresAt == nil || time.Until(claims.ExpiresAt.Time) < 0 {
return nil, ErrTokenExpired
}
if !token.Valid {
lastErr = ErrInvalidToken
continue
}
if err := j.validateTokenMetadata(token, claims, headerKid); err != nil {
return nil, err
}
return claims, nil
}
if lastErr != nil {
return nil, lastErr
}
return nil, ErrInvalidToken
}
func (j *JWTService) verificationKeys() []verificationKey {
if j.config == nil {
return nil
}
if j.config.KeyRotation.Enabled {
keys := []verificationKey{{key: []byte(j.config.KeyRotation.CurrentKey)}}
if j.config.KeyRotation.PreviousKey != "" {
keys = append(keys, verificationKey{key: []byte(j.config.KeyRotation.PreviousKey)})
}
return keys
}
return []verificationKey{{key: []byte(j.config.Secret)}}
}
func (j *JWTService) validateTokenMetadata(token *jwt.Token, claims *TokenClaims, headerKid string) error {
actualKid, _ := token.Header["kid"].(string)
if actualKid == "" {
actualKid = headerKid
}
if j.config.KeyRotation.Enabled {
if actualKid == "" {
if claims.KeyID != "" {
return ErrInvalidKeyID
}
} else {
if claims.KeyID == "" || claims.KeyID != actualKid {
return ErrInvalidKeyID
}
}
if actualKid != "" && actualKid != j.config.KeyRotation.KeyID && j.config.KeyRotation.PreviousKey == "" {
return ErrInvalidKeyID
}
} else {
if actualKid != "" || claims.KeyID != "" {
return ErrInvalidKeyID
}
}
if claims.Issuer != j.config.Issuer {
return ErrInvalidIssuer
}
if !slices.Contains(claims.Audience, j.config.Audience) {
return ErrInvalidAudience
}
return nil
}
func (j *JWTService) hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}

View File

@@ -0,0 +1,966 @@
package services
import (
"errors"
"fmt"
"slices"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
)
type jwtMockUserRepo struct {
users map[uint]*database.User
}
func (m *jwtMockUserRepo) GetByID(id uint) (*database.User, error) {
if user, exists := m.users[id]; exists {
return user, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *jwtMockUserRepo) GetByUsername(username string) (*database.User, error) {
for _, user := range m.users {
if user.Username == username {
return user, nil
}
}
return nil, gorm.ErrRecordNotFound
}
func (m *jwtMockUserRepo) Update(user *database.User) error {
if _, exists := m.users[user.ID]; !exists {
return gorm.ErrRecordNotFound
}
m.users[user.ID] = user
return nil
}
type jwtMockRefreshTokenRepo struct {
tokens map[string]*database.RefreshToken
nextID uint
}
func (m *jwtMockRefreshTokenRepo) Create(token *database.RefreshToken) error {
if m.tokens == nil {
m.tokens = make(map[string]*database.RefreshToken)
}
m.nextID++
token.ID = m.nextID
m.tokens[token.TokenHash] = token
return nil
}
func (m *jwtMockRefreshTokenRepo) GetByTokenHash(tokenHash string) (*database.RefreshToken, error) {
if token, exists := m.tokens[tokenHash]; exists {
return token, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *jwtMockRefreshTokenRepo) DeleteByUserID(userID uint) error {
for hash, token := range m.tokens {
if token.UserID == userID {
delete(m.tokens, hash)
}
}
return nil
}
func (m *jwtMockRefreshTokenRepo) DeleteExpired() error {
now := time.Now()
for hash, token := range m.tokens {
if token.ExpiresAt.Before(now) {
delete(m.tokens, hash)
}
}
return nil
}
func (m *jwtMockRefreshTokenRepo) DeleteByID(id uint) error {
for hash, token := range m.tokens {
if token.ID == id {
delete(m.tokens, hash)
return nil
}
}
return gorm.ErrRecordNotFound
}
func (m *jwtMockRefreshTokenRepo) GetByUserID(userID uint) ([]database.RefreshToken, error) {
var tokens []database.RefreshToken
for _, token := range m.tokens {
if token.UserID == userID {
tokens = append(tokens, *token)
}
}
return tokens, nil
}
func (m *jwtMockRefreshTokenRepo) CountByUserID(userID uint) (int64, error) {
var count int64
for _, token := range m.tokens {
if token.UserID == userID {
count++
}
}
return count, nil
}
func createTestJWTService() (*JWTService, *jwtMockUserRepo, *jwtMockRefreshTokenRepo) {
cfg := &config.JWTConfig{
Secret: "test-secret-key-that-is-long-enough-for-security",
Expiration: 1,
RefreshExpiration: 24,
Issuer: "test-issuer",
Audience: "test-audience",
KeyRotation: config.KeyRotationConfig{
Enabled: false,
CurrentKey: "",
PreviousKey: "",
KeyID: "",
},
}
userRepo := &jwtMockUserRepo{
users: make(map[uint]*database.User),
}
refreshRepo := &jwtMockRefreshTokenRepo{
tokens: make(map[string]*database.RefreshToken),
}
jwtService := NewJWTService(cfg, userRepo, refreshRepo)
return jwtService, userRepo, refreshRepo
}
func createTestUser() *database.User {
return &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
Password: "hashedpassword",
EmailVerified: true,
Locked: false,
SessionVersion: 1,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
}
func TestJWTService_GenerateAccessToken(t *testing.T) {
jwtService, userRepo, _ := createTestJWTService()
user := createTestUser()
userRepo.users[user.ID] = user
t.Run("Successful_Generation", func(t *testing.T) {
token, err := jwtService.GenerateAccessToken(user)
if err != nil {
t.Fatalf("Expected successful token generation, got error: %v", err)
}
if token == "" {
t.Error("Expected non-empty token")
}
parsedToken, err := jwt.Parse(token, func(token *jwt.Token) (any, error) {
return []byte(jwtService.config.Secret), nil
})
if err != nil {
t.Fatalf("Failed to parse generated token: %v", err)
}
if !parsedToken.Valid {
t.Error("Generated token should be valid")
}
if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
if claims["sub"] != float64(user.ID) {
t.Errorf("Expected subject %d, got %v", user.ID, claims["sub"])
}
if claims["username"] != user.Username {
t.Errorf("Expected username %s, got %v", user.Username, claims["username"])
}
if claims["session_version"] != float64(user.SessionVersion) {
t.Errorf("Expected session_version %d, got %v", user.SessionVersion, claims["session_version"])
}
if claims["type"] != TokenTypeAccess {
t.Errorf("Expected type %s, got %v", TokenTypeAccess, claims["type"])
}
if claims["iss"] != jwtService.config.Issuer {
t.Errorf("Expected issuer %s, got %v", jwtService.config.Issuer, claims["iss"])
}
if aud, ok := claims["aud"].([]any); !ok || len(aud) != 1 || aud[0] != jwtService.config.Audience {
t.Errorf("Expected audience [%s], got %v", jwtService.config.Audience, claims["aud"])
}
}
})
t.Run("Nil_User", func(t *testing.T) {
_, err := jwtService.GenerateAccessToken(nil)
if err == nil {
t.Error("Expected error for nil user")
}
if !errors.Is(err, ErrInvalidCredentials) {
t.Errorf("Expected ErrInvalidCredentials, got %v", err)
}
})
}
func TestJWTService_GenerateRefreshToken(t *testing.T) {
jwtService, userRepo, refreshRepo := createTestJWTService()
user := createTestUser()
userRepo.users[user.ID] = user
t.Run("Successful_Generation", func(t *testing.T) {
token, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("Expected successful refresh token generation, got error: %v", err)
}
if token == "" {
t.Error("Expected non-empty refresh token")
}
tokenHash := jwtService.hashToken(token)
storedToken, err := refreshRepo.GetByTokenHash(tokenHash)
if err != nil {
t.Fatalf("Expected refresh token to be stored in database: %v", err)
}
if storedToken.UserID != user.ID {
t.Errorf("Expected user ID %d, got %d", user.ID, storedToken.UserID)
}
expectedExpiry := time.Now().Add(time.Duration(jwtService.config.RefreshExpiration) * time.Hour)
if storedToken.ExpiresAt.Before(expectedExpiry.Add(-time.Minute)) || storedToken.ExpiresAt.After(expectedExpiry.Add(time.Minute)) {
t.Errorf("Expected expiry around %v, got %v", expectedExpiry, storedToken.ExpiresAt)
}
})
t.Run("Nil_User", func(t *testing.T) {
_, err := jwtService.GenerateRefreshToken(nil)
if err == nil {
t.Error("Expected error for nil user")
}
if !errors.Is(err, ErrInvalidCredentials) {
t.Errorf("Expected ErrInvalidCredentials, got %v", err)
}
})
}
func TestJWTService_VerifyAccessToken(t *testing.T) {
jwtService, userRepo, _ := createTestJWTService()
user := createTestUser()
userRepo.users[user.ID] = user
t.Run("Valid_Token", func(t *testing.T) {
token, err := jwtService.GenerateAccessToken(user)
if err != nil {
t.Fatalf("Failed to generate token: %v", err)
}
userID, err := jwtService.VerifyAccessToken(token)
if err != nil {
t.Fatalf("Expected successful token verification, got error: %v", err)
}
if userID != user.ID {
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
}
})
t.Run("Invalid_Token", func(t *testing.T) {
_, err := jwtService.VerifyAccessToken("invalid-token")
if err == nil {
t.Error("Expected error for invalid token")
}
})
t.Run("Empty_Token", func(t *testing.T) {
_, err := jwtService.VerifyAccessToken("")
if err == nil {
t.Error("Expected error for empty token")
}
if !errors.Is(err, ErrInvalidToken) {
t.Errorf("Expected ErrInvalidToken, got %v", err)
}
})
t.Run("User_Not_Found", func(t *testing.T) {
token, err := jwtService.GenerateAccessToken(user)
if err != nil {
t.Fatalf("Failed to generate token: %v", err)
}
delete(userRepo.users, user.ID)
_, err = jwtService.VerifyAccessToken(token)
if err == nil {
t.Error("Expected error for non-existent user")
}
if !errors.Is(err, ErrInvalidToken) {
t.Errorf("Expected ErrInvalidToken, got %v", err)
}
})
t.Run("Locked_User", func(t *testing.T) {
user.Locked = true
userRepo.users[user.ID] = user
token, err := jwtService.GenerateAccessToken(user)
if err != nil {
t.Fatalf("Failed to generate token: %v", err)
}
_, err = jwtService.VerifyAccessToken(token)
if err == nil {
t.Error("Expected error for locked user")
}
if !errors.Is(err, ErrAccountLocked) {
t.Errorf("Expected ErrAccountLocked, got %v", err)
}
})
t.Run("Session_Version_Mismatch", func(t *testing.T) {
user.Locked = false
user.SessionVersion = 2
userRepo.users[user.ID] = user
oldUser := *user
oldUser.SessionVersion = 1
token, err := jwtService.GenerateAccessToken(&oldUser)
if err != nil {
t.Fatalf("Failed to generate token: %v", err)
}
_, err = jwtService.VerifyAccessToken(token)
if err == nil {
t.Error("Expected error for session version mismatch")
}
if !errors.Is(err, ErrInvalidToken) {
t.Errorf("Expected ErrInvalidToken, got %v", err)
}
})
}
func TestJWTService_RefreshAccessToken(t *testing.T) {
jwtService, userRepo, refreshRepo := createTestJWTService()
user := createTestUser()
userRepo.users[user.ID] = user
t.Run("Successful_Refresh", func(t *testing.T) {
refreshToken, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("Failed to generate refresh token: %v", err)
}
accessToken, err := jwtService.RefreshAccessToken(refreshToken)
if err != nil {
t.Fatalf("Expected successful token refresh, got error: %v", err)
}
if accessToken == "" {
t.Error("Expected non-empty access token")
}
userID, err := jwtService.VerifyAccessToken(accessToken)
if err != nil {
t.Fatalf("Expected valid access token, got error: %v", err)
}
if userID != user.ID {
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
}
})
t.Run("Invalid_Refresh_Token", func(t *testing.T) {
_, err := jwtService.RefreshAccessToken("invalid-refresh-token")
if err == nil {
t.Error("Expected error for invalid refresh token")
}
if !errors.Is(err, ErrRefreshTokenInvalid) {
t.Errorf("Expected ErrRefreshTokenInvalid, got %v", err)
}
})
t.Run("Expired_Refresh_Token", func(t *testing.T) {
refreshToken := &database.RefreshToken{
UserID: user.ID,
TokenHash: "expired-token-hash",
ExpiresAt: time.Now().Add(-time.Hour),
}
refreshRepo.tokens["expired-token-hash"] = refreshToken
testToken := "test-expired-token"
tokenHash := jwtService.hashToken(testToken)
refreshToken.TokenHash = tokenHash
refreshRepo.tokens[tokenHash] = refreshToken
_, err := jwtService.RefreshAccessToken(testToken)
if err == nil {
t.Error("Expected error for expired refresh token")
}
if !errors.Is(err, ErrRefreshTokenExpired) {
t.Errorf("Expected ErrRefreshTokenExpired, got %v", err)
}
})
t.Run("User_Not_Found", func(t *testing.T) {
refreshToken, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("Failed to generate refresh token: %v", err)
}
delete(userRepo.users, user.ID)
_, err = jwtService.RefreshAccessToken(refreshToken)
if err == nil {
t.Error("Expected error for non-existent user")
}
if !errors.Is(err, ErrRefreshTokenInvalid) {
t.Errorf("Expected ErrRefreshTokenInvalid, got %v", err)
}
})
t.Run("Locked_User", func(t *testing.T) {
user.Locked = true
userRepo.users[user.ID] = user
refreshToken, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("Failed to generate refresh token: %v", err)
}
_, err = jwtService.RefreshAccessToken(refreshToken)
if err == nil {
t.Error("Expected error for locked user")
}
if !errors.Is(err, ErrAccountLocked) {
t.Errorf("Expected ErrAccountLocked, got %v", err)
}
})
}
func TestJWTService_RevokeRefreshToken(t *testing.T) {
jwtService, userRepo, refreshRepo := createTestJWTService()
user := createTestUser()
userRepo.users[user.ID] = user
t.Run("Successful_Revocation", func(t *testing.T) {
refreshToken, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("Failed to generate refresh token: %v", err)
}
tokenHash := jwtService.hashToken(refreshToken)
_, err = refreshRepo.GetByTokenHash(tokenHash)
if err != nil {
t.Fatalf("Expected refresh token to exist: %v", err)
}
err = jwtService.RevokeRefreshToken(refreshToken)
if err != nil {
t.Fatalf("Expected successful token revocation, got error: %v", err)
}
_, err = refreshRepo.GetByTokenHash(tokenHash)
if err == nil {
t.Error("Expected refresh token to be removed")
}
})
t.Run("Non_Existent_Token", func(t *testing.T) {
err := jwtService.RevokeRefreshToken("non-existent-token")
if err != nil {
t.Errorf("Expected no error for non-existent token, got %v", err)
}
})
}
func TestJWTService_RevokeAllRefreshTokens(t *testing.T) {
jwtService, userRepo, refreshRepo := createTestJWTService()
user := createTestUser()
userRepo.users[user.ID] = user
t.Run("Successful_Revocation", func(t *testing.T) {
_, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("Failed to generate first refresh token: %v", err)
}
_, err = jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("Failed to generate second refresh token: %v", err)
}
count, err := refreshRepo.CountByUserID(user.ID)
if err != nil {
t.Fatalf("Failed to count tokens: %v", err)
}
if count != 2 {
t.Errorf("Expected 2 tokens, got %d", count)
}
err = jwtService.RevokeAllRefreshTokens(user.ID)
if err != nil {
t.Fatalf("Expected successful token revocation, got error: %v", err)
}
count, err = refreshRepo.CountByUserID(user.ID)
if err != nil {
t.Fatalf("Failed to count tokens: %v", err)
}
if count != 0 {
t.Errorf("Expected 0 tokens, got %d", count)
}
})
}
func TestJWTService_CleanupExpiredTokens(t *testing.T) {
jwtService, userRepo, refreshRepo := createTestJWTService()
user := createTestUser()
userRepo.users[user.ID] = user
t.Run("Successful_Cleanup", func(t *testing.T) {
expiredToken := &database.RefreshToken{
UserID: user.ID,
TokenHash: "expired-token-hash",
ExpiresAt: time.Now().Add(-time.Hour),
}
refreshRepo.tokens["expired-token-hash"] = expiredToken
validToken, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("Failed to generate valid refresh token: %v", err)
}
if len(refreshRepo.tokens) != 2 {
t.Errorf("Expected 2 tokens, got %d", len(refreshRepo.tokens))
}
err = jwtService.CleanupExpiredTokens()
if err != nil {
t.Fatalf("Expected successful cleanup, got error: %v", err)
}
if len(refreshRepo.tokens) != 1 {
t.Errorf("Expected 1 token after cleanup, got %d", len(refreshRepo.tokens))
}
tokenHash := jwtService.hashToken(validToken)
_, exists := refreshRepo.tokens[tokenHash]
if !exists {
t.Error("Expected valid token to remain after cleanup")
}
})
}
func TestJWTService_KeyRotation(t *testing.T) {
cfg := &config.JWTConfig{
Secret: "old-secret-key-that-is-long-enough-for-security",
Expiration: 1,
RefreshExpiration: 24,
Issuer: "test-issuer",
Audience: "test-audience",
KeyRotation: config.KeyRotationConfig{
Enabled: true,
CurrentKey: "current-key-that-is-long-enough-for-security",
PreviousKey: "previous-key-that-is-long-enough-for-security",
KeyID: "current-key-id",
},
}
userRepo := &jwtMockUserRepo{
users: make(map[uint]*database.User),
}
refreshRepo := &jwtMockRefreshTokenRepo{
tokens: make(map[string]*database.RefreshToken),
}
jwtService := NewJWTService(cfg, userRepo, refreshRepo)
user := createTestUser()
userRepo.users[user.ID] = user
t.Run("Generate_Token_With_Key_Rotation", func(t *testing.T) {
token, err := jwtService.GenerateAccessToken(user)
if err != nil {
t.Fatalf("Expected successful token generation with key rotation, got error: %v", err)
}
parsedToken, err := jwt.Parse(token, func(token *jwt.Token) (any, error) {
return []byte(cfg.KeyRotation.CurrentKey), nil
})
if err != nil {
t.Fatalf("Failed to parse token with key rotation: %v", err)
}
if !parsedToken.Valid {
t.Error("Generated token should be valid")
}
if kid, ok := parsedToken.Header["kid"].(string); !ok || kid != cfg.KeyRotation.KeyID {
t.Errorf("Expected key ID %s, got %v", cfg.KeyRotation.KeyID, parsedToken.Header["kid"])
}
})
t.Run("Verify_Token_With_Current_Key", func(t *testing.T) {
token, err := jwtService.GenerateAccessToken(user)
if err != nil {
t.Fatalf("Failed to generate token: %v", err)
}
userID, err := jwtService.VerifyAccessToken(token)
if err != nil {
t.Fatalf("Expected successful token verification with current key, got error: %v", err)
}
if userID != user.ID {
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
}
})
t.Run("Legacy_Token_Without_KID_Remains_Valid", func(t *testing.T) {
legacyCfg := &config.JWTConfig{
Secret: "legacy-secret-key-that-is-long-enough-for-security",
Expiration: 1,
RefreshExpiration: 24,
Issuer: "legacy-issuer",
Audience: "legacy-audience",
KeyRotation: config.KeyRotationConfig{Enabled: false},
}
legacyUserRepo := &jwtMockUserRepo{users: map[uint]*database.User{user.ID: user}}
legacyRefreshRepo := &jwtMockRefreshTokenRepo{tokens: make(map[string]*database.RefreshToken)}
legacyService := NewJWTService(legacyCfg, legacyUserRepo, legacyRefreshRepo)
legacyToken, err := legacyService.GenerateAccessToken(user)
if err != nil {
t.Fatalf("Failed to generate legacy token: %v", err)
}
legacyCfg.KeyRotation.Enabled = true
legacyCfg.KeyRotation.CurrentKey = "rotated-current-key-that-is-long-enough-for-security"
legacyCfg.KeyRotation.PreviousKey = legacyCfg.Secret
legacyCfg.KeyRotation.KeyID = "rotated-key-id"
parsedUserID, err := legacyService.VerifyAccessToken(legacyToken)
if err != nil {
t.Fatalf("Legacy token should remain valid after enabling rotation: %v", err)
}
if parsedUserID != user.ID {
t.Fatalf("Expected user ID %d, got %d", user.ID, parsedUserID)
}
})
t.Run("Legacy_Token_With_Previous_KID_Remains_Valid", func(t *testing.T) {
rotCfg := &config.JWTConfig{
Secret: "unused-secret-key-that-is-long-enough-for-security",
Expiration: 1,
RefreshExpiration: 24,
Issuer: "rotation-issuer",
Audience: "rotation-audience",
KeyRotation: config.KeyRotationConfig{
Enabled: true,
CurrentKey: "rotation-key-v1-that-is-long-enough-for-security",
KeyID: "key-id-v1",
},
}
rotUserRepo := &jwtMockUserRepo{users: map[uint]*database.User{user.ID: user}}
rotRefreshRepo := &jwtMockRefreshTokenRepo{tokens: make(map[string]*database.RefreshToken)}
rotService := NewJWTService(rotCfg, rotUserRepo, rotRefreshRepo)
tokenV1, err := rotService.GenerateAccessToken(user)
if err != nil {
t.Fatalf("Failed to generate v1 token: %v", err)
}
rotCfg.KeyRotation.PreviousKey = rotCfg.KeyRotation.CurrentKey
rotCfg.KeyRotation.CurrentKey = "rotation-key-v2-that-is-long-enough-for-security"
rotCfg.KeyRotation.KeyID = "key-id-v2"
parsedUserID, err := rotService.VerifyAccessToken(tokenV1)
if err != nil {
t.Fatalf("Token signed with previous key should remain valid: %v", err)
}
if parsedUserID != user.ID {
t.Fatalf("Expected user ID %d, got %d", user.ID, parsedUserID)
}
})
t.Run("Unknown_KID_Is_Rejected", func(t *testing.T) {
cfg := &config.JWTConfig{
Secret: "unused-secret-key-that-is-long-enough-for-security",
Expiration: 1,
RefreshExpiration: 24,
Issuer: "issuer",
Audience: "audience",
KeyRotation: config.KeyRotationConfig{
Enabled: true,
CurrentKey: "current-key-for-unknown-kid-test-that-is-long-enough",
KeyID: "expected-key-id",
},
}
repo := &jwtMockUserRepo{users: map[uint]*database.User{user.ID: user}}
refreshRepo := &jwtMockRefreshTokenRepo{tokens: make(map[string]*database.RefreshToken)}
service := NewJWTService(cfg, repo, refreshRepo)
claims := TokenClaims{
UserID: user.ID,
Username: user.Username,
SessionVersion: user.SessionVersion,
TokenType: TokenTypeAccess,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: cfg.Issuer,
Audience: []string{cfg.Audience},
Subject: fmt.Sprint(user.ID),
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["kid"] = "unexpected-key-id"
tokenString, err := token.SignedString([]byte(cfg.KeyRotation.CurrentKey))
if err != nil {
t.Fatalf("Failed to sign token: %v", err)
}
_, err = service.VerifyAccessToken(tokenString)
if !errors.Is(err, ErrInvalidKeyID) {
t.Fatalf("Expected ErrInvalidKeyID, got %v", err)
}
})
}
func TestJWTService_ErrorHandling(t *testing.T) {
jwtService, _, _ := createTestJWTService()
t.Run("Invalid_Issuer", func(t *testing.T) {
claims := TokenClaims{
UserID: 1,
Username: "testuser",
SessionVersion: 1,
TokenType: TokenTypeAccess,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "wrong-issuer",
Audience: []string{jwtService.config.Audience},
Subject: "1",
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(jwtService.config.Secret))
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
_, err = jwtService.VerifyAccessToken(tokenString)
if err == nil {
t.Error("Expected error for invalid issuer")
}
if !errors.Is(err, ErrInvalidIssuer) {
t.Errorf("Expected ErrInvalidIssuer, got %v", err)
}
})
t.Run("Invalid_Audience", func(t *testing.T) {
claims := TokenClaims{
UserID: 1,
Username: "testuser",
SessionVersion: 1,
TokenType: TokenTypeAccess,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: jwtService.config.Issuer,
Audience: []string{"wrong-audience"},
Subject: "1",
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(jwtService.config.Secret))
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
_, err = jwtService.VerifyAccessToken(tokenString)
if err == nil {
t.Error("Expected error for invalid audience")
}
if !errors.Is(err, ErrInvalidAudience) {
t.Errorf("Expected ErrInvalidAudience, got %v", err)
}
})
t.Run("Expired_Token", func(t *testing.T) {
claims := TokenClaims{
UserID: 1,
Username: "testuser",
SessionVersion: 1,
TokenType: TokenTypeAccess,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: jwtService.config.Issuer,
Audience: []string{jwtService.config.Audience},
Subject: "1",
IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(jwtService.config.Secret))
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
_, err = jwtService.VerifyAccessToken(tokenString)
if err == nil {
t.Error("Expected error for expired token")
}
if !errors.Is(err, ErrTokenExpired) {
t.Errorf("Expected ErrTokenExpired, got %v", err)
}
})
t.Run("Subject_Mismatch", func(t *testing.T) {
claims := TokenClaims{
UserID: 1,
Username: "testuser",
SessionVersion: 1,
TokenType: TokenTypeAccess,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: jwtService.config.Issuer,
Audience: []string{jwtService.config.Audience},
Subject: "999",
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(jwtService.config.Secret))
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
_, err = jwtService.VerifyAccessToken(tokenString)
if !errors.Is(err, ErrInvalidToken) {
t.Fatalf("Expected ErrInvalidToken for subject mismatch, got %v", err)
}
})
}
func TestJWTService_HelperFunctions(t *testing.T) {
jwtService, _, _ := createTestJWTService()
t.Run("HashToken", func(t *testing.T) {
token := "test-token"
hash1 := jwtService.hashToken(token)
hash2 := jwtService.hashToken(token)
if hash1 != hash2 {
t.Error("Hash should be deterministic")
}
if hash1 == token {
t.Error("Hash should be different from original token")
}
hash3 := jwtService.hashToken("different-token")
if hash1 == hash3 {
t.Error("Different tokens should produce different hashes")
}
})
t.Run("Contains", func(t *testing.T) {
slice := []string{"item1", "item2", "item3"}
if !slices.Contains(slice, "item1") {
t.Error("Should contain item1")
}
if !slices.Contains(slice, "item2") {
t.Error("Should contain item2")
}
if slices.Contains(slice, "item4") {
t.Error("Should not contain item4")
}
if slices.Contains(slice, "") {
t.Error("Should not contain empty string")
}
})
}
func TestJWTService_Integration(t *testing.T) {
jwtService, userRepo, _ := createTestJWTService()
user := createTestUser()
userRepo.users[user.ID] = user
t.Run("Complete_Flow", func(t *testing.T) {
accessToken, err := jwtService.GenerateAccessToken(user)
if err != nil {
t.Fatalf("Failed to generate access token: %v", err)
}
refreshToken, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("Failed to generate refresh token: %v", err)
}
userID, err := jwtService.VerifyAccessToken(accessToken)
if err != nil {
t.Fatalf("Failed to verify access token: %v", err)
}
if userID != user.ID {
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
}
newAccessToken, err := jwtService.RefreshAccessToken(refreshToken)
if err != nil {
t.Fatalf("Failed to refresh access token: %v", err)
}
userID, err = jwtService.VerifyAccessToken(newAccessToken)
if err != nil {
t.Fatalf("Failed to verify new access token: %v", err)
}
if userID != user.ID {
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
}
err = jwtService.RevokeRefreshToken(refreshToken)
if err != nil {
t.Fatalf("Failed to revoke refresh token: %v", err)
}
_, err = jwtService.RefreshAccessToken(refreshToken)
if err == nil {
t.Error("Expected error when using revoked refresh token")
}
if !errors.Is(err, ErrRefreshTokenInvalid) {
t.Errorf("Expected ErrRefreshTokenInvalid, got %v", err)
}
})
}

View File

@@ -0,0 +1,135 @@
package services
import (
"fmt"
"time"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/validation"
)
type PasswordResetService struct {
userRepo repositories.UserRepository
emailService *EmailService
}
func NewPasswordResetService(userRepo repositories.UserRepository, emailService *EmailService) *PasswordResetService {
return &PasswordResetService{
userRepo: userRepo,
emailService: emailService,
}
}
func (s *PasswordResetService) RequestPasswordReset(usernameOrEmail string) error {
trimmed := TrimString(usernameOrEmail)
if trimmed == "" {
return fmt.Errorf("username or email is required")
}
var user *database.User
var err error
normalized, emailErr := normalizeEmail(trimmed)
if emailErr == nil {
user, err = s.userRepo.GetByEmail(normalized)
if err != nil && !IsRecordNotFound(err) {
return fmt.Errorf("lookup user by email: %w", err)
}
}
if user == nil {
user, err = s.userRepo.GetByUsername(trimmed)
if err != nil {
if IsRecordNotFound(err) {
return nil
}
return fmt.Errorf("lookup user by username: %w", err)
}
}
token, hashed, err := generateVerificationToken()
if err != nil {
return err
}
now := time.Now()
expiresAt := now.Add(time.Duration(defaultTokenExpirationHours) * time.Hour)
user.PasswordResetToken = hashed
user.PasswordResetSentAt = &now
user.PasswordResetExpiresAt = &expiresAt
if err := s.userRepo.Update(user); err != nil {
return fmt.Errorf("update user: %w", err)
}
if err := s.emailService.SendPasswordResetEmail(user, token); err != nil {
user.PasswordResetToken = ""
user.PasswordResetSentAt = nil
user.PasswordResetExpiresAt = nil
_ = s.userRepo.Update(user)
return fmt.Errorf("send password reset email: %w", err)
}
return nil
}
func (s *PasswordResetService) GetUserByResetToken(token string) (*database.User, error) {
trimmed := TrimString(token)
if trimmed == "" {
return nil, fmt.Errorf("reset token is required")
}
hashed := HashVerificationToken(trimmed)
user, err := s.userRepo.GetByPasswordResetToken(hashed)
if err != nil {
if IsRecordNotFound(err) {
return nil, fmt.Errorf("invalid or expired reset token")
}
return nil, fmt.Errorf("lookup reset token: %w", err)
}
if user.PasswordResetExpiresAt == nil || time.Now().After(*user.PasswordResetExpiresAt) {
return nil, fmt.Errorf("invalid or expired reset token")
}
return user, nil
}
func (s *PasswordResetService) ResetPassword(token, newPassword string) error {
if err := validation.ValidatePassword(newPassword); err != nil {
return err
}
user, err := s.GetUserByResetToken(token)
if err != nil {
hashed := HashVerificationToken(TrimString(token))
expiredUser, lookupErr := s.userRepo.GetByPasswordResetToken(hashed)
if lookupErr == nil && expiredUser != nil {
if expiredUser.PasswordResetExpiresAt == nil || time.Now().After(*expiredUser.PasswordResetExpiresAt) {
expiredUser.PasswordResetToken = ""
expiredUser.PasswordResetSentAt = nil
expiredUser.PasswordResetExpiresAt = nil
_ = s.userRepo.Update(expiredUser)
}
}
return err
}
hashedPassword, err := HashPassword(newPassword, DefaultBcryptCost)
if err != nil {
return err
}
user.Password = string(hashedPassword)
user.PasswordResetToken = ""
user.PasswordResetSentAt = nil
user.PasswordResetExpiresAt = nil
if err := s.userRepo.Update(user); err != nil {
return fmt.Errorf("update password: %w", err)
}
return nil
}

View File

@@ -0,0 +1,417 @@
package services
import (
"errors"
"testing"
"time"
"golang.org/x/crypto/bcrypt"
"goyco/internal/database"
"goyco/internal/testutils"
)
func TestNewPasswordResetService(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
service := NewPasswordResetService(userRepo, emailService)
if service == nil {
t.Fatal("expected service to be created")
}
if service.userRepo != userRepo {
t.Error("expected userRepo to be set")
}
if service.emailService != emailService {
t.Error("expected emailService to be set")
}
}
func TestPasswordResetService_RequestPasswordReset(t *testing.T) {
tests := []struct {
name string
usernameOrEmail string
setupMocks func() (*testutils.MockUserRepository, EmailSender)
expectedError bool
shouldSendEmail bool
}{
{
name: "successful request by username",
usernameOrEmail: "testuser",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
}
userRepo.Create(user)
emailSender := &testutils.MockEmailSender{}
return userRepo, emailSender
},
expectedError: false,
shouldSendEmail: true,
},
{
name: "successful request by email",
usernameOrEmail: "test@example.com",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
}
userRepo.Create(user)
emailSender := &testutils.MockEmailSender{}
return userRepo, emailSender
},
expectedError: false,
shouldSendEmail: true,
},
{
name: "empty input",
usernameOrEmail: "",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
return testutils.NewMockUserRepository(), &testutils.MockEmailSender{}
},
expectedError: true,
shouldSendEmail: false,
},
{
name: "whitespace only input",
usernameOrEmail: " ",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
return testutils.NewMockUserRepository(), &testutils.MockEmailSender{}
},
expectedError: true,
shouldSendEmail: false,
},
{
name: "user not found",
usernameOrEmail: "nonexistent",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
return testutils.NewMockUserRepository(), &testutils.MockEmailSender{}
},
expectedError: false,
shouldSendEmail: false,
},
{
name: "email service error",
usernameOrEmail: "testuser",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
}
userRepo.Create(user)
var errorSender errorEmailSender
errorSender.err = errors.New("email service error")
emailSender := &errorSender
return userRepo, emailSender
},
expectedError: true,
shouldSendEmail: false,
},
{
name: "prefers email over username",
usernameOrEmail: "test@example.com",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
}
userRepo.Create(user)
emailSender := &testutils.MockEmailSender{}
return userRepo, emailSender
},
expectedError: false,
shouldSendEmail: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo, emailSender := tt.setupMocks()
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
service := NewPasswordResetService(userRepo, emailService)
err := service.RequestPasswordReset(tt.usernameOrEmail)
if tt.expectedError {
if err == nil {
t.Error("expected error but got none")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if tt.shouldSendEmail {
user, _ := userRepo.GetByUsername("testuser")
if user == nil {
user, _ = userRepo.GetByEmail("test@example.com")
}
if user != nil && user.PasswordResetToken == "" {
t.Error("expected password reset token to be set")
}
}
}
})
}
}
func TestPasswordResetService_ResetPassword(t *testing.T) {
tests := []struct {
name string
token string
newPassword string
setupMocks func() (*testutils.MockUserRepository, EmailSender)
expectedError bool
verifyPassword bool
}{
{
name: "successful password reset",
token: "valid-token",
newPassword: "NewSecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
userRepo := testutils.NewMockUserRepository()
expiresAt := time.Now().Add(time.Hour)
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
PasswordResetToken: HashVerificationToken("valid-token"),
PasswordResetExpiresAt: &expiresAt,
}
userRepo.Create(user)
return userRepo, &testutils.MockEmailSender{}
},
expectedError: false,
verifyPassword: true,
},
{
name: "empty token",
token: "",
newPassword: "NewSecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
return testutils.NewMockUserRepository(), &testutils.MockEmailSender{}
},
expectedError: true,
verifyPassword: false,
},
{
name: "whitespace only token",
token: " ",
newPassword: "NewSecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
return testutils.NewMockUserRepository(), &testutils.MockEmailSender{}
},
expectedError: true,
verifyPassword: false,
},
{
name: "invalid token",
token: "invalid-token",
newPassword: "NewSecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
return testutils.NewMockUserRepository(), &testutils.MockEmailSender{}
},
expectedError: true,
verifyPassword: false,
},
{
name: "expired token",
token: "expired-token",
newPassword: "NewSecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
userRepo := testutils.NewMockUserRepository()
expiresAt := time.Now().Add(-time.Hour)
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
PasswordResetToken: HashVerificationToken("expired-token"),
PasswordResetExpiresAt: &expiresAt,
}
userRepo.Create(user)
return userRepo, &testutils.MockEmailSender{}
},
expectedError: true,
verifyPassword: false,
},
{
name: "nil expiration date",
token: "valid-token",
newPassword: "NewSecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
PasswordResetToken: HashVerificationToken("valid-token"),
}
userRepo.Create(user)
return userRepo, &testutils.MockEmailSender{}
},
expectedError: true,
verifyPassword: false,
},
{
name: "invalid password",
token: "valid-token",
newPassword: "short",
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
userRepo := testutils.NewMockUserRepository()
expiresAt := time.Now().Add(time.Hour)
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
PasswordResetToken: HashVerificationToken("valid-token"),
PasswordResetExpiresAt: &expiresAt,
}
userRepo.Create(user)
return userRepo, &testutils.MockEmailSender{}
},
expectedError: true,
verifyPassword: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo, emailSender := tt.setupMocks()
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
service := NewPasswordResetService(userRepo, emailService)
err := service.ResetPassword(tt.token, tt.newPassword)
if tt.expectedError {
if err == nil {
t.Error("expected error but got none")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if tt.verifyPassword {
user, _ := userRepo.GetByUsername("testuser")
if user == nil {
t.Fatal("expected user to exist")
}
if user.PasswordResetToken != "" {
t.Error("expected password reset token to be cleared")
}
if user.PasswordResetExpiresAt != nil {
t.Error("expected password reset expiration to be cleared")
}
if user.Password == "" {
t.Error("expected password to be set")
}
err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(tt.newPassword))
if err != nil {
t.Errorf("password hash verification failed: %v", err)
}
}
}
})
}
}
func TestPasswordResetService_ResetPassword_TokenClearedAfterExpiration(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
expiresAt := time.Now().Add(-time.Hour)
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
PasswordResetToken: HashVerificationToken("expired-token"),
PasswordResetExpiresAt: &expiresAt,
}
userRepo.Create(user)
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
service := NewPasswordResetService(userRepo, emailService)
err := service.ResetPassword("expired-token", "NewSecurePass123!")
if err == nil {
t.Error("expected error for expired token")
}
updatedUser, _ := userRepo.GetByID(1)
if updatedUser == nil {
t.Fatal("expected user to exist")
}
if updatedUser.PasswordResetToken != "" {
t.Error("expected password reset token to be cleared after expiration")
}
if updatedUser.PasswordResetExpiresAt != nil {
t.Error("expected password reset expiration to be cleared after expiration")
}
}
func TestPasswordResetService_RequestPasswordReset_EmailFailureRollback(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
}
userRepo.Create(user)
emailSender := &errorEmailSender{err: errors.New("email service error")}
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
service := NewPasswordResetService(userRepo, emailService)
err := service.RequestPasswordReset("testuser")
if err == nil {
t.Error("expected error when email fails")
}
updatedUser, _ := userRepo.GetByID(1)
if updatedUser == nil {
t.Fatal("expected user to exist")
}
if updatedUser.PasswordResetToken != "" {
t.Error("expected password reset token to be rolled back on email failure")
}
if updatedUser.PasswordResetSentAt != nil {
t.Error("expected password reset sent at to be rolled back on email failure")
}
if updatedUser.PasswordResetExpiresAt != nil {
t.Error("expected password reset expiration to be rolled back on email failure")
}
}

View File

@@ -0,0 +1,123 @@
package services
import (
"goyco/internal/database"
"goyco/internal/repositories"
)
type PostQueries struct {
postRepo repositories.PostRepository
voteService *VoteService
}
func NewPostQueries(postRepo repositories.PostRepository, voteService *VoteService) *PostQueries {
return &PostQueries{
postRepo: postRepo,
voteService: voteService,
}
}
type QueryOptions struct {
Limit int
Offset int
Sort string
}
type VoteContext struct {
UserID uint
IPAddress string
UserAgent string
}
func (pq *PostQueries) enrichPostsWithVotes(posts []database.Post, ctx VoteContext) []database.Post {
if pq.voteService == nil {
return posts
}
enriched := make([]database.Post, len(posts))
for i := range posts {
enriched[i] = posts[i]
vote, err := pq.voteService.GetUserVote(ctx.UserID, posts[i].ID, ctx.IPAddress, ctx.UserAgent)
if err == nil && vote != nil {
enriched[i].CurrentVote = vote.Type
}
}
return enriched
}
func (pq *PostQueries) enrichPostWithVote(post *database.Post, ctx VoteContext) *database.Post {
if pq.voteService == nil || post == nil {
return post
}
vote, err := pq.voteService.GetUserVote(ctx.UserID, post.ID, ctx.IPAddress, ctx.UserAgent)
if err == nil && vote != nil {
post.CurrentVote = vote.Type
}
return post
}
func (pq *PostQueries) GetAll(opts QueryOptions, ctx VoteContext) ([]database.Post, error) {
posts, err := pq.postRepo.GetAll(opts.Limit, opts.Offset)
if err != nil {
return nil, err
}
return pq.enrichPostsWithVotes(posts, ctx), nil
}
func (pq *PostQueries) GetTop(limit int, ctx VoteContext) ([]database.Post, error) {
posts, err := pq.postRepo.GetTopPosts(limit)
if err != nil {
return nil, err
}
return pq.enrichPostsWithVotes(posts, ctx), nil
}
func (pq *PostQueries) GetNewest(limit int, ctx VoteContext) ([]database.Post, error) {
posts, err := pq.postRepo.GetNewestPosts(limit)
if err != nil {
return nil, err
}
return pq.enrichPostsWithVotes(posts, ctx), nil
}
func (pq *PostQueries) GetBySort(sort string, limit int, ctx VoteContext) ([]database.Post, error) {
switch sort {
case "new", "newest", "latest":
return pq.GetNewest(limit, ctx)
default:
return pq.GetTop(limit, ctx)
}
}
func (pq *PostQueries) GetSearch(query string, opts QueryOptions, ctx VoteContext) ([]database.Post, error) {
posts, err := pq.postRepo.Search(query, opts.Limit, opts.Offset)
if err != nil {
return nil, err
}
return pq.enrichPostsWithVotes(posts, ctx), nil
}
func (pq *PostQueries) GetByID(postID uint, ctx VoteContext) (*database.Post, error) {
post, err := pq.postRepo.GetByID(postID)
if err != nil {
return nil, err
}
return pq.enrichPostWithVote(post, ctx), nil
}
func (pq *PostQueries) GetByUserID(userID uint, opts QueryOptions, ctx VoteContext) ([]database.Post, error) {
posts, err := pq.postRepo.GetByUserID(userID, opts.Limit, opts.Offset)
if err != nil {
return nil, err
}
return pq.enrichPostsWithVotes(posts, ctx), nil
}

View File

@@ -0,0 +1,609 @@
package services
import (
"errors"
"testing"
"time"
"goyco/internal/database"
"goyco/internal/testutils"
"gorm.io/gorm"
)
func TestNewPostQueries(t *testing.T) {
repo := testutils.NewMockPostRepository()
voteService := NewVoteService(testutils.NewMockVoteRepository(), repo, nil)
postQueries := NewPostQueries(repo, voteService)
if postQueries == nil {
t.Fatal("expected PostQueries to be created")
}
if postQueries.postRepo != repo {
t.Error("expected postRepo to be set")
}
if postQueries.voteService != voteService {
t.Error("expected voteService to be set")
}
}
func TestPostQueries_GetAll(t *testing.T) {
tests := []struct {
name string
setupRepo func() *testutils.MockPostRepository
opts QueryOptions
ctx VoteContext
expectedCount int
expectedError bool
}{
{
name: "success with pagination",
setupRepo: func() *testutils.MockPostRepository {
repo := testutils.NewMockPostRepository()
repo.Create(&database.Post{ID: 1, Title: "Post 1", Score: 10})
repo.Create(&database.Post{ID: 2, Title: "Post 2", Score: 5})
repo.Create(&database.Post{ID: 3, Title: "Post 3", Score: 15})
return repo
},
opts: QueryOptions{
Limit: 2,
Offset: 0,
},
ctx: VoteContext{UserID: 1},
expectedCount: 2,
expectedError: false,
},
{
name: "success with offset",
setupRepo: func() *testutils.MockPostRepository {
repo := testutils.NewMockPostRepository()
repo.Create(&database.Post{ID: 1, Title: "Post 1", Score: 10})
repo.Create(&database.Post{ID: 2, Title: "Post 2", Score: 5})
repo.Create(&database.Post{ID: 3, Title: "Post 3", Score: 15})
return repo
},
opts: QueryOptions{
Limit: 2,
Offset: 1,
},
ctx: VoteContext{UserID: 1},
expectedCount: 2,
expectedError: false,
},
{
name: "repository error",
setupRepo: func() *testutils.MockPostRepository {
repo := testutils.NewMockPostRepository()
repo.GetErr = errors.New("database error")
return repo
},
opts: QueryOptions{
Limit: 10,
Offset: 0,
},
ctx: VoteContext{UserID: 1},
expectedCount: 0,
expectedError: true,
},
{
name: "empty result",
setupRepo: func() *testutils.MockPostRepository {
return testutils.NewMockPostRepository()
},
opts: QueryOptions{
Limit: 10,
Offset: 0,
},
ctx: VoteContext{UserID: 1},
expectedCount: 0,
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := tt.setupRepo()
voteRepo := testutils.NewMockVoteRepository()
voteService := NewVoteService(voteRepo, repo, nil)
postQueries := NewPostQueries(repo, voteService)
posts, err := postQueries.GetAll(tt.opts, tt.ctx)
if tt.expectedError {
if err == nil {
t.Error("expected error but got none")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(posts) != tt.expectedCount {
t.Errorf("expected %d posts, got %d", tt.expectedCount, len(posts))
}
}
})
}
}
func TestPostQueries_GetAll_WithVoteEnrichment(t *testing.T) {
repo := testutils.NewMockPostRepository()
post1 := &database.Post{ID: 1, Title: "Post 1", Score: 10}
post2 := &database.Post{ID: 2, Title: "Post 2", Score: 5}
repo.Create(post1)
repo.Create(post2)
voteRepo := testutils.NewMockVoteRepository()
voteService := NewVoteService(voteRepo, repo, nil)
userID := uint(1)
voteRepo.Create(&database.Vote{
UserID: &userID,
PostID: 1,
Type: database.VoteUp,
})
postQueries := NewPostQueries(repo, voteService)
ctx := VoteContext{
UserID: 1,
IPAddress: "127.0.0.1",
UserAgent: "test-agent",
}
posts, err := postQueries.GetAll(QueryOptions{Limit: 10, Offset: 0}, ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(posts) != 2 {
t.Fatalf("expected 2 posts, got %d", len(posts))
}
if posts[0].CurrentVote != database.VoteUp && posts[0].ID == 1 {
if posts[1].ID == 1 && posts[1].CurrentVote != database.VoteUp {
t.Error("expected post 1 to have CurrentVote set to VoteUp")
}
}
for _, post := range posts {
if post.ID == 2 && post.CurrentVote != "" {
t.Errorf("expected post 2 to have no vote, got %s", post.CurrentVote)
}
}
}
func TestPostQueries_GetTop(t *testing.T) {
repo := testutils.NewMockPostRepository()
post1 := &database.Post{ID: 1, Title: "Post 1", Score: 10}
post2 := &database.Post{ID: 2, Title: "Post 2", Score: 15}
post3 := &database.Post{ID: 3, Title: "Post 3", Score: 5}
repo.Create(post1)
repo.Create(post2)
repo.Create(post3)
voteRepo := testutils.NewMockVoteRepository()
voteService := NewVoteService(voteRepo, repo, nil)
postQueries := NewPostQueries(repo, voteService)
posts, err := postQueries.GetTop(2, VoteContext{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(posts) != 2 {
t.Errorf("expected 2 posts, got %d", len(posts))
}
if len(posts) == 0 {
t.Error("expected at least one post")
}
}
func TestPostQueries_GetNewest(t *testing.T) {
repo := testutils.NewMockPostRepository()
now := time.Now()
post1 := &database.Post{ID: 1, Title: "Post 1", CreatedAt: now.Add(-2 * time.Hour)}
post2 := &database.Post{ID: 2, Title: "Post 2", CreatedAt: now.Add(-1 * time.Hour)}
post3 := &database.Post{ID: 3, Title: "Post 3", CreatedAt: now}
repo.Create(post1)
repo.Create(post2)
repo.Create(post3)
voteRepo := testutils.NewMockVoteRepository()
voteService := NewVoteService(voteRepo, repo, nil)
postQueries := NewPostQueries(repo, voteService)
posts, err := postQueries.GetNewest(2, VoteContext{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(posts) != 2 {
t.Errorf("expected 2 posts, got %d", len(posts))
}
if len(posts) == 0 {
t.Error("expected at least one post")
}
}
func TestPostQueries_GetBySort(t *testing.T) {
repo := testutils.NewMockPostRepository()
post1 := &database.Post{ID: 1, Title: "Post 1", Score: 10}
post2 := &database.Post{ID: 2, Title: "Post 2", Score: 15}
repo.Create(post1)
repo.Create(post2)
voteRepo := testutils.NewMockVoteRepository()
voteService := NewVoteService(voteRepo, repo, nil)
postQueries := NewPostQueries(repo, voteService)
tests := []struct {
name string
sort string
expectTop bool
}{
{"new sort", "new", false},
{"newest sort", "newest", false},
{"latest sort", "latest", false},
{"default sort", "", true},
{"invalid sort", "invalid", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
posts, err := postQueries.GetBySort(tt.sort, 10, VoteContext{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(posts) == 0 {
t.Error("expected at least one post")
}
})
}
}
func TestPostQueries_GetSearch(t *testing.T) {
tests := []struct {
name string
query string
setupRepo func() *testutils.MockPostRepository
opts QueryOptions
ctx VoteContext
expectedCount int
expectedError bool
}{
{
name: "successful search",
query: "test",
setupRepo: func() *testutils.MockPostRepository {
repo := testutils.NewMockPostRepository()
repo.Create(&database.Post{ID: 1, Title: "Test Post", Score: 10})
repo.Create(&database.Post{ID: 2, Title: "Another Post", Score: 5})
return repo
},
opts: QueryOptions{
Limit: 10,
Offset: 0,
},
ctx: VoteContext{},
expectedCount: 1,
expectedError: false,
},
{
name: "search with pagination",
query: "post",
setupRepo: func() *testutils.MockPostRepository {
repo := testutils.NewMockPostRepository()
repo.Create(&database.Post{ID: 1, Title: "Post 1", Score: 10})
repo.Create(&database.Post{ID: 2, Title: "Post 2", Score: 5})
repo.Create(&database.Post{ID: 3, Title: "Post 3", Score: 15})
return repo
},
opts: QueryOptions{
Limit: 2,
Offset: 0,
},
ctx: VoteContext{},
expectedCount: 2,
expectedError: false,
},
{
name: "search error",
query: "test",
setupRepo: func() *testutils.MockPostRepository {
repo := testutils.NewMockPostRepository()
repo.SearchErr = errors.New("search error")
return repo
},
opts: QueryOptions{
Limit: 10,
Offset: 0,
},
ctx: VoteContext{},
expectedCount: 0,
expectedError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := tt.setupRepo()
voteRepo := testutils.NewMockVoteRepository()
voteService := NewVoteService(voteRepo, repo, nil)
postQueries := NewPostQueries(repo, voteService)
posts, err := postQueries.GetSearch(tt.query, tt.opts, tt.ctx)
if tt.expectedError {
if err == nil {
t.Error("expected error but got none")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(posts) < tt.expectedCount {
t.Errorf("expected at least %d posts, got %d", tt.expectedCount, len(posts))
}
}
})
}
}
func TestPostQueries_GetByID(t *testing.T) {
tests := []struct {
name string
postID uint
setupRepo func() *testutils.MockPostRepository
ctx VoteContext
expectedError bool
expectedID uint
}{
{
name: "successful retrieval",
postID: 1,
setupRepo: func() *testutils.MockPostRepository {
repo := testutils.NewMockPostRepository()
repo.Create(&database.Post{ID: 1, Title: "Test Post", Score: 10})
return repo
},
ctx: VoteContext{UserID: 1},
expectedError: false,
expectedID: 1,
},
{
name: "post not found",
postID: 999,
setupRepo: func() *testutils.MockPostRepository {
return testutils.NewMockPostRepository()
},
ctx: VoteContext{UserID: 1},
expectedError: true,
expectedID: 0,
},
{
name: "repository error",
postID: 1,
setupRepo: func() *testutils.MockPostRepository {
repo := testutils.NewMockPostRepository()
repo.GetErr = errors.New("database error")
return repo
},
ctx: VoteContext{UserID: 1},
expectedError: true,
expectedID: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := tt.setupRepo()
voteRepo := testutils.NewMockVoteRepository()
voteService := NewVoteService(voteRepo, repo, nil)
postQueries := NewPostQueries(repo, voteService)
post, err := postQueries.GetByID(tt.postID, tt.ctx)
if tt.expectedError {
if err == nil {
t.Error("expected error but got none")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if post == nil {
t.Fatal("expected post to be returned")
}
if post.ID != tt.expectedID {
t.Errorf("expected post ID %d, got %d", tt.expectedID, post.ID)
}
}
})
}
}
func TestPostQueries_GetByID_WithVoteEnrichment(t *testing.T) {
repo := testutils.NewMockPostRepository()
post := &database.Post{ID: 1, Title: "Test Post", Score: 10}
repo.Create(post)
voteRepo := testutils.NewMockVoteRepository()
voteService := NewVoteService(voteRepo, repo, nil)
userID := uint(1)
voteRepo.Create(&database.Vote{
UserID: &userID,
PostID: 1,
Type: database.VoteDown,
})
postQueries := NewPostQueries(repo, voteService)
ctx := VoteContext{
UserID: 1,
IPAddress: "127.0.0.1",
UserAgent: "test-agent",
}
retrievedPost, err := postQueries.GetByID(1, ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if retrievedPost.CurrentVote != database.VoteDown {
t.Errorf("expected CurrentVote to be VoteDown, got %s", retrievedPost.CurrentVote)
}
}
func TestPostQueries_GetByUserID(t *testing.T) {
tests := []struct {
name string
userID uint
setupRepo func() *testutils.MockPostRepository
opts QueryOptions
ctx VoteContext
expectedCount int
expectedError bool
}{
{
name: "successful retrieval",
userID: 1,
setupRepo: func() *testutils.MockPostRepository {
repo := testutils.NewMockPostRepository()
authorID1 := uint(1)
authorID2 := uint(2)
repo.Create(&database.Post{ID: 1, Title: "User 1 Post", AuthorID: &authorID1, Score: 10})
repo.Create(&database.Post{ID: 2, Title: "User 2 Post", AuthorID: &authorID2, Score: 5})
repo.Create(&database.Post{ID: 3, Title: "User 1 Post 2", AuthorID: &authorID1, Score: 15})
return repo
},
opts: QueryOptions{
Limit: 10,
Offset: 0,
},
ctx: VoteContext{UserID: 1},
expectedCount: 2,
expectedError: false,
},
{
name: "with pagination",
userID: 1,
setupRepo: func() *testutils.MockPostRepository {
repo := testutils.NewMockPostRepository()
authorID := uint(1)
repo.Create(&database.Post{ID: 1, Title: "Post 1", AuthorID: &authorID, Score: 10})
repo.Create(&database.Post{ID: 2, Title: "Post 2", AuthorID: &authorID, Score: 5})
repo.Create(&database.Post{ID: 3, Title: "Post 3", AuthorID: &authorID, Score: 15})
return repo
},
opts: QueryOptions{
Limit: 2,
Offset: 0,
},
ctx: VoteContext{UserID: 1},
expectedCount: 2,
expectedError: false,
},
{
name: "repository error",
userID: 1,
setupRepo: func() *testutils.MockPostRepository {
repo := testutils.NewMockPostRepository()
repo.GetErr = errors.New("database error")
return repo
},
opts: QueryOptions{
Limit: 10,
Offset: 0,
},
ctx: VoteContext{UserID: 1},
expectedCount: 0,
expectedError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := tt.setupRepo()
voteRepo := testutils.NewMockVoteRepository()
voteService := NewVoteService(voteRepo, repo, nil)
postQueries := NewPostQueries(repo, voteService)
posts, err := postQueries.GetByUserID(tt.userID, tt.opts, tt.ctx)
if tt.expectedError {
if err == nil {
t.Error("expected error but got none")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(posts) < tt.expectedCount {
t.Errorf("expected at least %d posts, got %d", tt.expectedCount, len(posts))
}
}
})
}
}
func TestPostQueries_WithoutVoteService(t *testing.T) {
repo := testutils.NewMockPostRepository()
repo.Create(&database.Post{ID: 1, Title: "Test Post", Score: 10})
postQueries := NewPostQueries(repo, nil)
ctx := VoteContext{
UserID: 1,
IPAddress: "127.0.0.1",
UserAgent: "test-agent",
}
post, err := postQueries.GetByID(1, ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if post == nil {
t.Fatal("expected post to be returned")
}
if post.CurrentVote != "" {
t.Errorf("expected CurrentVote to be empty when voteService is nil, got %s", post.CurrentVote)
}
}
func TestPostQueries_WithIPBasedVote(t *testing.T) {
repo := testutils.NewMockPostRepository()
post := &database.Post{ID: 1, Title: "Test Post", Score: 10}
repo.Create(post)
voteRepo := testutils.NewMockVoteRepository()
voteService := NewVoteService(voteRepo, repo, nil)
voteHash := voteService.GenerateVoteHash("127.0.0.1", "test-agent", 1)
voteRepo.Create(&database.Vote{
PostID: 1,
Type: database.VoteUp,
VoteHash: &voteHash,
})
postQueries := NewPostQueries(repo, voteService)
ctx := VoteContext{
UserID: 0,
IPAddress: "127.0.0.1",
UserAgent: "test-agent",
}
retrievedPost, err := postQueries.GetByID(1, ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if retrievedPost.CurrentVote != database.VoteUp {
t.Errorf("expected CurrentVote to be VoteUp for IP-based vote, got %s", retrievedPost.CurrentVote)
}
}
func TestPostQueries_EnrichPostsWithVotes_NoVotes(t *testing.T) {
repo := testutils.NewMockPostRepository()
post1 := &database.Post{ID: 1, Title: "Post 1", Score: 10}
post2 := &database.Post{ID: 2, Title: "Post 2", Score: 5}
repo.Create(post1)
repo.Create(post2)
voteRepo := testutils.NewMockVoteRepository()
voteService := NewVoteService(voteRepo, repo, nil)
postQueries := NewPostQueries(repo, voteService)
ctx := VoteContext{
UserID: 1,
IPAddress: "127.0.0.1",
UserAgent: "test-agent",
}
posts, err := postQueries.GetAll(QueryOptions{Limit: 10, Offset: 0}, ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(posts) != 2 {
t.Errorf("expected 2 posts, got %d", len(posts))
}
for _, post := range posts {
if post.CurrentVote != "" {
t.Errorf("expected CurrentVote to be empty when no votes exist, got %s", post.CurrentVote)
}
}
}
func TestPostQueries_GetByID_NotFound(t *testing.T) {
repo := testutils.NewMockPostRepository()
voteRepo := testutils.NewMockVoteRepository()
voteService := NewVoteService(voteRepo, repo, nil)
postQueries := NewPostQueries(repo, voteService)
post, err := postQueries.GetByID(999, VoteContext{})
if err == nil {
t.Fatal("expected error for non-existent post")
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
t.Errorf("expected gorm.ErrRecordNotFound, got %v", err)
}
if post != nil {
t.Error("expected nil post when not found")
}
}

View File

@@ -0,0 +1,178 @@
package services
import (
"fmt"
"time"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/validation"
)
type RegistrationService struct {
userRepo repositories.UserRepository
emailService *EmailService
config *config.Config
}
func NewRegistrationService(userRepo repositories.UserRepository, emailService *EmailService, config *config.Config) *RegistrationService {
return &RegistrationService{
userRepo: userRepo,
emailService: emailService,
config: config,
}
}
func (s *RegistrationService) Register(username, email, password string) (*RegistrationResult, error) {
trimmedUsername := TrimString(username)
if err := validation.ValidateUsername(trimmedUsername); err != nil {
return nil, err
}
if err := validation.ValidatePassword(password); err != nil {
return nil, err
}
normalizedEmail, err := normalizeEmail(email)
if err != nil {
return nil, err
}
userCheck, err := s.userRepo.GetByUsername(trimmedUsername)
if err == nil {
if userCheck != nil {
return nil, ErrUsernameTaken
}
} else if !IsRecordNotFound(err) {
if handled := HandleUniqueConstraintError(err); handled != err {
return nil, handled
}
return nil, fmt.Errorf("lookup user: %w", err)
}
emailCheck, err := s.userRepo.GetByEmail(normalizedEmail)
if err == nil {
if emailCheck != nil {
return nil, ErrEmailTaken
}
} else if !IsRecordNotFound(err) {
if handled := HandleUniqueConstraintError(err); handled != err {
return nil, handled
}
return nil, fmt.Errorf("lookup email: %w", err)
}
hashedPassword, err := HashPassword(password, s.config.App.BcryptCost)
if err != nil {
return nil, err
}
token, hashedToken, err := generateVerificationToken()
if err != nil {
return nil, err
}
now := time.Now()
user := &database.User{
Username: trimmedUsername,
Email: normalizedEmail,
Password: string(hashedPassword),
EmailVerified: false,
EmailVerificationToken: hashedToken,
EmailVerificationSentAt: &now,
}
if err := s.userRepo.Create(user); err != nil {
if handled := HandleUniqueConstraintErrorWithMessage(err); handled != err {
return nil, handled
}
return nil, fmt.Errorf("create user: %w", err)
}
if err := s.emailService.SendVerificationEmail(user, token); err != nil {
if deleteErr := s.userRepo.HardDelete(user.ID); deleteErr != nil {
return nil, fmt.Errorf("verification email failed and user cleanup failed: email=%w, cleanup=%v", err, deleteErr)
}
return nil, fmt.Errorf("verification email failed: %w", err)
}
return &RegistrationResult{
User: sanitizeUser(user),
VerificationSent: true,
}, nil
}
func (s *RegistrationService) ConfirmEmail(token string) (*database.User, error) {
trimmed := TrimString(token)
if trimmed == "" {
return nil, ErrInvalidVerificationToken
}
hashed := HashVerificationToken(trimmed)
user, err := s.userRepo.GetByVerificationToken(hashed)
if err != nil {
if IsRecordNotFound(err) {
return nil, ErrInvalidVerificationToken
}
return nil, fmt.Errorf("lookup verification token: %w", err)
}
if user.EmailVerified {
return sanitizeUser(user), nil
}
now := time.Now()
user.EmailVerified = true
user.EmailVerifiedAt = &now
user.EmailVerificationToken = ""
user.EmailVerificationSentAt = nil
if err := s.userRepo.Update(user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
return sanitizeUser(user), nil
}
func (s *RegistrationService) ResendVerificationEmail(email string) error {
email = TrimString(email)
if err := validation.ValidateEmail(email); err != nil {
return ErrInvalidEmail
}
user, err := s.userRepo.GetByEmail(email)
if err != nil {
if IsRecordNotFound(err) {
return ErrInvalidCredentials
}
return fmt.Errorf("lookup user: %w", err)
}
if user.EmailVerified {
return fmt.Errorf("email already verified")
}
if user.EmailVerificationSentAt != nil && time.Since(*user.EmailVerificationSentAt) < 5*time.Minute {
return fmt.Errorf("verification email sent recently, please wait before requesting another")
}
token, hash, err := generateVerificationToken()
if err != nil {
return err
}
now := time.Now()
user.EmailVerificationToken = hash
user.EmailVerificationSentAt = &now
if err := s.userRepo.Update(user); err != nil {
return fmt.Errorf("update user: %w", err)
}
if err := s.emailService.SendResendVerificationEmail(user, token); err != nil {
return fmt.Errorf("send verification email: %w", err)
}
return nil
}

View File

@@ -0,0 +1,579 @@
package services
import (
"errors"
"testing"
"time"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/testutils"
)
func TestNewRegistrationService(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
cfg := testutils.AppTestConfig
service := NewRegistrationService(userRepo, emailService, cfg)
if service == nil {
t.Fatal("expected service to be created")
}
if service.userRepo != userRepo {
t.Error("expected userRepo to be set")
}
if service.emailService != emailService {
t.Error("expected emailService to be set")
}
if service.config != cfg {
t.Error("expected config to be set")
}
}
func TestRegistrationService_Register(t *testing.T) {
tests := []struct {
name string
username string
email string
password string
setupMocks func() (*testutils.MockUserRepository, *EmailService, *config.Config)
expectedError error
checkResult func(*testing.T, *RegistrationResult)
}{
{
name: "successful registration",
username: "testuser",
email: "test@example.com",
password: "SecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
emailSender := &testutils.MockEmailSender{}
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
checkResult: func(t *testing.T, result *RegistrationResult) {
if result == nil {
t.Fatal("expected non-nil result")
}
if result.User == nil {
t.Fatal("expected non-nil user")
}
if result.User.Username != "testuser" {
t.Errorf("expected username 'testuser', got %q", result.User.Username)
}
if result.User.Email != "test@example.com" {
t.Errorf("expected email 'test@example.com', got %q", result.User.Email)
}
if result.User.Password != "" {
t.Error("expected password to be sanitized")
}
if !result.VerificationSent {
t.Error("expected VerificationSent to be true")
}
if result.User.EmailVerified {
t.Error("expected EmailVerified to be false")
}
},
},
{
name: "invalid username",
username: "",
email: "test@example.com",
password: "SecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
checkResult: nil,
},
{
name: "invalid password",
username: "testuser",
email: "test@example.com",
password: "short",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
checkResult: nil,
},
{
name: "invalid email",
username: "testuser",
email: "invalid-email",
password: "SecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
checkResult: nil,
},
{
name: "username already taken",
username: "existinguser",
email: "test@example.com",
password: "SecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
existingUser := &database.User{
ID: 1,
Username: "existinguser",
Email: "existing@example.com",
}
userRepo.Create(existingUser)
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: ErrUsernameTaken,
checkResult: nil,
},
{
name: "email already taken",
username: "testuser",
email: "existing@example.com",
password: "SecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
existingUser := &database.User{
ID: 1,
Username: "existinguser",
Email: "existing@example.com",
}
userRepo.Create(existingUser)
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: ErrEmailTaken,
checkResult: nil,
},
{
name: "email service error",
username: "testuser",
email: "test@example.com",
password: "SecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
errorSender := &errorEmailSender{err: errors.New("email service error")}
emailService, _ := NewEmailService(testutils.AppTestConfig, errorSender)
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
checkResult: nil,
},
{
name: "trims username whitespace",
username: " testuser ",
email: "test@example.com",
password: "SecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
emailSender := &testutils.MockEmailSender{}
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
checkResult: func(t *testing.T, result *RegistrationResult) {
if result.User.Username != "testuser" {
t.Errorf("expected trimmed username 'testuser', got %q", result.User.Username)
}
},
},
{
name: "normalizes email",
username: "testuser",
email: "TEST@EXAMPLE.COM",
password: "SecurePass123!",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
emailSender := &testutils.MockEmailSender{}
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
checkResult: func(t *testing.T, result *RegistrationResult) {
if result.User.Email != "test@example.com" {
t.Errorf("expected normalized email 'test@example.com', got %q", result.User.Email)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo, emailService, cfg := tt.setupMocks()
service := NewRegistrationService(userRepo, emailService, cfg)
result, err := service.Register(tt.username, tt.email, tt.password)
if tt.expectedError != nil {
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
return
}
if err != nil {
if tt.checkResult == nil {
return
}
t.Fatalf("unexpected error: %v", err)
}
if tt.checkResult != nil {
tt.checkResult(t, result)
}
})
}
}
func TestRegistrationService_ConfirmEmail(t *testing.T) {
tests := []struct {
name string
token string
setupMocks func() (*testutils.MockUserRepository, *EmailService, *config.Config)
expectedError error
checkResult func(*testing.T, *database.User)
}{
{
name: "successful confirmation",
token: "valid-token",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
hashedToken := HashVerificationToken("valid-token")
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
EmailVerified: false,
EmailVerificationToken: hashedToken,
}
userRepo.Create(user)
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
checkResult: func(t *testing.T, user *database.User) {
if user == nil {
t.Fatal("expected non-nil user")
}
if !user.EmailVerified {
t.Error("expected EmailVerified to be true")
}
if user.EmailVerificationToken != "" {
t.Error("expected EmailVerificationToken to be cleared")
}
if user.EmailVerificationSentAt != nil {
t.Error("expected EmailVerificationSentAt to be nil")
}
if user.EmailVerifiedAt == nil {
t.Error("expected EmailVerifiedAt to be set")
}
},
},
{
name: "empty token",
token: "",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: ErrInvalidVerificationToken,
checkResult: nil,
},
{
name: "whitespace only token",
token: " ",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: ErrInvalidVerificationToken,
checkResult: nil,
},
{
name: "invalid token",
token: "invalid-token",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: ErrInvalidVerificationToken,
checkResult: nil,
},
{
name: "already verified",
token: "valid-token",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
hashedToken := HashVerificationToken("valid-token")
now := time.Now()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
EmailVerified: true,
EmailVerifiedAt: &now,
EmailVerificationToken: hashedToken,
}
userRepo.Create(user)
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
checkResult: func(t *testing.T, user *database.User) {
if user == nil {
t.Fatal("expected non-nil user")
}
if !user.EmailVerified {
t.Error("expected EmailVerified to be true")
}
},
},
{
name: "trims token whitespace",
token: " valid-token ",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
hashedToken := HashVerificationToken("valid-token")
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
EmailVerified: false,
EmailVerificationToken: hashedToken,
}
userRepo.Create(user)
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
checkResult: func(t *testing.T, user *database.User) {
if !user.EmailVerified {
t.Error("expected EmailVerified to be true")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo, emailService, cfg := tt.setupMocks()
service := NewRegistrationService(userRepo, emailService, cfg)
user, err := service.ConfirmEmail(tt.token)
if tt.expectedError != nil {
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
return
}
if err != nil {
if tt.checkResult == nil {
return
}
t.Fatalf("unexpected error: %v", err)
}
if tt.checkResult != nil {
tt.checkResult(t, user)
}
})
}
}
func TestRegistrationService_ResendVerificationEmail(t *testing.T) {
tests := []struct {
name string
email string
setupMocks func() (*testutils.MockUserRepository, *EmailService, *config.Config)
expectedError error
}{
{
name: "successful resend",
email: "test@example.com",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
oldTime := time.Now().Add(-10 * time.Minute)
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
EmailVerified: false,
EmailVerificationSentAt: &oldTime,
}
userRepo.Create(user)
emailSender := &testutils.MockEmailSender{}
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
},
{
name: "invalid email",
email: "invalid-email",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: ErrInvalidEmail,
},
{
name: "user not found",
email: "nonexistent@example.com",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: ErrInvalidCredentials,
},
{
name: "email already verified",
email: "test@example.com",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
now := time.Now()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
EmailVerified: true,
EmailVerifiedAt: &now,
}
userRepo.Create(user)
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
},
{
name: "email sent too recently",
email: "test@example.com",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
recentTime := time.Now().Add(-2 * time.Minute)
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
EmailVerified: false,
EmailVerificationSentAt: &recentTime,
}
userRepo.Create(user)
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
},
{
name: "email service error",
email: "test@example.com",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
oldTime := time.Now().Add(-10 * time.Minute)
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
EmailVerified: false,
EmailVerificationSentAt: &oldTime,
}
userRepo.Create(user)
errorSender := &errorEmailSender{err: errors.New("email service error")}
emailService, _ := NewEmailService(testutils.AppTestConfig, errorSender)
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
},
{
name: "trims email whitespace",
email: " test@example.com ",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
oldTime := time.Now().Add(-10 * time.Minute)
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
EmailVerified: false,
EmailVerificationSentAt: &oldTime,
}
userRepo.Create(user)
emailSender := &testutils.MockEmailSender{}
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
},
{
name: "no previous verification sent",
email: "test@example.com",
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
EmailVerified: false,
}
userRepo.Create(user)
emailSender := &testutils.MockEmailSender{}
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
return userRepo, emailService, testutils.AppTestConfig
},
expectedError: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo, emailService, cfg := tt.setupMocks()
service := NewRegistrationService(userRepo, emailService, cfg)
err := service.ResendVerificationEmail(tt.email)
if tt.expectedError != nil {
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
return
}
if err != nil {
if tt.name == "email already verified" || tt.name == "email sent too recently" || tt.name == "email service error" {
if err.Error() == "" {
t.Fatal("expected error message")
}
return
}
t.Fatalf("unexpected error: %v", err)
}
})
}
}

View File

@@ -0,0 +1,124 @@
package services
import (
"fmt"
"golang.org/x/crypto/bcrypt"
"goyco/internal/database"
"goyco/internal/repositories"
)
type SessionService struct {
jwtService *JWTService
userRepo repositories.UserRepository
}
func NewSessionService(jwtService *JWTService, userRepo repositories.UserRepository) *SessionService {
return &SessionService{
jwtService: jwtService,
userRepo: userRepo,
}
}
func (s *SessionService) Login(username, password string) (*AuthResult, error) {
trimmedUsername := TrimString(username)
if trimmedUsername == "" {
return nil, ErrInvalidCredentials
}
user, err := s.userRepo.GetByUsername(trimmedUsername)
if err != nil {
if IsRecordNotFound(err) {
return nil, ErrInvalidCredentials
}
return nil, fmt.Errorf("lookup user: %w", err)
}
if !user.EmailVerified {
return nil, ErrEmailNotVerified
}
if user.Locked {
return nil, ErrAccountLocked
}
if compareErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); compareErr != nil {
return nil, ErrInvalidCredentials
}
return s.issueAuthResult(user)
}
func (s *SessionService) VerifyToken(tokenString string) (uint, error) {
return s.jwtService.VerifyAccessToken(tokenString)
}
func (s *SessionService) issueAuthResult(user *database.User) (*AuthResult, error) {
accessToken, err := s.jwtService.GenerateAccessToken(user)
if err != nil {
return nil, fmt.Errorf("generate access token: %w", err)
}
refreshToken, err := s.jwtService.GenerateRefreshToken(user)
if err != nil {
return nil, fmt.Errorf("generate refresh token: %w", err)
}
return &AuthResult{
AccessToken: accessToken,
RefreshToken: refreshToken,
User: sanitizeUser(user),
}, nil
}
func (s *SessionService) RefreshAccessToken(refreshToken string) (*AuthResult, error) {
accessToken, err := s.jwtService.RefreshAccessToken(refreshToken)
if err != nil {
return nil, err
}
userID, err := s.jwtService.VerifyAccessToken(accessToken)
if err != nil {
return nil, fmt.Errorf("verify new access token: %w", err)
}
user, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, fmt.Errorf("lookup user: %w", err)
}
return &AuthResult{
AccessToken: accessToken,
RefreshToken: refreshToken,
User: sanitizeUser(user),
}, nil
}
func (s *SessionService) RevokeRefreshToken(refreshToken string) error {
return s.jwtService.RevokeRefreshToken(refreshToken)
}
func (s *SessionService) RevokeAllUserTokens(userID uint) error {
return s.jwtService.RevokeAllRefreshTokens(userID)
}
func (s *SessionService) InvalidateAllSessions(userID uint) error {
user, err := s.userRepo.GetByID(userID)
if err != nil {
return fmt.Errorf("load user: %w", err)
}
user.SessionVersion++
if err := s.userRepo.Update(user); err != nil {
return fmt.Errorf("update session version: %w", err)
}
if err := s.jwtService.RevokeAllRefreshTokens(userID); err != nil {
return fmt.Errorf("revoke refresh tokens: %w", err)
}
return nil
}
func (s *SessionService) CleanupExpiredTokens() error {
return s.jwtService.CleanupExpiredTokens()
}

View File

@@ -0,0 +1,563 @@
package services
import (
"errors"
"testing"
"time"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/testutils"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type sessionMockRefreshTokenRepo struct {
tokens map[string]*database.RefreshToken
createErr error
deleteByUserIDErr error
deleteExpiredErr error
getByTokenHashErr error
}
func newSessionMockRefreshTokenRepo() *sessionMockRefreshTokenRepo {
return &sessionMockRefreshTokenRepo{
tokens: make(map[string]*database.RefreshToken),
}
}
func (m *sessionMockRefreshTokenRepo) Create(token *database.RefreshToken) error {
if m.createErr != nil {
return m.createErr
}
if m.tokens == nil {
m.tokens = make(map[string]*database.RefreshToken)
}
m.tokens[token.TokenHash] = token
return nil
}
func (m *sessionMockRefreshTokenRepo) GetByTokenHash(tokenHash string) (*database.RefreshToken, error) {
if m.getByTokenHashErr != nil {
return nil, m.getByTokenHashErr
}
if token, ok := m.tokens[tokenHash]; ok {
return token, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *sessionMockRefreshTokenRepo) DeleteByUserID(userID uint) error {
if m.deleteByUserIDErr != nil {
return m.deleteByUserIDErr
}
for hash, token := range m.tokens {
if token.UserID == userID {
delete(m.tokens, hash)
}
}
return nil
}
func (m *sessionMockRefreshTokenRepo) DeleteExpired() error {
if m.deleteExpiredErr != nil {
return m.deleteExpiredErr
}
now := time.Now()
for hash, token := range m.tokens {
if token.ExpiresAt.Before(now) {
delete(m.tokens, hash)
}
}
return nil
}
func (m *sessionMockRefreshTokenRepo) DeleteByID(id uint) error {
for hash, token := range m.tokens {
if token.ID == id {
delete(m.tokens, hash)
return nil
}
}
return gorm.ErrRecordNotFound
}
func (m *sessionMockRefreshTokenRepo) GetByUserID(userID uint) ([]database.RefreshToken, error) {
var tokens []database.RefreshToken
for _, token := range m.tokens {
if token.UserID == userID {
tokens = append(tokens, *token)
}
}
return tokens, nil
}
func (m *sessionMockRefreshTokenRepo) CountByUserID(userID uint) (int64, error) {
var count int64
for _, token := range m.tokens {
if token.UserID == userID {
count++
}
}
return count, nil
}
func createSessionTestJWTService(userRepo *testutils.MockUserRepository) (*JWTService, *sessionMockRefreshTokenRepo) {
cfg := &config.JWTConfig{
Secret: "test-secret-key-that-is-long-enough-for-security",
Expiration: 1,
RefreshExpiration: 24,
Issuer: "test-issuer",
Audience: "test-audience",
KeyRotation: config.KeyRotationConfig{
Enabled: false,
CurrentKey: "",
PreviousKey: "",
KeyID: "",
},
}
refreshRepo := newSessionMockRefreshTokenRepo()
jwtService := NewJWTService(cfg, userRepo, refreshRepo)
return jwtService, refreshRepo
}
func createTestUserWithPassword(password string) *database.User {
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
Password: string(hashedPassword),
EmailVerified: true,
Locked: false,
SessionVersion: 1,
}
}
func TestNewSessionService(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
jwtService, _ := createSessionTestJWTService(userRepo)
service := NewSessionService(jwtService, userRepo)
if service == nil {
t.Fatal("expected service to be created")
}
if service.jwtService != jwtService {
t.Error("expected jwtService to be set")
}
if service.userRepo != userRepo {
t.Error("expected userRepo to be set")
}
}
func TestSessionService_Login(t *testing.T) {
tests := []struct {
name string
username string
password string
setupMocks func() (*JWTService, *testutils.MockUserRepository)
expectedError error
checkResult func(*testing.T, *AuthResult)
}{
{
name: "successful login",
username: "testuser",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: nil,
checkResult: func(t *testing.T, result *AuthResult) {
if result == nil {
t.Fatal("expected non-nil result")
}
if result.AccessToken == "" {
t.Error("expected non-empty access token")
}
if result.RefreshToken == "" {
t.Error("expected non-empty refresh token")
}
if result.User == nil {
t.Fatal("expected non-nil user")
}
if result.User.Username != "testuser" {
t.Errorf("expected username 'testuser', got %q", result.User.Username)
}
if result.User.Password != "" {
t.Error("expected password to be sanitized")
}
},
},
{
name: "empty username",
username: "",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: ErrInvalidCredentials,
checkResult: nil,
},
{
name: "whitespace only username",
username: " ",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: ErrInvalidCredentials,
checkResult: nil,
},
{
name: "user not found",
username: "nonexistent",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: ErrInvalidCredentials,
checkResult: nil,
},
{
name: "email not verified",
username: "testuser",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
user.EmailVerified = false
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: ErrEmailNotVerified,
checkResult: nil,
},
{
name: "account locked",
username: "testuser",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
user.Locked = true
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: ErrAccountLocked,
checkResult: nil,
},
{
name: "invalid password",
username: "testuser",
password: "WrongPassword",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: ErrInvalidCredentials,
checkResult: nil,
},
{
name: "trims username whitespace",
username: " testuser ",
password: "SecurePass123!",
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: nil,
checkResult: func(t *testing.T, result *AuthResult) {
if result.User.Username != "testuser" {
t.Errorf("expected trimmed username 'testuser', got %q", result.User.Username)
}
if result.AccessToken == "" {
t.Error("expected non-empty access token")
}
if result.RefreshToken == "" {
t.Error("expected non-empty refresh token")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
jwtService, userRepo := tt.setupMocks()
service := NewSessionService(jwtService, userRepo)
result, err := service.Login(tt.username, tt.password)
if tt.expectedError != nil {
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
return
}
if err != nil {
if tt.checkResult == nil {
return
}
t.Fatalf("unexpected error: %v", err)
}
if tt.checkResult != nil {
tt.checkResult(t, result)
}
})
}
}
func TestSessionService_VerifyToken(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
t.Run("successful verification", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
token, err := jwtService.GenerateAccessToken(user)
if err != nil {
t.Fatalf("failed to generate token: %v", err)
}
userID, err := service.VerifyToken(token)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if userID != user.ID {
t.Errorf("expected user ID %d, got %d", user.ID, userID)
}
})
t.Run("invalid token", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
_, err := service.VerifyToken("invalid-token")
if err == nil {
t.Fatal("expected error for invalid token")
}
})
t.Run("empty token", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
_, err := service.VerifyToken("")
if err == nil {
t.Fatal("expected error for empty token")
}
})
}
func TestSessionService_RefreshAccessToken(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
t.Run("successful refresh", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
refreshToken, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("failed to generate refresh token: %v", err)
}
result, err := service.RefreshAccessToken(refreshToken)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
if result.AccessToken == "" {
t.Error("expected non-empty access token")
}
if result.RefreshToken != refreshToken {
t.Errorf("expected refresh token to remain unchanged")
}
if result.User == nil {
t.Fatal("expected non-nil user")
}
})
t.Run("invalid refresh token", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
_, err := service.RefreshAccessToken("invalid-refresh-token")
if err == nil {
t.Fatal("expected error for invalid refresh token")
}
if !errors.Is(err, ErrRefreshTokenInvalid) {
t.Errorf("expected ErrRefreshTokenInvalid, got %v", err)
}
})
}
func TestSessionService_RevokeRefreshToken(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
t.Run("successful revocation", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
refreshToken, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("failed to generate refresh token: %v", err)
}
err = service.RevokeRefreshToken(refreshToken)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, err = service.RefreshAccessToken(refreshToken)
if err == nil {
t.Fatal("expected error when using revoked refresh token")
}
})
}
func TestSessionService_RevokeAllUserTokens(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
user := createTestUserWithPassword("SecurePass123!")
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
t.Run("successful revocation", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
refreshToken, err := jwtService.GenerateRefreshToken(user)
if err != nil {
t.Fatalf("failed to generate refresh token: %v", err)
}
err = service.RevokeAllUserTokens(user.ID)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, err = service.RefreshAccessToken(refreshToken)
if err == nil {
t.Fatal("expected error when using revoked refresh token")
}
})
}
func TestSessionService_InvalidateAllSessions(t *testing.T) {
tests := []struct {
name string
userID uint
setupMocks func() (*JWTService, *testutils.MockUserRepository)
expectedError error
checkResult func(*testing.T, *testutils.MockUserRepository)
}{
{
name: "successful invalidation",
userID: 1,
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
SessionVersion: 1,
}
userRepo.Create(user)
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: nil,
checkResult: func(t *testing.T, userRepo *testutils.MockUserRepository) {
user, err := userRepo.GetByID(1)
if err != nil {
t.Fatalf("failed to get user: %v", err)
}
if user.SessionVersion != 2 {
t.Errorf("expected SessionVersion to be 2, got %d", user.SessionVersion)
}
},
},
{
name: "user not found",
userID: 999,
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
userRepo := testutils.NewMockUserRepository()
jwtService, _ := createSessionTestJWTService(userRepo)
return jwtService, userRepo
},
expectedError: nil,
checkResult: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
jwtService, userRepo := tt.setupMocks()
service := NewSessionService(jwtService, userRepo)
err := service.InvalidateAllSessions(tt.userID)
if tt.expectedError != nil {
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
return
}
if err != nil {
if tt.name == "user not found" {
if err.Error() == "" {
t.Fatal("expected error message")
}
return
}
t.Fatalf("unexpected error: %v", err)
}
if tt.checkResult != nil {
tt.checkResult(t, userRepo)
}
})
}
}
func TestSessionService_CleanupExpiredTokens(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
jwtService, _ := createSessionTestJWTService(userRepo)
t.Run("successful cleanup", func(t *testing.T) {
service := NewSessionService(jwtService, userRepo)
err := service.CleanupExpiredTokens()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
}

View File

@@ -0,0 +1,598 @@
package services
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/net/html"
)
var (
ErrUnsupportedScheme = errors.New("unsupported URL scheme")
ErrTitleNotFound = errors.New("page title not found")
ErrSSRFBlocked = errors.New("request blocked for security reasons")
ErrTooManyRedirects = errors.New("too many redirects")
)
const (
maxTitleBodyBytes = 512 * 1024
defaultUserAgent = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
maxRedirects = 3
requestTimeout = 10 * time.Second
dialTimeout = 5 * time.Second
tlsHandshakeTimeout = 5 * time.Second
responseHeaderTimeout = 5 * time.Second
maxContentLength = 10 * 1024 * 1024
)
type TitleFetcher interface {
FetchTitle(ctx context.Context, rawURL string) (string, error)
}
type DNSResolver interface {
LookupIP(hostname string) ([]net.IP, error)
}
type DefaultDNSResolver struct{}
func (d DefaultDNSResolver) LookupIP(hostname string) ([]net.IP, error) {
return net.LookupIP(hostname)
}
type DNSCache struct {
mu sync.RWMutex
data map[string][]net.IP
}
func NewDNSCache() *DNSCache {
return &DNSCache{
data: make(map[string][]net.IP),
}
}
func (c *DNSCache) Get(hostname string) ([]net.IP, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
ips, exists := c.data[hostname]
return ips, exists
}
func (c *DNSCache) Set(hostname string, ips []net.IP) {
c.mu.Lock()
defer c.mu.Unlock()
c.data[hostname] = ips
}
type CachedDNSResolver struct {
resolver DNSResolver
cache *DNSCache
}
func NewCachedDNSResolver(resolver DNSResolver) *CachedDNSResolver {
return &CachedDNSResolver{
resolver: resolver,
cache: NewDNSCache(),
}
}
func (c *CachedDNSResolver) LookupIP(hostname string) ([]net.IP, error) {
if ips, exists := c.cache.Get(hostname); exists {
return ips, nil
}
ips, err := c.resolver.LookupIP(hostname)
if err != nil {
return nil, err
}
c.cache.Set(hostname, ips)
return ips, nil
}
type CustomDialer struct {
cache *DNSCache
fallback *net.Dialer
}
func NewCustomDialer(cache *DNSCache) *CustomDialer {
return &CustomDialer{
cache: cache,
fallback: &net.Dialer{
Timeout: dialTimeout,
},
}
}
func (d *CustomDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return d.fallback.DialContext(ctx, network, address)
}
if ips, exists := d.cache.Get(host); exists {
for _, ip := range ips {
ipAddr := net.JoinHostPort(ip.String(), port)
if conn, err := d.fallback.DialContext(ctx, network, ipAddr); err == nil {
return conn, nil
}
}
}
return d.fallback.DialContext(ctx, network, address)
}
type URLMetadataService struct {
client *http.Client
resolver DNSResolver
dnsCache *DNSCache
approvedHosts map[string]bool
mu sync.RWMutex
}
func NewURLMetadataService() *URLMetadataService {
dnsCache := NewDNSCache()
cachedResolver := NewCachedDNSResolver(DefaultDNSResolver{})
customDialer := NewCustomDialer(dnsCache)
svc := &URLMetadataService{
resolver: cachedResolver,
dnsCache: dnsCache,
approvedHosts: make(map[string]bool),
}
transport := &http.Transport{
DialContext: customDialer.DialContext,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: tlsHandshakeTimeout,
ResponseHeaderTimeout: responseHeaderTimeout,
DisableKeepAlives: false,
}
svc.client = &http.Client{
Timeout: requestTimeout,
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirects {
return ErrTooManyRedirects
}
hostname := req.URL.Hostname()
svc.mu.RLock()
approved := svc.approvedHosts[hostname]
svc.mu.RUnlock()
if approved {
return nil
}
if err := svc.validateURLForSSRF(req.URL); err != nil {
return err
}
svc.mu.Lock()
svc.approvedHosts[hostname] = true
svc.mu.Unlock()
return nil
},
}
return svc
}
func (s *URLMetadataService) FetchTitle(ctx context.Context, rawURL string) (string, error) {
if rawURL == "" {
return "", errors.New("empty URL")
}
parsed, err := url.Parse(rawURL)
if err != nil {
return "", fmt.Errorf("parse url: %w", err)
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", ErrUnsupportedScheme
}
hostname := parsed.Hostname()
s.mu.RLock()
approved := s.approvedHosts[hostname]
s.mu.RUnlock()
if !approved {
if err := s.validateURLForSSRF(parsed); err != nil {
return "", err
}
s.mu.Lock()
s.approvedHosts[hostname] = true
s.mu.Unlock()
}
request, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
if err != nil {
return "", fmt.Errorf("build request: %w", err)
}
request.Header.Set("User-Agent", defaultUserAgent)
request.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8")
request.Header.Set("Accept-Language", "en-US,en;q=0.5")
resp, err := s.client.Do(request)
if err != nil {
return "", fmt.Errorf("fetch url: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if !strings.Contains(strings.ToLower(contentType), "text/html") {
return "", ErrTitleNotFound
}
contentLength := resp.ContentLength
if contentLength > maxContentLength {
return "", ErrTitleNotFound
}
limited := io.LimitReader(resp.Body, maxTitleBodyBytes)
body, err := io.ReadAll(limited)
if err != nil {
return "", fmt.Errorf("read body: %w", err)
}
title := s.ExtractTitleFromHTML(string(body))
if title != "" {
return title, nil
}
return "", ErrTitleNotFound
}
func (s *URLMetadataService) ExtractTitleFromHTML(html string) string {
if title := s.ExtractFromTitleTag(html); title != "" {
return title
}
if title := s.ExtractFromOpenGraph(html); title != "" {
return title
}
if title := s.ExtractFromJSONLD(html); title != "" {
return title
}
if title := s.ExtractFromTwitterCard(html); title != "" {
return title
}
if title := s.extractFromMetaTags(html); title != "" {
return title
}
return ""
}
func (s *URLMetadataService) ExtractFromTitleTag(htmlContent string) string {
tokenizer := html.NewTokenizer(strings.NewReader(htmlContent))
for {
tokenType := tokenizer.Next()
switch tokenType {
case html.ErrorToken:
if errors.Is(tokenizer.Err(), io.EOF) {
return ""
}
return ""
case html.StartTagToken, html.SelfClosingTagToken:
token := tokenizer.Token()
if strings.EqualFold(token.Data, "title") {
textTokenType := tokenizer.Next()
if textTokenType == html.TextToken {
rawTitle := tokenizer.Token().Data
cleaned := s.optimizedTitleClean(rawTitle)
if cleaned != "" {
return cleaned
}
}
}
}
}
}
func (s *URLMetadataService) ExtractFromOpenGraph(htmlContent string) string {
lines := strings.Split(htmlContent, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.Contains(strings.ToLower(line), `property="og:title"`) && strings.Contains(line, `content="`) {
start := strings.Index(line, `content="`)
if start != -1 {
start += 9
end := strings.Index(line[start:], `"`)
if end != -1 {
title := line[start : start+end]
cleaned := s.optimizedTitleClean(title)
if cleaned != "" {
return cleaned
}
}
}
}
}
return ""
}
func (s *URLMetadataService) ExtractFromJSONLD(htmlContent string) string {
lines := strings.Split(htmlContent, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.Contains(line, `"@type":"VideoObject"`) || strings.Contains(line, `"@type":"WebPage"`) {
if strings.Contains(line, `"name":`) {
start := strings.Index(line, `"name":`)
if start != -1 {
start += 7
for i := start; i < len(line); i++ {
if line[i] == '"' {
start = i + 1
break
}
}
end := strings.Index(line[start:], `"`)
if end != -1 {
title := line[start : start+end]
cleaned := s.optimizedTitleClean(title)
if cleaned != "" {
return cleaned
}
}
}
}
}
}
return ""
}
func (s *URLMetadataService) ExtractFromTwitterCard(htmlContent string) string {
lines := strings.Split(htmlContent, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.Contains(strings.ToLower(line), `name="twitter:title"`) && strings.Contains(line, `content="`) {
start := strings.Index(line, `content="`)
if start != -1 {
start += 9
end := strings.Index(line[start:], `"`)
if end != -1 {
title := line[start : start+end]
cleaned := s.optimizedTitleClean(title)
if cleaned != "" {
return cleaned
}
}
}
}
}
return ""
}
func (s *URLMetadataService) extractFromMetaTags(htmlContent string) string {
lines := strings.Split(htmlContent, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.Contains(strings.ToLower(line), `name="title"`) && strings.Contains(line, `content="`) {
start := strings.Index(line, `content="`)
if start != -1 {
start += 9
end := strings.Index(line[start:], `"`)
if end != -1 {
title := line[start : start+end]
cleaned := s.optimizedTitleClean(title)
if cleaned != "" {
return cleaned
}
}
}
}
}
return ""
}
func (s *URLMetadataService) optimizedTitleClean(title string) string {
if title == "" {
return ""
}
var result strings.Builder
result.Grow(len(title))
inWhitespace := false
started := false
for _, r := range title {
if r == ' ' || r == '\t' || r == '\n' || r == '\r' {
if started && !inWhitespace {
result.WriteRune(' ')
inWhitespace = true
}
} else {
result.WriteRune(r)
inWhitespace = false
started = true
}
}
cleaned := result.String()
if len(cleaned) > 0 && cleaned[len(cleaned)-1] == ' ' {
cleaned = cleaned[:len(cleaned)-1]
}
return cleaned
}
func (s *URLMetadataService) validateURLForSSRF(u *url.URL) error {
if u == nil {
return ErrSSRFBlocked
}
if u.Scheme != "http" && u.Scheme != "https" {
return ErrSSRFBlocked
}
if u.Host == "" {
return ErrSSRFBlocked
}
hostname := u.Hostname()
if hostname == "" {
return ErrSSRFBlocked
}
if isLocalhost(hostname) {
return ErrSSRFBlocked
}
ips, err := s.resolver.LookupIP(hostname)
if err != nil {
return ErrSSRFBlocked
}
for _, ip := range ips {
if isPrivateOrReservedIP(ip) {
return ErrSSRFBlocked
}
}
return nil
}
func isLocalhost(hostname string) bool {
hostname = strings.ToLower(hostname)
localhostNames := []string{
"localhost",
"127.0.0.1",
"::1",
"0.0.0.0",
"0:0:0:0:0:0:0:1",
"0:0:0:0:0:0:0:0",
}
for _, name := range localhostNames {
if hostname == name {
return true
}
}
return false
}
func isPrivateOrReservedIP(ip net.IP) bool {
if ip == nil {
return true
}
ipv4 := ip.To4()
if ipv4 == nil {
return isPrivateIPv6(ip)
}
privateRanges := []struct {
start, end net.IP
}{
{net.IPv4(10, 0, 0, 0), net.IPv4(10, 255, 255, 255)},
{net.IPv4(172, 16, 0, 0), net.IPv4(172, 31, 255, 255)},
{net.IPv4(192, 168, 0, 0), net.IPv4(192, 168, 255, 255)},
{net.IPv4(127, 0, 0, 0), net.IPv4(127, 255, 255, 255)},
{net.IPv4(169, 254, 0, 0), net.IPv4(169, 254, 255, 255)},
{net.IPv4(224, 0, 0, 0), net.IPv4(239, 255, 255, 255)},
{net.IPv4(240, 0, 0, 0), net.IPv4(255, 255, 255, 255)},
}
for _, r := range privateRanges {
if ipInRange(ipv4, r.start, r.end) {
return true
}
}
return false
}
func isPrivateIPv6(ip net.IP) bool {
privateRanges := []struct {
prefix []byte
length int
}{
{[]byte{0xfc, 0x00}, 7},
{[]byte{0xfe, 0x80}, 10},
{[]byte{0xff, 0x00}, 8},
{[]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, 128},
}
for _, r := range privateRanges {
if ipv6InRange(ip, r.prefix, r.length) {
return true
}
}
return false
}
func ipInRange(ip, start, end net.IP) bool {
ipInt := ipToInt(ip)
startInt := ipToInt(start)
endInt := ipToInt(end)
return ipInt >= startInt && ipInt <= endInt
}
func ipToInt(ip net.IP) uint32 {
ipv4 := ip.To4()
if ipv4 == nil {
return 0
}
return uint32(ipv4[0])<<24 + uint32(ipv4[1])<<16 + uint32(ipv4[2])<<8 + uint32(ipv4[3])
}
func ipv6InRange(ip net.IP, prefix []byte, length int) bool {
ipBytes := ip.To16()
if ipBytes == nil {
return false
}
bytesToCompare := length / 8
bitsToCompare := length % 8
for i := 0; i < bytesToCompare && i < len(prefix) && i < len(ipBytes); i++ {
if ipBytes[i] != prefix[i] {
return false
}
}
if bitsToCompare > 0 && bytesToCompare < len(prefix) && bytesToCompare < len(ipBytes) {
mask := byte(0xff) << (8 - bitsToCompare)
if (ipBytes[bytesToCompare] & mask) != (prefix[bytesToCompare] & mask) {
return false
}
}
return true
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,160 @@
package services
import (
"fmt"
"time"
"golang.org/x/crypto/bcrypt"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/validation"
)
type UserManagementService struct {
userRepo repositories.UserRepository
postRepo repositories.PostRepository
emailService *EmailService
}
func NewUserManagementService(userRepo repositories.UserRepository, postRepo repositories.PostRepository, emailService *EmailService) *UserManagementService {
return &UserManagementService{
userRepo: userRepo,
postRepo: postRepo,
emailService: emailService,
}
}
func (s *UserManagementService) UpdateUsername(userID uint, newUsername string) (*database.User, error) {
trimmed := TrimString(newUsername)
if err := validation.ValidateUsername(trimmed); err != nil {
return nil, err
}
existing, err := s.userRepo.GetByUsernameIncludingDeleted(trimmed)
if err == nil && existing.ID != userID {
return nil, ErrUsernameTaken
}
if err != nil && !IsRecordNotFound(err) {
return nil, fmt.Errorf("lookup username: %w", err)
}
user, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, fmt.Errorf("load user: %w", err)
}
if user.Username == trimmed {
return sanitizeUser(user), nil
}
user.Username = trimmed
if err := s.userRepo.Update(user); err != nil {
if handled := HandleUniqueConstraintError(err); handled != err {
return nil, handled
}
return nil, fmt.Errorf("update user: %w", err)
}
return sanitizeUser(user), nil
}
func (s *UserManagementService) UpdatePassword(userID uint, currentPassword, newPassword string) (*database.User, error) {
if err := validation.ValidatePassword(newPassword); err != nil {
return nil, err
}
user, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, fmt.Errorf("load user: %w", err)
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(currentPassword)); err != nil {
return nil, fmt.Errorf("current password is incorrect")
}
hashedPassword, err := HashPassword(newPassword, DefaultBcryptCost)
if err != nil {
return nil, err
}
user.Password = string(hashedPassword)
if err := s.userRepo.Update(user); err != nil {
return nil, fmt.Errorf("update password: %w", err)
}
return sanitizeUser(user), nil
}
func (s *UserManagementService) UpdateEmail(userID uint, newEmail string) (*database.User, error) {
normalized, err := normalizeEmail(newEmail)
if err != nil {
return nil, err
}
existing, err := s.userRepo.GetByEmail(normalized)
if err == nil && existing.ID != userID {
return nil, ErrEmailTaken
}
if err != nil && !IsRecordNotFound(err) {
return nil, fmt.Errorf("lookup email: %w", err)
}
user, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, fmt.Errorf("load user: %w", err)
}
if user.Email == normalized {
return sanitizeUser(user), nil
}
previousEmail := user.Email
previousVerified := user.EmailVerified
previousVerifiedAt := user.EmailVerifiedAt
previousToken := user.EmailVerificationToken
previousSentAt := user.EmailVerificationSentAt
token, hashed, err := generateVerificationToken()
if err != nil {
return nil, err
}
now := time.Now()
user.Email = normalized
user.EmailVerified = false
user.EmailVerifiedAt = nil
user.EmailVerificationToken = hashed
user.EmailVerificationSentAt = &now
if err := s.userRepo.Update(user); err != nil {
if handled := HandleUniqueConstraintError(err); handled != err {
return nil, handled
}
return nil, fmt.Errorf("update user: %w", err)
}
if err := s.emailService.SendEmailChangeVerificationEmail(user, token); err != nil {
user.Email = previousEmail
user.EmailVerified = previousVerified
user.EmailVerifiedAt = previousVerifiedAt
user.EmailVerificationToken = previousToken
user.EmailVerificationSentAt = previousSentAt
_ = s.userRepo.Update(user)
return nil, err
}
return sanitizeUser(user), nil
}
func (s *UserManagementService) UserHasPosts(userID uint) (bool, int64, error) {
if s.postRepo == nil {
return false, 0, fmt.Errorf("post repository not configured")
}
count, err := s.postRepo.CountByUserID(userID)
if err != nil {
return false, 0, fmt.Errorf("count user posts: %w", err)
}
return count > 0, count, nil
}

View File

@@ -0,0 +1,647 @@
package services
import (
"errors"
"testing"
"time"
"goyco/internal/database"
"goyco/internal/testutils"
"golang.org/x/crypto/bcrypt"
)
func TestNewUserManagementService(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
service := NewUserManagementService(userRepo, postRepo, emailService)
if service == nil {
t.Fatal("expected service to be created")
}
if service.userRepo != userRepo {
t.Error("expected userRepo to be set")
}
if service.postRepo != postRepo {
t.Error("expected postRepo to be set")
}
if service.emailService != emailService {
t.Error("expected emailService to be set")
}
}
func TestUserManagementService_UpdateUsername(t *testing.T) {
tests := []struct {
name string
userID uint
newUsername string
setupMocks func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService)
expectedError error
checkResult func(*testing.T, *database.User)
}{
{
name: "successful update",
userID: 1,
newUsername: "newusername",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "oldusername",
Email: "test@example.com",
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: func(t *testing.T, user *database.User) {
if user == nil {
t.Fatal("expected non-nil user")
}
if user.Username != "newusername" {
t.Errorf("expected username 'newusername', got %q", user.Username)
}
if user.Password != "" {
t.Error("expected password to be sanitized")
}
},
},
{
name: "invalid username",
userID: 1,
newUsername: "",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "oldusername",
Email: "test@example.com",
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: nil,
},
{
name: "username already taken by different user",
userID: 1,
newUsername: "takenusername",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
user1 := &database.User{
ID: 1,
Username: "oldusername",
Email: "test1@example.com",
}
user2 := &database.User{
ID: 2,
Username: "takenusername",
Email: "test2@example.com",
}
userRepo.Create(user1)
userRepo.Create(user2)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: ErrUsernameTaken,
checkResult: nil,
},
{
name: "same username",
userID: 1,
newUsername: "oldusername",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "oldusername",
Email: "test@example.com",
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: func(t *testing.T, user *database.User) {
if user.Username != "oldusername" {
t.Errorf("expected username 'oldusername', got %q", user.Username)
}
},
},
{
name: "user not found",
userID: 999,
newUsername: "newusername",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: nil,
},
{
name: "trims username whitespace",
userID: 1,
newUsername: " newusername ",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "oldusername",
Email: "test@example.com",
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: func(t *testing.T, user *database.User) {
if user.Username != "newusername" {
t.Errorf("expected trimmed username 'newusername', got %q", user.Username)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo, postRepo, emailService := tt.setupMocks()
service := NewUserManagementService(userRepo, postRepo, emailService)
result, err := service.UpdateUsername(tt.userID, tt.newUsername)
if tt.expectedError != nil {
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
return
}
if err != nil {
if tt.checkResult == nil {
return
}
t.Fatalf("unexpected error: %v", err)
}
if tt.checkResult != nil {
tt.checkResult(t, result)
}
})
}
}
func TestUserManagementService_UpdatePassword(t *testing.T) {
tests := []struct {
name string
userID uint
currentPassword string
newPassword string
setupMocks func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService)
expectedError error
checkResult func(*testing.T, *database.User)
}{
{
name: "successful update",
userID: 1,
currentPassword: "OldPass123!",
newPassword: "NewPass123!",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("OldPass123!"), bcrypt.DefaultCost)
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
Password: string(hashedPassword),
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: func(t *testing.T, user *database.User) {
if user == nil {
t.Fatal("expected non-nil user")
}
if user.Password != "" {
t.Error("expected password to be sanitized")
}
},
},
{
name: "invalid new password",
userID: 1,
currentPassword: "OldPass123!",
newPassword: "short",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("OldPass123!"), bcrypt.DefaultCost)
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
Password: string(hashedPassword),
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: nil,
},
{
name: "incorrect current password",
userID: 1,
currentPassword: "WrongPassword",
newPassword: "NewPass123!",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("OldPass123!"), bcrypt.DefaultCost)
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
Password: string(hashedPassword),
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: nil,
},
{
name: "user not found",
userID: 999,
currentPassword: "OldPass123!",
newPassword: "NewPass123!",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo, postRepo, emailService := tt.setupMocks()
service := NewUserManagementService(userRepo, postRepo, emailService)
result, err := service.UpdatePassword(tt.userID, tt.currentPassword, tt.newPassword)
if tt.expectedError != nil {
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
return
}
if err != nil {
if tt.checkResult == nil {
return
}
t.Fatalf("unexpected error: %v", err)
}
if tt.checkResult != nil {
tt.checkResult(t, result)
}
})
}
}
func TestUserManagementService_UpdateEmail(t *testing.T) {
tests := []struct {
name string
userID uint
newEmail string
setupMocks func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService)
expectedError error
checkResult func(*testing.T, *database.User)
}{
{
name: "successful update",
userID: 1,
newEmail: "newemail@example.com",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
now := time.Now()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "oldemail@example.com",
EmailVerified: true,
EmailVerifiedAt: &now,
EmailVerificationToken: "",
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: func(t *testing.T, user *database.User) {
if user == nil {
t.Fatal("expected non-nil user")
}
if user.Email != "newemail@example.com" {
t.Errorf("expected email 'newemail@example.com', got %q", user.Email)
}
if user.EmailVerified {
t.Error("expected EmailVerified to be false")
}
},
},
{
name: "invalid email",
userID: 1,
newEmail: "invalid-email",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "oldemail@example.com",
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: nil,
},
{
name: "email already taken by different user",
userID: 1,
newEmail: "taken@example.com",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
user1 := &database.User{
ID: 1,
Username: "testuser1",
Email: "oldemail@example.com",
}
user2 := &database.User{
ID: 2,
Username: "testuser2",
Email: "taken@example.com",
}
userRepo.Create(user1)
userRepo.Create(user2)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: ErrEmailTaken,
checkResult: nil,
},
{
name: "same email",
userID: 1,
newEmail: "oldemail@example.com",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "oldemail@example.com",
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: func(t *testing.T, user *database.User) {
if user.Email != "oldemail@example.com" {
t.Errorf("expected email 'oldemail@example.com', got %q", user.Email)
}
},
},
{
name: "user not found",
userID: 999,
newEmail: "newemail@example.com",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: nil,
},
{
name: "normalizes email",
userID: 1,
newEmail: "NEWEMAIL@EXAMPLE.COM",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "oldemail@example.com",
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: func(t *testing.T, user *database.User) {
if user.Email != "newemail@example.com" {
t.Errorf("expected normalized email 'newemail@example.com', got %q", user.Email)
}
},
},
{
name: "email service error rolls back",
userID: 1,
newEmail: "newemail@example.com",
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
now := time.Now()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "oldemail@example.com",
EmailVerified: true,
EmailVerifiedAt: &now,
EmailVerificationToken: "",
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
errorSender := &errorEmailSender{err: errors.New("email service error")}
emailService, _ := NewEmailService(testutils.AppTestConfig, errorSender)
return userRepo, postRepo, emailService
},
expectedError: nil,
checkResult: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo, postRepo, emailService := tt.setupMocks()
service := NewUserManagementService(userRepo, postRepo, emailService)
result, err := service.UpdateEmail(tt.userID, tt.newEmail)
if tt.expectedError != nil {
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
return
}
if err != nil {
if tt.checkResult == nil || tt.name == "email service error rolls back" {
if tt.name == "email service error rolls back" {
user, _ := userRepo.GetByID(1)
if user.Email != "oldemail@example.com" {
t.Error("expected email to be rolled back to original")
}
if !user.EmailVerified {
t.Error("expected EmailVerified to be rolled back")
}
}
return
}
t.Fatalf("unexpected error: %v", err)
}
if tt.checkResult != nil {
tt.checkResult(t, result)
}
})
}
}
func TestUserManagementService_UserHasPosts(t *testing.T) {
tests := []struct {
name string
userID uint
setupMocks func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService)
expectedHas bool
expectedCount int64
expectedError error
}{
{
name: "user has posts",
userID: 1,
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
userID := uint(1)
post1 := &database.Post{
ID: 1,
AuthorID: &userID,
Title: "Post 1",
URL: "https://example.com/1",
}
post2 := &database.Post{
ID: 2,
AuthorID: &userID,
Title: "Post 2",
URL: "https://example.com/2",
}
postRepo.Create(post1)
postRepo.Create(post2)
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedHas: true,
expectedCount: 2,
expectedError: nil,
},
{
name: "user has no posts",
userID: 1,
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
userRepo := testutils.NewMockUserRepository()
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
}
userRepo.Create(user)
postRepo := testutils.NewMockPostRepository()
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
return userRepo, postRepo, emailService
},
expectedHas: false,
expectedCount: 0,
expectedError: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
userRepo, postRepo, emailService := tt.setupMocks()
service := NewUserManagementService(userRepo, postRepo, emailService)
hasPosts, count, err := service.UserHasPosts(tt.userID)
if tt.expectedError != nil {
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if hasPosts != tt.expectedHas {
t.Errorf("expected hasPosts %v, got %v", tt.expectedHas, hasPosts)
}
if count != tt.expectedCount {
t.Errorf("expected count %d, got %d", tt.expectedCount, count)
}
})
}
}

View File

@@ -0,0 +1,376 @@
package services
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"sync"
"time"
"gorm.io/gorm"
"goyco/internal/database"
"goyco/internal/repositories"
)
type VoteService struct {
voteRepo repositories.VoteRepository
postRepo repositories.PostRepository
db *gorm.DB
voteMutex sync.RWMutex
}
type VoteRequest struct {
UserID uint `json:"user_id,omitempty"`
PostID uint `json:"post_id"`
Type database.VoteType `json:"type"`
IPAddress string `json:"-"`
UserAgent string `json:"-"`
}
type VoteResponse struct {
PostID uint `json:"post_id"`
Type database.VoteType `json:"type"`
UpVotes int `json:"up_votes"`
DownVotes int `json:"down_votes"`
Score int `json:"score"`
Message string `json:"message"`
IsUnauthenticated bool `json:"is_unauthenticated"`
}
func NewVoteService(voteRepo repositories.VoteRepository, postRepo repositories.PostRepository, db *gorm.DB) *VoteService {
return &VoteService{
voteRepo: voteRepo,
postRepo: postRepo,
db: db,
}
}
func (vs *VoteService) GenerateVoteHash(ipAddress, userAgent string, postID uint) string {
data := fmt.Sprintf("%s:%s:%d", ipAddress, userAgent, postID)
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}
func (vs *VoteService) CastVote(req VoteRequest) (*VoteResponse, error) {
if err := vs.validateVoteRequest(req); err != nil {
return nil, err
}
vs.voteMutex.Lock()
defer vs.voteMutex.Unlock()
var response *VoteResponse
if vs.db == nil {
return vs.castVoteWithoutTransaction(req)
}
err := vs.db.Transaction(func(tx *gorm.DB) error {
txVoteRepo := vs.voteRepo.WithTx(tx)
txPostRepo := vs.postRepo.WithTx(tx)
post, err := txPostRepo.GetByID(req.PostID)
if err != nil {
if IsRecordNotFound(err) {
return errors.New("post not found")
}
return fmt.Errorf("failed to get post: %w", err)
}
isUnauthenticated := req.UserID == 0
if req.Type == database.VoteNone {
var existingVote *database.Vote
var err error
if isUnauthenticated {
voteHash := vs.GenerateVoteHash(req.IPAddress, req.UserAgent, req.PostID)
existingVote, err = txVoteRepo.GetByVoteHash(voteHash)
} else {
existingVote, err = txVoteRepo.GetByUserAndPost(req.UserID, req.PostID)
}
if err != nil {
if IsRecordNotFound(err) {
response = vs.buildVoteResponse(post, database.VoteNone, isUnauthenticated)
return nil
}
return fmt.Errorf("failed to get existing vote: %w", err)
}
if err := txVoteRepo.Delete(existingVote.ID); err != nil {
return fmt.Errorf("failed to delete vote: %w", err)
}
} else {
var vote *database.Vote
if isUnauthenticated {
voteHash := vs.GenerateVoteHash(req.IPAddress, req.UserAgent, req.PostID)
vote = &database.Vote{
PostID: req.PostID,
Type: req.Type,
VoteHash: &voteHash,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
} else {
vote = &database.Vote{
UserID: &req.UserID,
PostID: req.PostID,
Type: req.Type,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
}
if err := txVoteRepo.CreateOrUpdate(vote); err != nil {
return fmt.Errorf("failed to create or update vote: %w", err)
}
}
if err := vs.updatePostVoteCountsWithTx(tx, req.PostID); err != nil {
return fmt.Errorf("failed to update post vote counts: %w", err)
}
updatedPost, err := txPostRepo.GetByID(req.PostID)
if err != nil {
return fmt.Errorf("failed to get updated post: %w", err)
}
response = vs.buildVoteResponse(updatedPost, req.Type, isUnauthenticated)
return nil
})
if err != nil {
return nil, err
}
return response, nil
}
func (vs *VoteService) castVoteWithoutTransaction(req VoteRequest) (*VoteResponse, error) {
post, err := vs.postRepo.GetByID(req.PostID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("post not found")
}
return nil, fmt.Errorf("failed to get post: %w", err)
}
isUnauthenticated := req.UserID == 0
if req.Type == database.VoteNone {
var existingVote *database.Vote
var err error
if isUnauthenticated {
voteHash := vs.GenerateVoteHash(req.IPAddress, req.UserAgent, req.PostID)
existingVote, err = vs.voteRepo.GetByVoteHash(voteHash)
} else {
existingVote, err = vs.voteRepo.GetByUserAndPost(req.UserID, req.PostID)
}
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return vs.buildVoteResponse(post, database.VoteNone, isUnauthenticated), nil
}
return nil, fmt.Errorf("failed to get existing vote: %w", err)
}
if err := vs.voteRepo.Delete(existingVote.ID); err != nil {
return nil, fmt.Errorf("failed to delete vote: %w", err)
}
} else {
var vote *database.Vote
if isUnauthenticated {
voteHash := vs.GenerateVoteHash(req.IPAddress, req.UserAgent, req.PostID)
vote = &database.Vote{
PostID: req.PostID,
Type: req.Type,
VoteHash: &voteHash,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
} else {
vote = &database.Vote{
UserID: &req.UserID,
PostID: req.PostID,
Type: req.Type,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
}
if err := vs.voteRepo.CreateOrUpdate(vote); err != nil {
return nil, fmt.Errorf("failed to create or update vote: %w", err)
}
}
if err := vs.updatePostVoteCounts(req.PostID); err != nil {
return nil, fmt.Errorf("failed to update post vote counts: %w", err)
}
updatedPost, err := vs.postRepo.GetByID(req.PostID)
if err != nil {
return nil, fmt.Errorf("failed to get updated post: %w", err)
}
return vs.buildVoteResponse(updatedPost, req.Type, isUnauthenticated), nil
}
func (vs *VoteService) GetUserVote(userID uint, postID uint, ipAddress, userAgent string) (*database.Vote, error) {
if userID > 0 {
vote, err := vs.voteRepo.GetByUserAndPost(userID, postID)
if err == nil && vote != nil {
return vote, nil
}
}
voteHash := vs.GenerateVoteHash(ipAddress, userAgent, postID)
vote, err := vs.voteRepo.GetByVoteHash(voteHash)
if err == nil && vote != nil {
return vote, nil
}
return nil, gorm.ErrRecordNotFound
}
func (vs *VoteService) GetPostVotes(postID uint) ([]database.Vote, error) {
votes, err := vs.voteRepo.GetByPostID(postID)
if err != nil {
return nil, err
}
return votes, nil
}
func (vs *VoteService) DeleteVotesByPostID(postID uint) error {
if vs.db != nil {
if err := vs.db.Unscoped().Where("post_id = ?", postID).Delete(&database.Vote{}).Error; err != nil {
return fmt.Errorf("failed to delete votes for post: %w", err)
}
return nil
}
votes, err := vs.voteRepo.GetByPostID(postID)
if err != nil {
return fmt.Errorf("failed to get votes: %w", err)
}
for _, vote := range votes {
if err := vs.voteRepo.Delete(vote.ID); err != nil {
return fmt.Errorf("failed to delete vote %d: %w", vote.ID, err)
}
}
return nil
}
func (vs *VoteService) validateVoteRequest(req VoteRequest) error {
if req.PostID == 0 {
return errors.New("post ID is required")
}
if req.Type != database.VoteUp && req.Type != database.VoteDown && req.Type != database.VoteNone {
return errors.New("invalid vote type")
}
return nil
}
func (vs *VoteService) buildVoteResponse(post *database.Post, voteType database.VoteType, isUnauthenticated bool) *VoteResponse {
message := "Vote updated successfully"
if voteType == database.VoteNone {
message = "Vote removed successfully"
}
return &VoteResponse{
PostID: post.ID,
Type: voteType,
UpVotes: post.UpVotes,
DownVotes: post.DownVotes,
Score: post.Score,
Message: message,
IsUnauthenticated: isUnauthenticated,
}
}
func (vs *VoteService) updatePostVoteCounts(postID uint) error {
if vs.db == nil {
votes, err := vs.voteRepo.GetByPostID(postID)
if err != nil {
return fmt.Errorf("failed to get votes: %w", err)
}
upVotes, downVotes := vs.countVotes(votes)
score := upVotes - downVotes
post, err := vs.postRepo.GetByID(postID)
if err != nil {
return fmt.Errorf("failed to get post: %w", err)
}
post.UpVotes = upVotes
post.DownVotes = downVotes
post.Score = score
return vs.postRepo.Update(post)
}
return vs.updatePostVoteCountsWithTx(vs.db, postID)
}
func (vs *VoteService) updatePostVoteCountsWithTx(tx *gorm.DB, postID uint) error {
txVoteRepo := vs.voteRepo.WithTx(tx)
txPostRepo := vs.postRepo.WithTx(tx)
votes, err := txVoteRepo.GetByPostID(postID)
if err != nil {
return fmt.Errorf("failed to get votes: %w", err)
}
upVotes, downVotes := vs.countVotes(votes)
score := upVotes - downVotes
post, err := txPostRepo.GetByID(postID)
if err != nil {
return fmt.Errorf("failed to get post: %w", err)
}
post.UpVotes = upVotes
post.DownVotes = downVotes
post.Score = score
return txPostRepo.Update(post)
}
func (vs *VoteService) countVotes(votes []database.Vote) (int, int) {
upVotes := 0
downVotes := 0
for _, vote := range votes {
switch vote.Type {
case database.VoteUp:
upVotes++
case database.VoteDown:
downVotes++
}
}
return upVotes, downVotes
}
func (vs *VoteService) GetVoteStatistics() (int64, int64, error) {
totalCount, err := vs.voteRepo.Count()
if err != nil {
return 0, 0, fmt.Errorf("failed to get vote count: %w", err)
}
return totalCount, 0, nil
}

View File

@@ -0,0 +1,918 @@
package services
import (
"errors"
"fmt"
"strings"
"sync"
"testing"
"gorm.io/gorm"
"goyco/internal/database"
"goyco/internal/repositories"
)
type mockVoteRepo struct {
votes map[uint]*database.Vote
byUserPost map[string]*database.Vote
byVoteHash map[string]*database.Vote
nextID uint
createErr error
getByUserAndPostErr error
getByVoteHashErr error
getByPostIDErr error
updateErr error
deleteErr error
createCalls int
updateCalls int
deleteCalls int
mu sync.RWMutex
}
func newMockVoteRepo() *mockVoteRepo {
return &mockVoteRepo{
votes: make(map[uint]*database.Vote),
byUserPost: make(map[string]*database.Vote),
byVoteHash: make(map[string]*database.Vote),
nextID: 1,
}
}
func (m *mockVoteRepo) Create(vote *database.Vote) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.createErr != nil {
return m.createErr
}
var key string
if vote.UserID != nil {
key = m.key(*vote.UserID, vote.PostID)
} else if vote.VoteHash != nil {
key = *vote.VoteHash
} else {
return errors.New("vote must have either user_id or vote_hash")
}
if existingVote, exists := m.byUserPost[key]; exists {
existingVote.Type = vote.Type
existingVote.UpdatedAt = vote.UpdatedAt
vote.ID = existingVote.ID
return nil
}
vote.ID = m.nextID
m.nextID++
voteCopy := *vote
m.votes[vote.ID] = &voteCopy
m.byUserPost[key] = &voteCopy
if vote.VoteHash != nil {
m.byVoteHash[*vote.VoteHash] = &voteCopy
}
m.createCalls++
return nil
}
func (m *mockVoteRepo) CreateOrUpdate(vote *database.Vote) error {
return m.Create(vote)
}
func (m *mockVoteRepo) GetByID(id uint) (*database.Vote, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if vote, ok := m.votes[id]; ok {
voteCopy := *vote
return &voteCopy, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *mockVoteRepo) GetByUserAndPost(userID, postID uint) (*database.Vote, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.getByUserAndPostErr != nil {
return nil, m.getByUserAndPostErr
}
key := m.key(userID, postID)
if vote, ok := m.byUserPost[key]; ok {
voteCopy := *vote
return &voteCopy, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *mockVoteRepo) GetByVoteHash(voteHash string) (*database.Vote, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.getByVoteHashErr != nil {
return nil, m.getByVoteHashErr
}
if vote, ok := m.byVoteHash[voteHash]; ok {
voteCopy := *vote
return &voteCopy, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *mockVoteRepo) GetByPostID(postID uint) ([]database.Vote, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.getByPostIDErr != nil {
return nil, m.getByPostIDErr
}
var votes []database.Vote
for _, vote := range m.votes {
if vote.PostID == postID {
votes = append(votes, *vote)
}
}
return votes, nil
}
func (m *mockVoteRepo) GetByUserID(userID uint) ([]database.Vote, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var votes []database.Vote
for _, vote := range m.votes {
if vote.UserID != nil && *vote.UserID == userID {
votes = append(votes, *vote)
}
}
return votes, nil
}
func (m *mockVoteRepo) Update(vote *database.Vote) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.updateErr != nil {
return m.updateErr
}
if _, ok := m.votes[vote.ID]; !ok {
return gorm.ErrRecordNotFound
}
voteCopy := *vote
m.votes[vote.ID] = &voteCopy
var key string
if vote.UserID != nil {
key = m.key(*vote.UserID, vote.PostID)
m.byUserPost[key] = &voteCopy
}
if vote.VoteHash != nil {
m.byVoteHash[*vote.VoteHash] = &voteCopy
}
m.updateCalls++
return nil
}
func (m *mockVoteRepo) Delete(id uint) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.deleteErr != nil {
return m.deleteErr
}
vote, ok := m.votes[id]
if !ok {
return gorm.ErrRecordNotFound
}
delete(m.votes, id)
var key string
if vote.UserID != nil {
key = m.key(*vote.UserID, vote.PostID)
delete(m.byUserPost, key)
}
if vote.VoteHash != nil {
delete(m.byVoteHash, *vote.VoteHash)
}
m.deleteCalls++
return nil
}
func (m *mockVoteRepo) Count() (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()
return int64(len(m.votes)), nil
}
func (m *mockVoteRepo) CountByPostID(postID uint) (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()
count := int64(0)
for _, vote := range m.votes {
if vote.PostID == postID {
count++
}
}
return count, nil
}
func (m *mockVoteRepo) CountByUserID(userID uint) (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()
count := int64(0)
for _, vote := range m.votes {
if vote.UserID != nil && *vote.UserID == userID {
count++
}
}
return count, nil
}
func (m *mockVoteRepo) WithTx(tx *gorm.DB) repositories.VoteRepository {
return m
}
func (m *mockVoteRepo) key(userID, postID uint) string {
return fmt.Sprintf("%d:%d", userID, postID)
}
type mockPostRepo struct {
posts map[uint]*database.Post
nextID uint
getErr error
updateErr error
mu sync.RWMutex
}
func newMockPostRepo() *mockPostRepo {
return &mockPostRepo{
posts: make(map[uint]*database.Post),
nextID: 1,
}
}
func (m *mockPostRepo) GetByID(id uint) (*database.Post, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.getErr != nil {
return nil, m.getErr
}
if post, ok := m.posts[id]; ok {
postCopy := *post
return &postCopy, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *mockPostRepo) Update(post *database.Post) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.updateErr != nil {
return m.updateErr
}
if _, ok := m.posts[post.ID]; !ok {
return gorm.ErrRecordNotFound
}
postCopy := *post
m.posts[post.ID] = &postCopy
return nil
}
func (m *mockPostRepo) Create(post *database.Post) error {
m.mu.Lock()
defer m.mu.Unlock()
post.ID = m.nextID
m.nextID++
postCopy := *post
m.posts[post.ID] = &postCopy
return nil
}
func (m *mockPostRepo) GetByURL(url string) (*database.Post, error) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, post := range m.posts {
if post.URL == url {
postCopy := *post
return &postCopy, nil
}
}
return nil, gorm.ErrRecordNotFound
}
func (m *mockPostRepo) GetByAuthorID(authorID uint) ([]database.Post, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var posts []database.Post
for _, post := range m.posts {
if post.AuthorID != nil && *post.AuthorID == authorID {
posts = append(posts, *post)
}
}
return posts, nil
}
func (m *mockPostRepo) GetAll(limit, offset int) ([]database.Post, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var posts []database.Post
count := 0
for _, post := range m.posts {
if count >= offset && count < offset+limit {
posts = append(posts, *post)
}
count++
}
return posts, nil
}
func (m *mockPostRepo) Count() (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()
return int64(len(m.posts)), nil
}
func (m *mockPostRepo) Delete(id uint) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.posts[id]; !ok {
return gorm.ErrRecordNotFound
}
delete(m.posts, id)
return nil
}
func (m *mockPostRepo) GetByUserID(userID uint, limit, offset int) ([]database.Post, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var posts []database.Post
count := 0
for _, post := range m.posts {
if post.AuthorID != nil && *post.AuthorID == userID {
if count >= offset && count < offset+limit {
posts = append(posts, *post)
}
count++
}
}
return posts, nil
}
func (m *mockPostRepo) CountByUserID(userID uint) (int64, error) {
m.mu.RLock()
defer m.mu.RUnlock()
count := int64(0)
for _, post := range m.posts {
if post.AuthorID != nil && *post.AuthorID == userID {
count++
}
}
return count, nil
}
func (m *mockPostRepo) GetTopPosts(limit int) ([]database.Post, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var posts []database.Post
count := 0
for _, post := range m.posts {
if count < limit {
posts = append(posts, *post)
count++
}
}
return posts, nil
}
func (m *mockPostRepo) GetNewestPosts(limit int) ([]database.Post, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var posts []database.Post
count := 0
for _, post := range m.posts {
if count < limit {
posts = append(posts, *post)
count++
}
}
return posts, nil
}
func (m *mockPostRepo) Search(query string, limit, offset int) ([]database.Post, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var posts []database.Post
count := 0
for _, post := range m.posts {
if strings.Contains(strings.ToLower(post.Title), strings.ToLower(query)) {
if count >= offset && count < offset+limit {
posts = append(posts, *post)
}
count++
}
}
return posts, nil
}
func (m *mockPostRepo) GetPostsByDeletedUsers() ([]database.Post, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var posts []database.Post
for _, post := range m.posts {
if post.AuthorID == nil {
posts = append(posts, *post)
}
}
return posts, nil
}
func (m *mockPostRepo) HardDeletePostsByDeletedUsers() (int64, error) {
m.mu.Lock()
defer m.mu.Unlock()
count := int64(0)
for id, post := range m.posts {
if post.AuthorID == nil {
delete(m.posts, id)
count++
}
}
return count, nil
}
func (m *mockPostRepo) HardDeleteAll() (int64, error) {
m.mu.Lock()
defer m.mu.Unlock()
count := int64(len(m.posts))
m.posts = make(map[uint]*database.Post)
return count, nil
}
func (m *mockPostRepo) WithTx(tx *gorm.DB) repositories.PostRepository {
return m
}
func TestVoteService_CastVote_Authenticated(t *testing.T) {
voteRepo := newMockVoteRepo()
postRepo := newMockPostRepo()
service := NewVoteService(voteRepo, postRepo, nil)
post := &database.Post{
ID: 1,
Title: "Test Post",
URL: "https://example.com",
AuthorID: &[]uint{1}[0],
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
postRepo.posts[1] = post
req := VoteRequest{
UserID: 1,
PostID: 1,
Type: database.VoteUp,
IPAddress: "127.0.0.1",
UserAgent: "test-agent",
}
result, err := service.CastVote(req)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if result == nil {
t.Fatal("Expected result, got nil")
}
if result.Type != database.VoteUp {
t.Errorf("Expected vote type 'up', got '%v'", result.Type)
}
if result.UpVotes != 1 {
t.Errorf("Expected up votes to be 1, got %d", result.UpVotes)
}
if result.DownVotes != 0 {
t.Errorf("Expected down votes to be 0, got %d", result.DownVotes)
}
if result.Score != 1 {
t.Errorf("Expected score to be 1, got %d", result.Score)
}
if result.IsUnauthenticated {
t.Error("Expected IsUnauthenticated to be false for authenticated vote")
}
}
func TestVoteService_CastVote_Unauthenticated(t *testing.T) {
voteRepo := newMockVoteRepo()
postRepo := newMockPostRepo()
service := NewVoteService(voteRepo, postRepo, nil)
post := &database.Post{
ID: 1,
Title: "Test Post",
URL: "https://example.com",
AuthorID: &[]uint{1}[0],
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
postRepo.posts[1] = post
req := VoteRequest{
UserID: 0,
PostID: 1,
Type: database.VoteUp,
IPAddress: "127.0.0.1",
UserAgent: "test-agent",
}
result, err := service.CastVote(req)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if result == nil {
t.Fatal("Expected result, got nil")
}
if result.Type != database.VoteUp {
t.Errorf("Expected vote type 'up', got '%v'", result.Type)
}
if result.UpVotes != 1 {
t.Errorf("Expected up votes to be 1, got %d", result.UpVotes)
}
if result.DownVotes != 0 {
t.Errorf("Expected down votes to be 0, got %d", result.DownVotes)
}
if result.Score != 1 {
t.Errorf("Expected score to be 1, got %d", result.Score)
}
if !result.IsUnauthenticated {
t.Error("Expected IsUnauthenticated to be true for unauthenticated vote")
}
}
func TestVoteService_CastVote_UpdateExisting(t *testing.T) {
voteRepo := newMockVoteRepo()
postRepo := newMockPostRepo()
service := NewVoteService(voteRepo, postRepo, nil)
post := &database.Post{
ID: 1,
Title: "Test Post",
URL: "https://example.com",
AuthorID: &[]uint{1}[0],
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
postRepo.posts[1] = post
req := VoteRequest{
UserID: 1,
PostID: 1,
Type: database.VoteUp,
IPAddress: "127.0.0.1",
UserAgent: "test-agent",
}
_, err := service.CastVote(req)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
req.Type = database.VoteDown
result, err := service.CastVote(req)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if result.UpVotes != 0 {
t.Errorf("Expected up votes to be 0, got %d", result.UpVotes)
}
if result.DownVotes != 1 {
t.Errorf("Expected down votes to be 1, got %d", result.DownVotes)
}
if result.Score != -1 {
t.Errorf("Expected score to be -1, got %d", result.Score)
}
}
func TestVoteService_CastVote_RemoveVote(t *testing.T) {
voteRepo := newMockVoteRepo()
postRepo := newMockPostRepo()
service := NewVoteService(voteRepo, postRepo, nil)
post := &database.Post{
ID: 1,
Title: "Test Post",
URL: "https://example.com",
AuthorID: &[]uint{1}[0],
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
postRepo.posts[1] = post
req := VoteRequest{
UserID: 1,
PostID: 1,
Type: database.VoteUp,
IPAddress: "127.0.0.1",
UserAgent: "test-agent",
}
_, err := service.CastVote(req)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
req.Type = database.VoteNone
result, err := service.CastVote(req)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if result.UpVotes != 0 {
t.Errorf("Expected up votes to be 0, got %d", result.UpVotes)
}
if result.DownVotes != 0 {
t.Errorf("Expected down votes to be 0, got %d", result.DownVotes)
}
if result.Score != 0 {
t.Errorf("Expected score to be 0, got %d", result.Score)
}
}
func TestVoteService_GetUserVote_Authenticated(t *testing.T) {
voteRepo := newMockVoteRepo()
postRepo := newMockPostRepo()
service := NewVoteService(voteRepo, postRepo, nil)
userID := uint(1)
vote := &database.Vote{
ID: 1,
UserID: &userID,
PostID: 1,
Type: database.VoteUp,
}
voteRepo.votes[1] = vote
voteRepo.byUserPost["1:1"] = vote
result, err := service.GetUserVote(1, 1, "127.0.0.1", "test-agent")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if result == nil {
t.Fatal("Expected vote, got nil")
}
if result.Type != database.VoteUp {
t.Errorf("Expected vote type 'up', got '%v'", result.Type)
}
if result.UserID == nil || *result.UserID != 1 {
t.Errorf("Expected user ID 1, got %v", result.UserID)
}
}
func TestVoteService_GetUserVote_Unauthenticated(t *testing.T) {
voteRepo := newMockVoteRepo()
postRepo := newMockPostRepo()
service := NewVoteService(voteRepo, postRepo, nil)
voteHash := service.GenerateVoteHash("127.0.0.1", "test-agent", 1)
vote := &database.Vote{
ID: 1,
UserID: nil,
PostID: 1,
Type: database.VoteUp,
VoteHash: &voteHash,
}
voteRepo.votes[1] = vote
voteRepo.byVoteHash[voteHash] = vote
result, err := service.GetUserVote(0, 1, "127.0.0.1", "test-agent")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if result == nil {
t.Fatal("Expected vote, got nil")
}
if result.Type != database.VoteUp {
t.Errorf("Expected vote type 'up', got '%v'", result.Type)
}
if result.UserID != nil {
t.Error("Expected UserID to be nil for unauthenticated vote")
}
if result.VoteHash == nil || *result.VoteHash != voteHash {
t.Errorf("Expected vote hash '%s', got %v", voteHash, result.VoteHash)
}
}
func TestVoteService_GetPostVotes(t *testing.T) {
voteRepo := newMockVoteRepo()
postRepo := newMockPostRepo()
service := NewVoteService(voteRepo, postRepo, nil)
userID1 := uint(1)
userID2 := uint(2)
voteHash := "test-hash"
vote1 := &database.Vote{
ID: 1,
UserID: &userID1,
PostID: 1,
Type: database.VoteUp,
}
vote2 := &database.Vote{
ID: 2,
UserID: &userID2,
PostID: 1,
Type: database.VoteDown,
}
vote3 := &database.Vote{
ID: 3,
UserID: nil,
PostID: 1,
Type: database.VoteUp,
VoteHash: &voteHash,
}
voteRepo.votes[1] = vote1
voteRepo.votes[2] = vote2
voteRepo.votes[3] = vote3
votes, err := service.GetPostVotes(1)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if len(votes) != 3 {
t.Errorf("Expected 3 votes, got %d", len(votes))
}
hasAuthenticated := false
hasUnauthenticated := false
for _, vote := range votes {
if vote.UserID != nil {
hasAuthenticated = true
}
if vote.VoteHash != nil {
hasUnauthenticated = true
}
}
if !hasAuthenticated {
t.Error("Expected to find authenticated votes")
}
if !hasUnauthenticated {
t.Error("Expected to find unauthenticated votes")
}
}
func TestVoteService_GetVoteStatistics(t *testing.T) {
voteRepo := newMockVoteRepo()
postRepo := newMockPostRepo()
service := NewVoteService(voteRepo, postRepo, nil)
userID1 := uint(1)
userID2 := uint(2)
voteHash := "test-hash"
vote1 := &database.Vote{
ID: 1,
UserID: &userID1,
PostID: 1,
Type: database.VoteUp,
}
vote2 := &database.Vote{
ID: 2,
UserID: &userID2,
PostID: 1,
Type: database.VoteDown,
}
vote3 := &database.Vote{
ID: 3,
UserID: nil,
PostID: 1,
Type: database.VoteUp,
VoteHash: &voteHash,
}
voteRepo.votes[1] = vote1
voteRepo.votes[2] = vote2
voteRepo.votes[3] = vote3
authenticatedCount, anonymousCount, err := service.GetVoteStatistics()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if authenticatedCount != 3 {
t.Errorf("Expected total count to be 3, got %d", authenticatedCount)
}
if anonymousCount != 0 {
t.Errorf("Expected unauthenticated count to be 0, got %d", anonymousCount)
}
}
func TestVoteService_Validation(t *testing.T) {
voteRepo := newMockVoteRepo()
postRepo := newMockPostRepo()
service := NewVoteService(voteRepo, postRepo, nil)
req := VoteRequest{
UserID: 1,
PostID: 0,
Type: database.VoteUp,
IPAddress: "127.0.0.1",
UserAgent: "test-agent",
}
_, err := service.CastVote(req)
if err == nil {
t.Error("Expected error for missing post ID")
}
req.PostID = 1
req.Type = "invalid"
_, err = service.CastVote(req)
if err == nil {
t.Error("Expected error for invalid vote type")
}
}
func TestVoteService_PostNotFound(t *testing.T) {
voteRepo := newMockVoteRepo()
postRepo := newMockPostRepo()
service := NewVoteService(voteRepo, postRepo, nil)
req := VoteRequest{
UserID: 1,
PostID: 999,
Type: database.VoteUp,
IPAddress: "127.0.0.1",
UserAgent: "test-agent",
}
_, err := service.CastVote(req)
if err == nil {
t.Error("Expected error for non-existent post")
}
if !strings.Contains(err.Error(), "post not found") {
t.Errorf("Expected 'post not found' error, got %v", err)
}
}