To gitea and beyond, let's go(-yco)
This commit is contained in:
176
internal/services/account_deletion_service.go
Normal file
176
internal/services/account_deletion_service.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AccountDeletionService struct {
|
||||
userRepo repositories.UserRepository
|
||||
postRepo repositories.PostRepository
|
||||
deletionRepo repositories.AccountDeletionRepository
|
||||
emailService *EmailService
|
||||
}
|
||||
|
||||
func NewAccountDeletionService(userRepo repositories.UserRepository, postRepo repositories.PostRepository, deletionRepo repositories.AccountDeletionRepository, emailService *EmailService) *AccountDeletionService {
|
||||
return &AccountDeletionService{
|
||||
userRepo: userRepo,
|
||||
postRepo: postRepo,
|
||||
deletionRepo: deletionRepo,
|
||||
emailService: emailService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountDeletionService) GetUserIDFromDeletionToken(token string) (uint, error) {
|
||||
trimmed := TrimString(token)
|
||||
if trimmed == "" {
|
||||
return 0, ErrInvalidDeletionToken
|
||||
}
|
||||
|
||||
if s.deletionRepo == nil {
|
||||
return 0, fmt.Errorf("account deletion repository not configured")
|
||||
}
|
||||
|
||||
hashed := HashVerificationToken(trimmed)
|
||||
req, err := s.deletionRepo.GetByTokenHash(hashed)
|
||||
if err != nil {
|
||||
if IsRecordNotFound(err) {
|
||||
return 0, ErrInvalidDeletionToken
|
||||
}
|
||||
return 0, fmt.Errorf("lookup deletion request: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().After(req.ExpiresAt) {
|
||||
if delErr := s.deletionRepo.DeleteByID(req.ID); delErr != nil {
|
||||
log.Printf("Failed to delete expired deletion request %d: %v", req.ID, delErr)
|
||||
}
|
||||
return 0, ErrInvalidDeletionToken
|
||||
}
|
||||
|
||||
return req.UserID, nil
|
||||
}
|
||||
|
||||
func (s *AccountDeletionService) RequestAccountDeletion(userID uint) error {
|
||||
if userID == 0 {
|
||||
return fmt.Errorf("invalid user identifier")
|
||||
}
|
||||
|
||||
if s.deletionRepo == nil {
|
||||
return fmt.Errorf("account deletion repository not configured")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
if IsRecordNotFound(err) {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
return fmt.Errorf("load user: %w", err)
|
||||
}
|
||||
|
||||
if err := s.deletionRepo.DeleteByUserID(userID); err != nil {
|
||||
return fmt.Errorf("clear existing deletion requests: %w", err)
|
||||
}
|
||||
|
||||
token, hash, err := generateVerificationToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &database.AccountDeletionRequest{
|
||||
UserID: userID,
|
||||
TokenHash: hash,
|
||||
ExpiresAt: time.Now().Add(time.Duration(deletionTokenExpirationHours) * time.Hour),
|
||||
}
|
||||
|
||||
if err := s.deletionRepo.Create(req); err != nil {
|
||||
return fmt.Errorf("create deletion request: %w", err)
|
||||
}
|
||||
|
||||
if err := s.emailService.SendAccountDeletionEmail(user, token); err != nil {
|
||||
if delErr := s.deletionRepo.DeleteByID(req.ID); delErr != nil {
|
||||
log.Printf("Failed to cleanup deletion request %d after email failure: %v", req.ID, delErr)
|
||||
}
|
||||
return fmt.Errorf("send deletion confirmation email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountDeletionService) validateAndGetDeletionRequest(token string) (*database.AccountDeletionRequest, *database.User, error) {
|
||||
trimmed := TrimString(token)
|
||||
if trimmed == "" {
|
||||
return nil, nil, ErrInvalidDeletionToken
|
||||
}
|
||||
|
||||
if s.deletionRepo == nil {
|
||||
return nil, nil, fmt.Errorf("account deletion repository not configured")
|
||||
}
|
||||
|
||||
hashed := HashVerificationToken(trimmed)
|
||||
req, err := s.deletionRepo.GetByTokenHash(hashed)
|
||||
if err != nil {
|
||||
if IsRecordNotFound(err) {
|
||||
return nil, nil, ErrInvalidDeletionToken
|
||||
}
|
||||
return nil, nil, fmt.Errorf("lookup deletion request: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().After(req.ExpiresAt) {
|
||||
if delErr := s.deletionRepo.DeleteByID(req.ID); delErr != nil {
|
||||
log.Printf("Failed to delete expired deletion request %d: %v", req.ID, delErr)
|
||||
}
|
||||
return nil, nil, ErrInvalidDeletionToken
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(req.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
if delErr := s.deletionRepo.DeleteByID(req.ID); delErr != nil {
|
||||
log.Printf("Failed to delete orphaned deletion request %d: %v", req.ID, delErr)
|
||||
}
|
||||
return nil, nil, ErrInvalidDeletionToken
|
||||
}
|
||||
return nil, nil, fmt.Errorf("load user: %w", err)
|
||||
}
|
||||
|
||||
return req, user, nil
|
||||
}
|
||||
|
||||
func (s *AccountDeletionService) ConfirmAccountDeletion(token string) error {
|
||||
return s.ConfirmAccountDeletionWithPosts(token, true)
|
||||
}
|
||||
|
||||
func (s *AccountDeletionService) ConfirmAccountDeletionWithPosts(token string, deletePosts bool) error {
|
||||
req, user, err := s.validateAndGetDeletionRequest(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !deletePosts {
|
||||
if err := s.userRepo.SoftDeleteWithPosts(user.ID); err != nil {
|
||||
return fmt.Errorf("soft delete user with posts: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := s.userRepo.HardDelete(user.ID); err != nil {
|
||||
return fmt.Errorf("delete user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.deletionRepo.DeleteByID(req.ID); err != nil {
|
||||
return fmt.Errorf("clear deletion request: %w", err)
|
||||
}
|
||||
|
||||
if err := s.emailService.SendAccountDeletionNotificationEmail(user, deletePosts); err != nil {
|
||||
log.Printf("Failed to send account deletion notification email for user %d: %v", user.ID, err)
|
||||
return fmt.Errorf("%w: %w", ErrDeletionEmailFailed, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
529
internal/services/account_deletion_service_test.go
Normal file
529
internal/services/account_deletion_service_test.go
Normal file
@@ -0,0 +1,529 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/testutils"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type errorEmailSender struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *errorEmailSender) Send(to, subject, body string) error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type mockAccountDeletionRepository struct {
|
||||
requests map[uint]*database.AccountDeletionRequest
|
||||
requestsByTokenHash map[string]*database.AccountDeletionRequest
|
||||
nextID uint
|
||||
createErr error
|
||||
getByTokenHashErr error
|
||||
deleteByIDErr error
|
||||
deleteByUserIDErr error
|
||||
}
|
||||
|
||||
func newMockAccountDeletionRepository() *mockAccountDeletionRepository {
|
||||
return &mockAccountDeletionRepository{
|
||||
requests: make(map[uint]*database.AccountDeletionRequest),
|
||||
requestsByTokenHash: make(map[string]*database.AccountDeletionRequest),
|
||||
nextID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockAccountDeletionRepository) Create(req *database.AccountDeletionRequest) error {
|
||||
if m.createErr != nil {
|
||||
return m.createErr
|
||||
}
|
||||
|
||||
req.ID = m.nextID
|
||||
m.nextID++
|
||||
|
||||
reqCopy := *req
|
||||
m.requests[req.ID] = &reqCopy
|
||||
m.requestsByTokenHash[req.TokenHash] = &reqCopy
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAccountDeletionRepository) GetByTokenHash(hash string) (*database.AccountDeletionRequest, error) {
|
||||
if m.getByTokenHashErr != nil {
|
||||
return nil, m.getByTokenHashErr
|
||||
}
|
||||
|
||||
if req, ok := m.requestsByTokenHash[hash]; ok {
|
||||
reqCopy := *req
|
||||
return &reqCopy, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *mockAccountDeletionRepository) DeleteByID(id uint) error {
|
||||
if m.deleteByIDErr != nil {
|
||||
return m.deleteByIDErr
|
||||
}
|
||||
|
||||
if req, ok := m.requests[id]; ok {
|
||||
delete(m.requests, id)
|
||||
delete(m.requestsByTokenHash, req.TokenHash)
|
||||
return nil
|
||||
}
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *mockAccountDeletionRepository) DeleteByUserID(userID uint) error {
|
||||
if m.deleteByUserIDErr != nil {
|
||||
return m.deleteByUserIDErr
|
||||
}
|
||||
|
||||
for id, req := range m.requests {
|
||||
if req.UserID == userID {
|
||||
delete(m.requests, id)
|
||||
delete(m.requestsByTokenHash, req.TokenHash)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewAccountDeletionService(t *testing.T) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
deletionRepo := newMockAccountDeletionRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
|
||||
service := NewAccountDeletionService(userRepo, postRepo, deletionRepo, emailService)
|
||||
|
||||
if service == nil {
|
||||
t.Fatal("expected service to be created")
|
||||
}
|
||||
|
||||
if service.userRepo != userRepo {
|
||||
t.Error("expected userRepo to be set")
|
||||
}
|
||||
|
||||
if service.postRepo != postRepo {
|
||||
t.Error("expected postRepo to be set")
|
||||
}
|
||||
|
||||
if service.deletionRepo != deletionRepo {
|
||||
t.Error("expected deletionRepo to be set")
|
||||
}
|
||||
|
||||
if service.emailService != emailService {
|
||||
t.Error("expected emailService to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountDeletionService_GetUserIDFromDeletionToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
setupRepo func() *mockAccountDeletionRepository
|
||||
expectedID uint
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "successful retrieval",
|
||||
token: "valid-token",
|
||||
setupRepo: func() *mockAccountDeletionRepository {
|
||||
repo := newMockAccountDeletionRepository()
|
||||
req := &database.AccountDeletionRequest{
|
||||
UserID: 1,
|
||||
TokenHash: HashVerificationToken("valid-token"),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
}
|
||||
repo.Create(req)
|
||||
return repo
|
||||
},
|
||||
expectedID: 1,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
token: "",
|
||||
setupRepo: func() *mockAccountDeletionRepository { return newMockAccountDeletionRepository() },
|
||||
expectedID: 0,
|
||||
expectedError: ErrInvalidDeletionToken,
|
||||
},
|
||||
{
|
||||
name: "whitespace only token",
|
||||
token: " ",
|
||||
setupRepo: func() *mockAccountDeletionRepository { return newMockAccountDeletionRepository() },
|
||||
expectedID: 0,
|
||||
expectedError: ErrInvalidDeletionToken,
|
||||
},
|
||||
{
|
||||
name: "token not found",
|
||||
token: "invalid-token",
|
||||
setupRepo: func() *mockAccountDeletionRepository {
|
||||
return newMockAccountDeletionRepository()
|
||||
},
|
||||
expectedID: 0,
|
||||
expectedError: ErrInvalidDeletionToken,
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
token: "expired-token",
|
||||
setupRepo: func() *mockAccountDeletionRepository {
|
||||
repo := newMockAccountDeletionRepository()
|
||||
req := &database.AccountDeletionRequest{
|
||||
UserID: 1,
|
||||
TokenHash: HashVerificationToken("expired-token"),
|
||||
ExpiresAt: time.Now().Add(-time.Hour),
|
||||
}
|
||||
repo.Create(req)
|
||||
return repo
|
||||
},
|
||||
expectedID: 0,
|
||||
expectedError: ErrInvalidDeletionToken,
|
||||
},
|
||||
{
|
||||
name: "repository error",
|
||||
token: "valid-token",
|
||||
setupRepo: func() *mockAccountDeletionRepository {
|
||||
repo := newMockAccountDeletionRepository()
|
||||
repo.getByTokenHashErr = errors.New("database error")
|
||||
return repo
|
||||
},
|
||||
expectedID: 0,
|
||||
expectedError: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
deletionRepo := tt.setupRepo()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
|
||||
service := NewAccountDeletionService(userRepo, postRepo, deletionRepo, emailService)
|
||||
|
||||
userID, err := service.GetUserIDFromDeletionToken(tt.token)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
if !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
} else {
|
||||
if tt.name == "repository error" || tt.name == "nil repository" {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
} else if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if userID != tt.expectedID {
|
||||
t.Errorf("expected userID %d, got %d", tt.expectedID, userID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountDeletionService_RequestAccountDeletion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uint
|
||||
setupMocks func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender)
|
||||
expectedError bool
|
||||
checkToken bool
|
||||
}{
|
||||
{
|
||||
name: "successful request",
|
||||
userID: 1,
|
||||
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{ID: 1, Username: "testuser", Email: "test@example.com"}
|
||||
userRepo.Create(user)
|
||||
|
||||
deletionRepo := newMockAccountDeletionRepository()
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
return userRepo, deletionRepo, emailSender
|
||||
},
|
||||
expectedError: false,
|
||||
checkToken: true,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: 0,
|
||||
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
|
||||
return testutils.NewMockUserRepository(), newMockAccountDeletionRepository(), &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: true,
|
||||
checkToken: false,
|
||||
},
|
||||
{
|
||||
name: "user not found",
|
||||
userID: 999,
|
||||
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
|
||||
return testutils.NewMockUserRepository(), newMockAccountDeletionRepository(), &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: true,
|
||||
checkToken: false,
|
||||
},
|
||||
{
|
||||
name: "email service error",
|
||||
userID: 1,
|
||||
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{ID: 1, Username: "testuser", Email: "test@example.com"}
|
||||
userRepo.Create(user)
|
||||
|
||||
deletionRepo := newMockAccountDeletionRepository()
|
||||
var errorSender errorEmailSender
|
||||
errorSender.err = errors.New("email service error")
|
||||
emailSender := &errorSender
|
||||
|
||||
return userRepo, deletionRepo, emailSender
|
||||
},
|
||||
expectedError: true,
|
||||
checkToken: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo, deletionRepo, emailSender := tt.setupMocks()
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
|
||||
|
||||
service := NewAccountDeletionService(userRepo, postRepo, deletionRepo, emailService)
|
||||
|
||||
err := service.RequestAccountDeletion(tt.userID)
|
||||
|
||||
if tt.expectedError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if tt.checkToken {
|
||||
if len(deletionRepo.requests) == 0 {
|
||||
t.Error("expected deletion request to be created")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountDeletionService_ConfirmAccountDeletion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
setupMocks func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender)
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "successful deletion",
|
||||
token: "valid-token",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{ID: 1, Username: "testuser", Email: "test@example.com"}
|
||||
userRepo.Create(user)
|
||||
|
||||
deletionRepo := newMockAccountDeletionRepository()
|
||||
req := &database.AccountDeletionRequest{
|
||||
UserID: 1,
|
||||
TokenHash: HashVerificationToken("valid-token"),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
}
|
||||
deletionRepo.Create(req)
|
||||
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
return userRepo, deletionRepo, emailSender
|
||||
},
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
token: "",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
|
||||
return testutils.NewMockUserRepository(), newMockAccountDeletionRepository(), &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: ErrInvalidDeletionToken,
|
||||
},
|
||||
{
|
||||
name: "token not found",
|
||||
token: "invalid-token",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
|
||||
return testutils.NewMockUserRepository(), newMockAccountDeletionRepository(), &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: ErrInvalidDeletionToken,
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
token: "expired-token",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
|
||||
deletionRepo := newMockAccountDeletionRepository()
|
||||
req := &database.AccountDeletionRequest{
|
||||
UserID: 1,
|
||||
TokenHash: HashVerificationToken("expired-token"),
|
||||
ExpiresAt: time.Now().Add(-time.Hour),
|
||||
}
|
||||
deletionRepo.Create(req)
|
||||
|
||||
return testutils.NewMockUserRepository(), deletionRepo, &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: ErrInvalidDeletionToken,
|
||||
},
|
||||
{
|
||||
name: "user not found",
|
||||
token: "valid-token",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
|
||||
deletionRepo := newMockAccountDeletionRepository()
|
||||
req := &database.AccountDeletionRequest{
|
||||
UserID: 999,
|
||||
TokenHash: HashVerificationToken("valid-token"),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
}
|
||||
deletionRepo.Create(req)
|
||||
|
||||
return testutils.NewMockUserRepository(), deletionRepo, &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: ErrInvalidDeletionToken,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo, deletionRepo, emailSender := tt.setupMocks()
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
|
||||
|
||||
service := NewAccountDeletionService(userRepo, postRepo, deletionRepo, emailService)
|
||||
|
||||
err := service.ConfirmAccountDeletion(tt.token)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
if !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountDeletionService_ConfirmAccountDeletionWithPosts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
deletePosts bool
|
||||
setupMocks func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender)
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "successful deletion without posts",
|
||||
token: "valid-token",
|
||||
deletePosts: false,
|
||||
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{ID: 1, Username: "testuser", Email: "test@example.com"}
|
||||
userRepo.Create(user)
|
||||
|
||||
deletionRepo := newMockAccountDeletionRepository()
|
||||
req := &database.AccountDeletionRequest{
|
||||
UserID: 1,
|
||||
TokenHash: HashVerificationToken("valid-token"),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
}
|
||||
deletionRepo.Create(req)
|
||||
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
return userRepo, deletionRepo, emailSender
|
||||
},
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "successful deletion with posts",
|
||||
token: "valid-token",
|
||||
deletePosts: true,
|
||||
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{ID: 1, Username: "testuser", Email: "test@example.com"}
|
||||
userRepo.Create(user)
|
||||
|
||||
deletionRepo := newMockAccountDeletionRepository()
|
||||
req := &database.AccountDeletionRequest{
|
||||
UserID: 1,
|
||||
TokenHash: HashVerificationToken("valid-token"),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
}
|
||||
deletionRepo.Create(req)
|
||||
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
return userRepo, deletionRepo, emailSender
|
||||
},
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
token: "",
|
||||
deletePosts: false,
|
||||
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
|
||||
return testutils.NewMockUserRepository(), newMockAccountDeletionRepository(), &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: ErrInvalidDeletionToken,
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
token: "expired-token",
|
||||
deletePosts: false,
|
||||
setupMocks: func() (*testutils.MockUserRepository, *mockAccountDeletionRepository, EmailSender) {
|
||||
deletionRepo := newMockAccountDeletionRepository()
|
||||
req := &database.AccountDeletionRequest{
|
||||
UserID: 1,
|
||||
TokenHash: HashVerificationToken("expired-token"),
|
||||
ExpiresAt: time.Now().Add(-time.Hour),
|
||||
}
|
||||
deletionRepo.Create(req)
|
||||
|
||||
return testutils.NewMockUserRepository(), deletionRepo, &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: ErrInvalidDeletionToken,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo, deletionRepo, emailSender := tt.setupMocks()
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
|
||||
|
||||
service := NewAccountDeletionService(userRepo, postRepo, deletionRepo, emailService)
|
||||
|
||||
err := service.ConfirmAccountDeletionWithPosts(tt.token, tt.deletePosts)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
if !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountDeletionService_UserHasPosts(t *testing.T) {
|
||||
}
|
||||
139
internal/services/auth_facade.go
Normal file
139
internal/services/auth_facade.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/database"
|
||||
)
|
||||
|
||||
type AuthFacade struct {
|
||||
registrationService *RegistrationService
|
||||
passwordResetService *PasswordResetService
|
||||
deletionService *AccountDeletionService
|
||||
sessionService *SessionService
|
||||
userManagementService *UserManagementService
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewAuthFacade(
|
||||
registrationService *RegistrationService,
|
||||
passwordResetService *PasswordResetService,
|
||||
deletionService *AccountDeletionService,
|
||||
sessionService *SessionService,
|
||||
userManagementService *UserManagementService,
|
||||
config *config.Config,
|
||||
) *AuthFacade {
|
||||
return &AuthFacade{
|
||||
registrationService: registrationService,
|
||||
passwordResetService: passwordResetService,
|
||||
deletionService: deletionService,
|
||||
sessionService: sessionService,
|
||||
userManagementService: userManagementService,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *AuthFacade) Register(username, email, password string) (*RegistrationResult, error) {
|
||||
return f.registrationService.Register(username, email, password)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) Login(username, password string) (*AuthResult, error) {
|
||||
return f.sessionService.Login(username, password)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) VerifyToken(tokenString string) (uint, error) {
|
||||
return f.sessionService.VerifyToken(tokenString)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) ConfirmEmail(token string) (*database.User, error) {
|
||||
return f.registrationService.ConfirmEmail(token)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) ResendVerificationEmail(email string) error {
|
||||
return f.registrationService.ResendVerificationEmail(email)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) RequestPasswordReset(usernameOrEmail string) error {
|
||||
return f.passwordResetService.RequestPasswordReset(usernameOrEmail)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) ResetPassword(token, newPassword string) error {
|
||||
user, err := f.passwordResetService.GetUserByResetToken(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := f.passwordResetService.ResetPassword(token, newPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := f.sessionService.InvalidateAllSessions(user.ID); err != nil {
|
||||
return fmt.Errorf("invalidate sessions: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *AuthFacade) UpdateUsername(userID uint, username string) (*database.User, error) {
|
||||
return f.userManagementService.UpdateUsername(userID, username)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) UpdateEmail(userID uint, email string) (*database.User, error) {
|
||||
return f.userManagementService.UpdateEmail(userID, email)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) UpdatePassword(userID uint, currentPassword, newPassword string) (*database.User, error) {
|
||||
user, err := f.userManagementService.UpdatePassword(userID, currentPassword, newPassword)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := f.sessionService.InvalidateAllSessions(userID); err != nil {
|
||||
return nil, fmt.Errorf("invalidate sessions: %w", err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (f *AuthFacade) RequestAccountDeletion(userID uint) error {
|
||||
return f.deletionService.RequestAccountDeletion(userID)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) ConfirmAccountDeletion(token string) error {
|
||||
return f.deletionService.ConfirmAccountDeletion(token)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) ConfirmAccountDeletionWithPosts(token string, deletePosts bool) error {
|
||||
return f.deletionService.ConfirmAccountDeletionWithPosts(token, deletePosts)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) GetUserIDFromDeletionToken(token string) (uint, error) {
|
||||
return f.deletionService.GetUserIDFromDeletionToken(token)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) RefreshAccessToken(refreshToken string) (*AuthResult, error) {
|
||||
return f.sessionService.RefreshAccessToken(refreshToken)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) RevokeRefreshToken(refreshToken string) error {
|
||||
return f.sessionService.RevokeRefreshToken(refreshToken)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) RevokeAllUserTokens(userID uint) error {
|
||||
return f.sessionService.RevokeAllUserTokens(userID)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) InvalidateAllSessions(userID uint) error {
|
||||
return f.sessionService.InvalidateAllSessions(userID)
|
||||
}
|
||||
|
||||
func (f *AuthFacade) CleanupExpiredTokens() error {
|
||||
return f.sessionService.CleanupExpiredTokens()
|
||||
}
|
||||
|
||||
func (f *AuthFacade) GetAdminEmail() string {
|
||||
return f.config.App.AdminEmail
|
||||
}
|
||||
|
||||
func (f *AuthFacade) UserHasPosts(userID uint) (bool, int64, error) {
|
||||
return f.userManagementService.UserHasPosts(userID)
|
||||
}
|
||||
672
internal/services/auth_service_test.go
Normal file
672
internal/services/auth_service_test.go
Normal file
@@ -0,0 +1,672 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
"goyco/internal/testutils"
|
||||
)
|
||||
|
||||
func testHashVerificationToken(token string) string {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func setupAuthService(t *testing.T) (*AuthFacade, *testutils.ServiceSuite) {
|
||||
t.Helper()
|
||||
suite := testutils.NewServiceSuite(t)
|
||||
authService, err := NewAuthFacadeForTest(testutils.AppTestConfig, suite.UserRepo, suite.PostRepo, suite.DeletionRepo, suite.RefreshTokenRepo, suite.EmailSender)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create auth service: %v", err)
|
||||
}
|
||||
return authService, suite
|
||||
}
|
||||
|
||||
func TestAuthService_Unit_Register(t *testing.T) {
|
||||
t.Run("Successful_Registration", func(t *testing.T) {
|
||||
authService, _ := setupAuthService(t)
|
||||
|
||||
result, err := authService.Register("testuser", "test@example.com", "SecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful registration, got error: %v", err)
|
||||
}
|
||||
|
||||
if result.User.Username != "testuser" {
|
||||
t.Errorf("Expected username 'testuser', got '%s'", result.User.Username)
|
||||
}
|
||||
|
||||
if result.User.Email != "test@example.com" {
|
||||
t.Errorf("Expected email 'test@example.com', got '%s'", result.User.Email)
|
||||
}
|
||||
|
||||
if result.User.EmailVerified {
|
||||
t.Error("Expected email to be unverified initially")
|
||||
}
|
||||
|
||||
if !result.VerificationSent {
|
||||
t.Error("Expected verification to be sent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Duplicate_Username", func(t *testing.T) {
|
||||
authService, suite := setupAuthService(t)
|
||||
|
||||
existingUser := &database.User{
|
||||
Username: "existinguser",
|
||||
Email: "existing@example.com",
|
||||
Password: "hashed",
|
||||
EmailVerified: true,
|
||||
}
|
||||
if err := suite.UserRepo.Create(existingUser); err != nil {
|
||||
t.Fatalf("Failed to create existing user: %v", err)
|
||||
}
|
||||
|
||||
_, err := authService.Register("existinguser", "test@example.com", "SecurePass123!")
|
||||
if err == nil {
|
||||
t.Error("Expected error for duplicate username")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "username already exists") {
|
||||
t.Errorf("Expected username conflict error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Duplicate_Email", func(t *testing.T) {
|
||||
authService, suite := setupAuthService(t)
|
||||
|
||||
existingUser := &database.User{
|
||||
Username: "existinguser",
|
||||
Email: "existing@example.com",
|
||||
Password: "hashed",
|
||||
EmailVerified: true,
|
||||
}
|
||||
if err := suite.UserRepo.Create(existingUser); err != nil {
|
||||
t.Fatalf("Failed to create existing user: %v", err)
|
||||
}
|
||||
|
||||
_, err := authService.Register("newuser", "existing@example.com", "SecurePass123!")
|
||||
if err == nil {
|
||||
t.Error("Expected error for duplicate email")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "email already exists") {
|
||||
t.Errorf("Expected email conflict error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Weak_Password", func(t *testing.T) {
|
||||
authService, _ := setupAuthService(t)
|
||||
|
||||
_, err := authService.Register("testuser", "test@example.com", "123")
|
||||
if err == nil {
|
||||
t.Error("Expected error for weak password")
|
||||
}
|
||||
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "password") {
|
||||
t.Errorf("Expected password validation error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid_Email", func(t *testing.T) {
|
||||
authService, _ := setupAuthService(t)
|
||||
|
||||
_, err := authService.Register("testuser", "invalid-email", "SecurePass123!")
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid email")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "email") {
|
||||
t.Errorf("Expected email validation error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthService_Unit_Login(t *testing.T) {
|
||||
t.Run("Successful_Login", func(t *testing.T) {
|
||||
authService, suite := setupAuthService(t)
|
||||
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("SecurePass123!"), bcrypt.DefaultCost)
|
||||
testUser := &database.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: string(hashedPassword),
|
||||
EmailVerified: true,
|
||||
}
|
||||
if err := suite.UserRepo.Create(testUser); err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
result, err := authService.Login("testuser", "SecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful login, got error: %v", err)
|
||||
}
|
||||
|
||||
if result.User.Username != "testuser" {
|
||||
t.Errorf("Expected username 'testuser', got '%s'", result.User.Username)
|
||||
}
|
||||
|
||||
if result.AccessToken == "" {
|
||||
t.Error("Expected access token to be generated")
|
||||
}
|
||||
|
||||
if result.RefreshToken == "" {
|
||||
t.Error("Expected refresh token to be generated")
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(result.AccessToken, func(token *jwt.Token) (any, error) {
|
||||
return []byte(testutils.AppTestConfig.JWT.Secret), nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse token: %v", err)
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
t.Error("Expected valid JWT token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid_Credentials", func(t *testing.T) {
|
||||
authService, _ := setupAuthService(t)
|
||||
|
||||
_, err := authService.Login("nonexistent", "SecurePass123!")
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid credentials")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "invalid credentials") {
|
||||
t.Errorf("Expected invalid credentials error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Wrong_Password", func(t *testing.T) {
|
||||
authService, suite := setupAuthService(t)
|
||||
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("correctpassword"), bcrypt.DefaultCost)
|
||||
testUser := &database.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: string(hashedPassword),
|
||||
EmailVerified: true,
|
||||
}
|
||||
if err := suite.UserRepo.Create(testUser); err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
_, err := authService.Login("testuser", "wrongpassword")
|
||||
if err == nil {
|
||||
t.Error("Expected error for wrong password")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "invalid credentials") {
|
||||
t.Errorf("Expected invalid credentials error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Unverified_Email", func(t *testing.T) {
|
||||
authService, suite := setupAuthService(t)
|
||||
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("SecurePass123!"), bcrypt.DefaultCost)
|
||||
testUser := &database.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: string(hashedPassword),
|
||||
EmailVerified: false,
|
||||
}
|
||||
if err := suite.UserRepo.Create(testUser); err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
_, err := authService.Login("testuser", "SecurePass123!")
|
||||
if err == nil {
|
||||
t.Error("Expected error for unverified email")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "email not verified") {
|
||||
t.Errorf("Expected email verification error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthService_Unit_ConfirmEmail(t *testing.T) {
|
||||
t.Run("Successful_Email_Confirmation", func(t *testing.T) {
|
||||
authService, suite := setupAuthService(t)
|
||||
|
||||
rawToken := "valid-token"
|
||||
hashedToken := testHashVerificationToken(rawToken)
|
||||
|
||||
testUser := &database.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashed",
|
||||
EmailVerified: false,
|
||||
EmailVerificationToken: hashedToken,
|
||||
}
|
||||
if err := suite.UserRepo.Create(testUser); err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
result, err := authService.ConfirmEmail("valid-token")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful email confirmation, got error: %v", err)
|
||||
}
|
||||
|
||||
if !result.EmailVerified {
|
||||
t.Error("Expected email to be verified")
|
||||
}
|
||||
|
||||
if result.EmailVerificationToken != "" {
|
||||
t.Error("Expected verification token to be cleared")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid_Token", func(t *testing.T) {
|
||||
authService, _ := setupAuthService(t)
|
||||
|
||||
_, err := authService.ConfirmEmail("invalid-token")
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid token")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "invalid verification token") {
|
||||
t.Errorf("Expected invalid verification token error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthService_Integration_Complete_Workflow(t *testing.T) {
|
||||
suite := testutils.NewServiceSuite(t)
|
||||
|
||||
authService, err := NewAuthFacadeForTest(testutils.AppTestConfig, suite.UserRepo, suite.PostRepo, suite.DeletionRepo, suite.RefreshTokenRepo, suite.EmailSender)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create auth service: %v", err)
|
||||
}
|
||||
|
||||
t.Run("Complete_User_Lifecycle", func(t *testing.T) {
|
||||
suite.EmailSender.Reset()
|
||||
registerResult, err := authService.Register("lifecycle_user", "lifecycle@example.com", "SecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register user: %v", err)
|
||||
}
|
||||
|
||||
if registerResult.User.Username != "lifecycle_user" {
|
||||
t.Errorf("Expected username 'lifecycle_user', got '%s'", registerResult.User.Username)
|
||||
}
|
||||
|
||||
if registerResult.User.EmailVerified {
|
||||
t.Error("Expected email to be unverified initially")
|
||||
}
|
||||
|
||||
verificationToken := setupVerificationTokenForTest(t, suite.EmailSender, suite.UserRepo, "lifecycle_user")
|
||||
|
||||
confirmResult, err := authService.ConfirmEmail(verificationToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to confirm email: %v", err)
|
||||
}
|
||||
|
||||
if !confirmResult.EmailVerified {
|
||||
t.Error("Expected email to be verified after confirmation")
|
||||
}
|
||||
|
||||
loginResult, err := authService.Login("lifecycle_user", "SecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to login: %v", err)
|
||||
}
|
||||
|
||||
if loginResult.User.Username != "lifecycle_user" {
|
||||
t.Errorf("Expected username 'lifecycle_user', got '%s'", loginResult.User.Username)
|
||||
}
|
||||
|
||||
if loginResult.AccessToken == "" {
|
||||
t.Error("Expected access token to be generated")
|
||||
}
|
||||
|
||||
if loginResult.RefreshToken == "" {
|
||||
t.Error("Expected refresh token to be generated")
|
||||
}
|
||||
|
||||
updateResult, err := authService.UpdateUsername(loginResult.User.ID, "updated_lifecycle_user")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update username: %v", err)
|
||||
}
|
||||
|
||||
if updateResult.Username != "updated_lifecycle_user" {
|
||||
t.Errorf("Expected updated username, got '%s'", updateResult.Username)
|
||||
}
|
||||
|
||||
suite.EmailSender.Reset()
|
||||
emailResult, err := authService.UpdateEmail(loginResult.User.ID, "updated@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update email: %v", err)
|
||||
}
|
||||
|
||||
if emailResult.Email != "updated@example.com" {
|
||||
t.Errorf("Expected updated email, got '%s'", emailResult.Email)
|
||||
}
|
||||
|
||||
newVerificationToken := setupVerificationTokenForTest(t, suite.EmailSender, suite.UserRepo, "updated_lifecycle_user")
|
||||
|
||||
_, err = authService.ConfirmEmail(newVerificationToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to confirm updated email: %v", err)
|
||||
}
|
||||
|
||||
_, err = authService.UpdatePassword(loginResult.User.ID, "SecurePass123!", "NewSecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update password: %v", err)
|
||||
}
|
||||
|
||||
_, err = authService.Login("updated_lifecycle_user", "NewSecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to login with new password: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Account_Deletion_Workflow", func(t *testing.T) {
|
||||
suite.EmailSender.Reset()
|
||||
|
||||
registerResult, err := authService.Register("deletion_user", "deletion@example.com", "SecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register user: %v", err)
|
||||
}
|
||||
|
||||
verificationToken := setupVerificationTokenForTest(t, suite.EmailSender, suite.UserRepo, "deletion_user")
|
||||
|
||||
_, err = authService.ConfirmEmail(verificationToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to confirm email: %v", err)
|
||||
}
|
||||
|
||||
err = authService.RequestAccountDeletion(registerResult.User.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to request account deletion: %v", err)
|
||||
}
|
||||
|
||||
deletionToken := setupDeletionTokenForTest(t, suite.EmailSender, suite.DeletionRepo, registerResult.User.ID)
|
||||
|
||||
err = authService.ConfirmAccountDeletion(deletionToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to confirm account deletion: %v", err)
|
||||
}
|
||||
|
||||
if err := authService.ConfirmAccountDeletion(deletionToken); !errors.Is(err, ErrInvalidDeletionToken) {
|
||||
t.Fatalf("Expected token reuse to fail with ErrInvalidDeletionToken, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Password_Reset_Workflow", func(t *testing.T) {
|
||||
suite.EmailSender.Reset()
|
||||
|
||||
_, err := authService.Register("reset_user", "reset@example.com", "SecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register user: %v", err)
|
||||
}
|
||||
|
||||
verificationToken := setupVerificationTokenForTest(t, suite.EmailSender, suite.UserRepo, "reset_user")
|
||||
|
||||
_, err = authService.ConfirmEmail(verificationToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to confirm email: %v", err)
|
||||
}
|
||||
|
||||
err = authService.RequestPasswordReset("reset@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to request password reset: %v", err)
|
||||
}
|
||||
|
||||
resetToken := setupPasswordResetTokenForTest(t, suite.EmailSender, suite.UserRepo, "reset@example.com")
|
||||
|
||||
err = authService.ResetPassword(resetToken, "NewSecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to reset password: %v", err)
|
||||
}
|
||||
|
||||
if err := authService.ResetPassword(resetToken, "AnotherPass123!"); err == nil {
|
||||
t.Fatal("expected reset token reuse to fail")
|
||||
}
|
||||
|
||||
_, err = authService.Login("reset_user", "NewSecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to login with new password: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthService_Integration_Error_Handling(t *testing.T) {
|
||||
suite := testutils.NewServiceSuite(t)
|
||||
|
||||
authService, err := NewAuthFacadeForTest(testutils.AppTestConfig, suite.UserRepo, suite.PostRepo, suite.DeletionRepo, suite.RefreshTokenRepo, suite.EmailSender)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create auth service: %v", err)
|
||||
}
|
||||
|
||||
t.Run("Validation_Errors", func(t *testing.T) {
|
||||
_, err := authService.Register("weak_user", "weak@example.com", "123")
|
||||
if err == nil {
|
||||
t.Error("Expected error for weak password")
|
||||
}
|
||||
|
||||
_, err = authService.Register("invalid_user", "not-an-email", "SecurePass123!")
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid email")
|
||||
}
|
||||
|
||||
_, err = authService.Register("", "test@example.com", "SecurePass123!")
|
||||
if err == nil {
|
||||
t.Error("Expected error for empty username")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Duplicate_Constraints", func(t *testing.T) {
|
||||
_, err := authService.Register("duplicate_user", "duplicate1@example.com", "SecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register first user: %v", err)
|
||||
}
|
||||
|
||||
_, err = authService.Register("duplicate_user", "duplicate2@example.com", "SecurePass123!")
|
||||
if err == nil {
|
||||
t.Error("Expected error for duplicate username")
|
||||
}
|
||||
|
||||
_, err = authService.Register("another_user", "duplicate1@example.com", "SecurePass123!")
|
||||
if err == nil {
|
||||
t.Error("Expected error for duplicate email")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Duplicate_LongUsername", func(t *testing.T) {
|
||||
longUsername := strings.Repeat("x", 50)
|
||||
user := &database.User{
|
||||
Username: longUsername,
|
||||
Email: "longuser@example.com",
|
||||
Password: testutils.HashPassword("SecurePass123!"),
|
||||
EmailVerified: true,
|
||||
}
|
||||
if err := suite.UserRepo.Create(user); err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
_, err := authService.Register(longUsername, "longuser2@example.com", "SecurePass123!")
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Fatalf("expected ErrUsernameTaken for duplicate long username, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Authentication_Errors", func(t *testing.T) {
|
||||
_, err := authService.Login("nonexistent", "password")
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent user")
|
||||
}
|
||||
|
||||
_, err = authService.Register("auth_user", "auth@example.com", "SecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register user: %v", err)
|
||||
}
|
||||
|
||||
verificationToken := suite.EmailSender.VerificationToken()
|
||||
if verificationToken == "" {
|
||||
t.Fatal("Expected verification token to be generated")
|
||||
}
|
||||
|
||||
_, err = authService.ConfirmEmail(verificationToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to confirm email: %v", err)
|
||||
}
|
||||
|
||||
_, err = authService.Login("auth_user", "wrongpassword")
|
||||
if err == nil {
|
||||
t.Error("Expected error for wrong password")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Pre_Verified_User_Operations", func(t *testing.T) {
|
||||
user := createTestUserWithAuth(authService, suite.EmailSender, suite.UserRepo, "preverified_user", "preverified@example.com")
|
||||
|
||||
if user.Username != "preverified_user" {
|
||||
t.Errorf("Expected username 'preverified_user', got '%s'", user.Username)
|
||||
}
|
||||
|
||||
if user.Email != "preverified@example.com" {
|
||||
t.Errorf("Expected email 'preverified@example.com', got '%s'", user.Email)
|
||||
}
|
||||
|
||||
if !user.EmailVerified {
|
||||
t.Error("Expected user to be email verified")
|
||||
}
|
||||
|
||||
loginResult, err := authService.Login("preverified_user", "SecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful login for pre-verified user, got error: %v", err)
|
||||
}
|
||||
|
||||
if loginResult.User.ID != user.ID {
|
||||
t.Errorf("Expected user ID %d, got %d", user.ID, loginResult.User.ID)
|
||||
}
|
||||
|
||||
updateResult, err := authService.UpdateUsername(user.ID, "updated_preverified_user")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful username update, got error: %v", err)
|
||||
}
|
||||
|
||||
if updateResult.Username != "updated_preverified_user" {
|
||||
t.Errorf("Expected updated username, got '%s'", updateResult.Username)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func createTestUserWithAuth(authService interface {
|
||||
Register(username, email, password string) (*RegistrationResult, error)
|
||||
ConfirmEmail(token string) (*database.User, error)
|
||||
}, emailSender interface {
|
||||
Reset()
|
||||
VerificationToken() string
|
||||
}, userRepo repositories.UserRepository, username, email string) *database.User {
|
||||
emailSender.Reset()
|
||||
|
||||
_, err := authService.Register(username, email, "SecurePass123!")
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to register user: %v", err))
|
||||
}
|
||||
|
||||
verificationToken := emailSender.VerificationToken()
|
||||
if verificationToken == "" {
|
||||
panic("Failed to capture verification token during test setup")
|
||||
}
|
||||
|
||||
hashedToken := testutils.HashVerificationToken(verificationToken)
|
||||
|
||||
user, err := userRepo.GetByUsername(username)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to get user: %v", err))
|
||||
}
|
||||
user.EmailVerificationToken = hashedToken
|
||||
if err := userRepo.Update(user); err != nil {
|
||||
panic(fmt.Sprintf("Failed to update user with hashed token: %v", err))
|
||||
}
|
||||
|
||||
confirmResult, err := authService.ConfirmEmail(verificationToken)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to confirm email: %v", err))
|
||||
}
|
||||
|
||||
return confirmResult
|
||||
}
|
||||
|
||||
func setupVerificationTokenForTest(t *testing.T, emailSender *testutils.MockEmailSender, userRepo repositories.UserRepository, username string) string {
|
||||
t.Helper()
|
||||
|
||||
verificationToken := emailSender.VerificationToken()
|
||||
if verificationToken == "" {
|
||||
t.Fatal("Expected verification token to be generated")
|
||||
}
|
||||
|
||||
hashedToken := testutils.HashVerificationToken(verificationToken)
|
||||
|
||||
user, err := userRepo.GetByUsername(username)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user: %v", err)
|
||||
}
|
||||
user.EmailVerificationToken = hashedToken
|
||||
if err := userRepo.Update(user); err != nil {
|
||||
t.Fatalf("Failed to update user with hashed token: %v", err)
|
||||
}
|
||||
|
||||
return verificationToken
|
||||
}
|
||||
|
||||
func setupDeletionTokenForTest(t *testing.T, emailSender *testutils.MockEmailSender, deletionRepo repositories.AccountDeletionRepository, userID uint) string {
|
||||
t.Helper()
|
||||
|
||||
deletionToken := emailSender.DeletionToken()
|
||||
if deletionToken == "" {
|
||||
t.Fatal("Expected deletion token to be generated")
|
||||
}
|
||||
|
||||
hashedToken := testutils.HashVerificationToken(deletionToken)
|
||||
|
||||
if err := deletionRepo.DeleteByUserID(userID); err != nil {
|
||||
t.Fatalf("Cannot delete user %d", userID)
|
||||
}
|
||||
|
||||
req := &database.AccountDeletionRequest{
|
||||
UserID: userID,
|
||||
TokenHash: hashedToken,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
if err := deletionRepo.Create(req); err != nil {
|
||||
t.Fatalf("Failed to create account deletion request: %v", err)
|
||||
}
|
||||
|
||||
return deletionToken
|
||||
}
|
||||
|
||||
func setupPasswordResetTokenForTest(t *testing.T, emailSender *testutils.MockEmailSender, userRepo repositories.UserRepository, email string) string {
|
||||
t.Helper()
|
||||
|
||||
resetToken := emailSender.PasswordResetToken()
|
||||
if resetToken == "" {
|
||||
t.Fatal("Expected password reset token to be generated")
|
||||
}
|
||||
|
||||
hashedToken := testutils.HashVerificationToken(resetToken)
|
||||
|
||||
user, err := userRepo.GetByEmail(email)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user: %v", err)
|
||||
}
|
||||
user.PasswordResetToken = hashedToken
|
||||
if err := userRepo.Update(user); err != nil {
|
||||
t.Fatalf("Failed to update user with hashed reset token: %v", err)
|
||||
}
|
||||
|
||||
return resetToken
|
||||
}
|
||||
35
internal/services/auth_types.go
Normal file
35
internal/services/auth_types.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"goyco/internal/database"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
ErrInvalidToken = errors.New("invalid or expired token")
|
||||
ErrUsernameTaken = errors.New("username already exists")
|
||||
ErrEmailTaken = errors.New("email already exists")
|
||||
ErrInvalidEmail = errors.New("invalid email address")
|
||||
ErrPasswordTooShort = errors.New("password too short")
|
||||
ErrEmailNotVerified = errors.New("email not verified")
|
||||
ErrAccountLocked = errors.New("account is locked")
|
||||
ErrInvalidVerificationToken = errors.New("invalid verification token")
|
||||
ErrEmailSenderUnavailable = errors.New("email sender not configured")
|
||||
ErrDeletionEmailFailed = errors.New("account deletion email failed")
|
||||
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrDeletionRequestNotFound = errors.New("deletion request not found")
|
||||
)
|
||||
|
||||
type AuthResult struct {
|
||||
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
|
||||
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
|
||||
User *database.User `json:"user"`
|
||||
}
|
||||
|
||||
type RegistrationResult struct {
|
||||
User *database.User `json:"user"`
|
||||
VerificationSent bool `json:"verification_sent"`
|
||||
}
|
||||
59
internal/services/auth_utils.go
Normal file
59
internal/services/auth_utils.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strings"
|
||||
|
||||
"goyco/internal/database"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTokenExpirationHours = 24
|
||||
verificationTokenBytes = 32
|
||||
deletionTokenExpirationHours = 24
|
||||
)
|
||||
|
||||
func normalizeEmail(email string) (string, error) {
|
||||
trimmed := strings.TrimSpace(email)
|
||||
if trimmed == "" {
|
||||
return "", fmt.Errorf("email is required")
|
||||
}
|
||||
|
||||
parsed, err := mail.ParseAddress(trimmed)
|
||||
if err != nil {
|
||||
return "", ErrInvalidEmail
|
||||
}
|
||||
|
||||
return strings.ToLower(parsed.Address), nil
|
||||
}
|
||||
|
||||
func generateVerificationToken() (string, string, error) {
|
||||
buf := make([]byte, verificationTokenBytes)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", "", fmt.Errorf("generate verification token: %w", err)
|
||||
}
|
||||
|
||||
token := hex.EncodeToString(buf)
|
||||
hashed := HashVerificationToken(token)
|
||||
return token, hashed, nil
|
||||
}
|
||||
|
||||
func HashVerificationToken(token string) string {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func sanitizeUser(user *database.User) *database.User {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
copy := *user
|
||||
copy.Password = ""
|
||||
copy.EmailVerificationToken = ""
|
||||
return ©
|
||||
}
|
||||
136
internal/services/common.go
Normal file
136
internal/services/common.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/repositories"
|
||||
)
|
||||
|
||||
func TrimString(s string) string {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultBcryptCost = 10
|
||||
)
|
||||
|
||||
func HashPassword(password string, cost int) (string, error) {
|
||||
if cost <= 0 {
|
||||
cost = DefaultBcryptCost
|
||||
}
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), cost)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
return string(hashedPassword), nil
|
||||
}
|
||||
|
||||
func IsRecordNotFound(err error) bool {
|
||||
return errors.Is(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
|
||||
func HandleUniqueConstraintError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var pqErr *pq.Error
|
||||
if !errors.As(err, &pqErr) || pqErr.Code != "23505" {
|
||||
return err
|
||||
}
|
||||
|
||||
constraintLower := strings.ToLower(pqErr.Constraint)
|
||||
errMsgLower := strings.ToLower(pqErr.Message)
|
||||
|
||||
if strings.Contains(constraintLower, "username") ||
|
||||
strings.Contains(errMsgLower, "username") ||
|
||||
strings.Contains(errMsgLower, "users_username_key") ||
|
||||
strings.Contains(errMsgLower, "users.username") {
|
||||
return ErrUsernameTaken
|
||||
}
|
||||
|
||||
if strings.Contains(constraintLower, "email") ||
|
||||
strings.Contains(errMsgLower, "email") ||
|
||||
strings.Contains(errMsgLower, "users_email_key") ||
|
||||
strings.Contains(errMsgLower, "users.email") {
|
||||
return ErrEmailTaken
|
||||
}
|
||||
|
||||
return ErrUsernameTaken
|
||||
}
|
||||
|
||||
func HandleUniqueConstraintErrorWithMessage(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if handled := HandleUniqueConstraintError(err); handled != err {
|
||||
return handled
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
errMsgLower := strings.ToLower(errMsg)
|
||||
|
||||
isUniqueError := strings.Contains(errMsgLower, "duplicate key") ||
|
||||
strings.Contains(errMsgLower, "unique constraint") ||
|
||||
strings.Contains(errMsgLower, "violates unique constraint") ||
|
||||
strings.Contains(errMsgLower, "unique constraint failed") ||
|
||||
strings.Contains(errMsgLower, "constraint failed") ||
|
||||
(strings.Contains(errMsgLower, "constraint") && strings.Contains(errMsgLower, "unique"))
|
||||
|
||||
if !isUniqueError {
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.Contains(errMsgLower, "username") ||
|
||||
strings.Contains(errMsgLower, "users_username_key") ||
|
||||
strings.Contains(errMsgLower, "users.username") ||
|
||||
strings.Contains(errMsg, "username") ||
|
||||
strings.Contains(errMsg, "users_username_key") ||
|
||||
strings.Contains(errMsg, "users.username") {
|
||||
return ErrUsernameTaken
|
||||
}
|
||||
|
||||
if strings.Contains(errMsgLower, "email") ||
|
||||
strings.Contains(errMsgLower, "users_email_key") ||
|
||||
strings.Contains(errMsgLower, "users.email") ||
|
||||
strings.Contains(errMsg, "email") ||
|
||||
strings.Contains(errMsg, "users_email_key") ||
|
||||
strings.Contains(errMsg, "users.email") {
|
||||
return ErrEmailTaken
|
||||
}
|
||||
|
||||
return ErrUsernameTaken
|
||||
}
|
||||
|
||||
func NewAuthFacadeForTest(cfg *config.Config, userRepo repositories.UserRepository, postRepo repositories.PostRepository, deletionRepo repositories.AccountDeletionRepository, refreshRepo repositories.RefreshTokenRepositoryInterface, emailSender EmailSender) (*AuthFacade, error) {
|
||||
emailService, err := NewEmailService(cfg, emailSender)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create email service: %w", err)
|
||||
}
|
||||
|
||||
jwtService := NewJWTService(&cfg.JWT, userRepo, refreshRepo)
|
||||
|
||||
registrationService := NewRegistrationService(userRepo, emailService, cfg)
|
||||
passwordResetService := NewPasswordResetService(userRepo, emailService)
|
||||
deletionService := NewAccountDeletionService(userRepo, postRepo, deletionRepo, emailService)
|
||||
sessionService := NewSessionService(jwtService, userRepo)
|
||||
userManagementService := NewUserManagementService(userRepo, postRepo, emailService)
|
||||
|
||||
authFacade := NewAuthFacade(
|
||||
registrationService,
|
||||
passwordResetService,
|
||||
deletionService,
|
||||
sessionService,
|
||||
userManagementService,
|
||||
cfg,
|
||||
)
|
||||
|
||||
return authFacade, nil
|
||||
}
|
||||
822
internal/services/common_test.go
Normal file
822
internal/services/common_test.go
Normal file
@@ -0,0 +1,822 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/repositories"
|
||||
"goyco/internal/testutils"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestTrimString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no_whitespace",
|
||||
input: "test",
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "leading_spaces",
|
||||
input: " test",
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "trailing_spaces",
|
||||
input: "test ",
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "leading_and_trailing_spaces",
|
||||
input: " test ",
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "only_spaces",
|
||||
input: " ",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty_string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "tabs_and_newlines",
|
||||
input: "\t\n test \n\t",
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "internal_spaces_preserved",
|
||||
input: " test string ",
|
||||
expected: "test string",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := TrimString(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("TrimString(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPassword(t *testing.T) {
|
||||
t.Run("successful_hash", func(t *testing.T) {
|
||||
password := "testpassword123"
|
||||
hashed, err := HashPassword(password, DefaultBcryptCost)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v, want no error", err)
|
||||
}
|
||||
if hashed == "" {
|
||||
t.Error("HashPassword() returned empty string")
|
||||
}
|
||||
if hashed == password {
|
||||
t.Error("HashPassword() returned plain password")
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(hashed), []byte(password))
|
||||
if err != nil {
|
||||
t.Errorf("HashPassword() produced invalid hash: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different_passwords_produce_different_hashes", func(t *testing.T) {
|
||||
password1 := "password1"
|
||||
password2 := "password2"
|
||||
|
||||
hash1, err := HashPassword(password1, DefaultBcryptCost)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
hash2, err := HashPassword(password2, DefaultBcryptCost)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Error("Different passwords produced same hash")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("same_password_produces_different_hashes", func(t *testing.T) {
|
||||
password := "samepassword"
|
||||
hash1, err := HashPassword(password, DefaultBcryptCost)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
hash2, err := HashPassword(password, DefaultBcryptCost)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Error("Same password produced same hash (should be different due to salt)")
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash1), []byte(password)); err != nil {
|
||||
t.Errorf("First hash doesn't verify: %v", err)
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash2), []byte(password)); err != nil {
|
||||
t.Errorf("Second hash doesn't verify: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom_cost", func(t *testing.T) {
|
||||
password := "testpassword"
|
||||
customCost := 12
|
||||
|
||||
hashed, err := HashPassword(password, customCost)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(hashed), []byte(password))
|
||||
if err != nil {
|
||||
t.Errorf("HashPassword() with custom cost produced invalid hash: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("zero_cost_uses_default", func(t *testing.T) {
|
||||
password := "testpassword"
|
||||
|
||||
hashed, err := HashPassword(password, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(hashed), []byte(password))
|
||||
if err != nil {
|
||||
t.Errorf("HashPassword() with zero cost produced invalid hash: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("negative_cost_uses_default", func(t *testing.T) {
|
||||
password := "testpassword"
|
||||
|
||||
hashed, err := HashPassword(password, -1)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(hashed), []byte(password))
|
||||
if err != nil {
|
||||
t.Errorf("HashPassword() with negative cost produced invalid hash: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty_password", func(t *testing.T) {
|
||||
hashed, err := HashPassword("", DefaultBcryptCost)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(hashed), []byte(""))
|
||||
if err != nil {
|
||||
t.Errorf("HashPassword() with empty password produced invalid hash: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsRecordNotFound(t *testing.T) {
|
||||
t.Run("gorm_record_not_found", func(t *testing.T) {
|
||||
err := gorm.ErrRecordNotFound
|
||||
if !IsRecordNotFound(err) {
|
||||
t.Error("IsRecordNotFound() = false, want true for gorm.ErrRecordNotFound")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrapped_gorm_record_not_found", func(t *testing.T) {
|
||||
err := errors.New("some context")
|
||||
wrappedErr := errors.Join(err, gorm.ErrRecordNotFound)
|
||||
if !IsRecordNotFound(wrappedErr) {
|
||||
t.Error("IsRecordNotFound() = false, want true for wrapped gorm.ErrRecordNotFound")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("other_error", func(t *testing.T) {
|
||||
err := errors.New("some other error")
|
||||
if IsRecordNotFound(err) {
|
||||
t.Error("IsRecordNotFound() = true, want false for other error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil_error", func(t *testing.T) {
|
||||
if IsRecordNotFound(nil) {
|
||||
t.Error("IsRecordNotFound() = true, want false for nil error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleUniqueConstraintError(t *testing.T) {
|
||||
t.Run("nil_error", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintError(nil)
|
||||
if err != nil {
|
||||
t.Errorf("HandleUniqueConstraintError(nil) = %v, want nil", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non_pq_error", func(t *testing.T) {
|
||||
originalErr := errors.New("some other error")
|
||||
err := HandleUniqueConstraintError(originalErr)
|
||||
if err != originalErr {
|
||||
t.Errorf("HandleUniqueConstraintError() = %v, want %v", err, originalErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("pq_error_wrong_code", func(t *testing.T) {
|
||||
pqErr := &pq.Error{
|
||||
Code: "23503",
|
||||
Message: "some error",
|
||||
}
|
||||
err := HandleUniqueConstraintError(pqErr)
|
||||
if err != pqErr {
|
||||
t.Errorf("HandleUniqueConstraintError() = %v, want %v", err, pqErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("username_constraint_in_constraint_field", func(t *testing.T) {
|
||||
pqErr := &pq.Error{
|
||||
Code: "23505",
|
||||
Constraint: "users_username_key",
|
||||
Message: "duplicate key value violates unique constraint",
|
||||
}
|
||||
err := HandleUniqueConstraintError(pqErr)
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintError() = %v, want ErrUsernameTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("username_constraint_in_message", func(t *testing.T) {
|
||||
pqErr := &pq.Error{
|
||||
Code: "23505",
|
||||
Constraint: "some_constraint",
|
||||
Message: "duplicate key value violates unique constraint users.username",
|
||||
}
|
||||
err := HandleUniqueConstraintError(pqErr)
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintError() = %v, want ErrUsernameTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("username_constraint_case_insensitive", func(t *testing.T) {
|
||||
pqErr := &pq.Error{
|
||||
Code: "23505",
|
||||
Constraint: "USERS_USERNAME_KEY",
|
||||
Message: "DUPLICATE KEY VALUE VIOLATES UNIQUE CONSTRAINT",
|
||||
}
|
||||
err := HandleUniqueConstraintError(pqErr)
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintError() = %v, want ErrUsernameTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("email_constraint_in_constraint_field", func(t *testing.T) {
|
||||
pqErr := &pq.Error{
|
||||
Code: "23505",
|
||||
Constraint: "users_email_key",
|
||||
Message: "duplicate key value violates unique constraint",
|
||||
}
|
||||
err := HandleUniqueConstraintError(pqErr)
|
||||
if !errors.Is(err, ErrEmailTaken) {
|
||||
t.Errorf("HandleUniqueConstraintError() = %v, want ErrEmailTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("email_constraint_in_message", func(t *testing.T) {
|
||||
pqErr := &pq.Error{
|
||||
Code: "23505",
|
||||
Constraint: "some_constraint",
|
||||
Message: "duplicate key value violates unique constraint users.email",
|
||||
}
|
||||
err := HandleUniqueConstraintError(pqErr)
|
||||
if !errors.Is(err, ErrEmailTaken) {
|
||||
t.Errorf("HandleUniqueConstraintError() = %v, want ErrEmailTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("email_constraint_case_insensitive", func(t *testing.T) {
|
||||
pqErr := &pq.Error{
|
||||
Code: "23505",
|
||||
Constraint: "USERS_EMAIL_KEY",
|
||||
Message: "DUPLICATE KEY VALUE VIOLATES UNIQUE CONSTRAINT",
|
||||
}
|
||||
err := HandleUniqueConstraintError(pqErr)
|
||||
if !errors.Is(err, ErrEmailTaken) {
|
||||
t.Errorf("HandleUniqueConstraintError() = %v, want ErrEmailTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown_unique_constraint_defaults_to_username", func(t *testing.T) {
|
||||
pqErr := &pq.Error{
|
||||
Code: "23505",
|
||||
Constraint: "some_other_constraint",
|
||||
Message: "duplicate key value violates unique constraint",
|
||||
}
|
||||
err := HandleUniqueConstraintError(pqErr)
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintError() = %v, want ErrUsernameTaken (default)", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleUniqueConstraintErrorWithMessage(t *testing.T) {
|
||||
t.Run("nil_error", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(nil)
|
||||
if err != nil {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage(nil) = %v, want nil", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("pq_error_username_handled_by_HandleUniqueConstraintError", func(t *testing.T) {
|
||||
pqErr := &pq.Error{
|
||||
Code: "23505",
|
||||
Constraint: "users_username_key",
|
||||
Message: "duplicate key value violates unique constraint",
|
||||
}
|
||||
err := HandleUniqueConstraintErrorWithMessage(pqErr)
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("pq_error_email_handled_by_HandleUniqueConstraintError", func(t *testing.T) {
|
||||
pqErr := &pq.Error{
|
||||
Code: "23505",
|
||||
Constraint: "users_email_key",
|
||||
Message: "duplicate key value violates unique constraint",
|
||||
}
|
||||
err := HandleUniqueConstraintErrorWithMessage(pqErr)
|
||||
if !errors.Is(err, ErrEmailTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("message_based_duplicate_key_username", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint username"))
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("message_based_unique_constraint_username", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("unique constraint failed on username"))
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("message_based_violates_unique_constraint_username", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("violates unique constraint users_username_key"))
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("message_based_unique_constraint_failed_username", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("unique constraint failed on users.username"))
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("message_based_constraint_failed_username", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("constraint failed on username"))
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("message_based_constraint_and_unique_username", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("constraint unique username failed"))
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("message_based_duplicate_key_email", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint email"))
|
||||
if !errors.Is(err, ErrEmailTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("message_based_unique_constraint_email", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("unique constraint failed on users_email_key"))
|
||||
if !errors.Is(err, ErrEmailTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("message_based_violates_unique_constraint_email", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("violates unique constraint users.email"))
|
||||
if !errors.Is(err, ErrEmailTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("message_based_case_insensitive_username", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("DUPLICATE KEY VALUE VIOLATES UNIQUE CONSTRAINT USERNAME"))
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("message_based_case_insensitive_email", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("DUPLICATE KEY VALUE VIOLATES UNIQUE CONSTRAINT EMAIL"))
|
||||
if !errors.Is(err, ErrEmailTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non_unique_error_passed_through", func(t *testing.T) {
|
||||
originalErr := errors.New("some other database error")
|
||||
err := HandleUniqueConstraintErrorWithMessage(originalErr)
|
||||
if err != originalErr {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want %v", err, originalErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unique_error_without_username_or_email_defaults_to_username", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint"))
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken (default)", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrapped_pq_error", func(t *testing.T) {
|
||||
pqErr := &pq.Error{
|
||||
Code: "23505",
|
||||
Constraint: "users_username_key",
|
||||
Message: "duplicate key value violates unique constraint",
|
||||
}
|
||||
wrappedErr := errors.New("context: " + pqErr.Error())
|
||||
err := HandleUniqueConstraintErrorWithMessage(wrappedErr)
|
||||
if err == nil {
|
||||
t.Error("HandleUniqueConstraintErrorWithMessage() returned nil, expected error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleUniqueConstraintErrorWithMessage_EdgeCases(t *testing.T) {
|
||||
t.Run("mixed_case_username_in_message", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint UserName"))
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mixed_case_email_in_message", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint Email"))
|
||||
if !errors.Is(err, ErrEmailTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("username_substring_not_matched", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint usernames"))
|
||||
if !errors.Is(err, ErrUsernameTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrUsernameTaken", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("email_substring_not_matched", func(t *testing.T) {
|
||||
err := HandleUniqueConstraintErrorWithMessage(errors.New("duplicate key value violates unique constraint emails"))
|
||||
if !errors.Is(err, ErrEmailTaken) {
|
||||
t.Errorf("HandleUniqueConstraintErrorWithMessage() = %v, want ErrEmailTaken", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkTrimString(b *testing.B) {
|
||||
input := " test string with spaces "
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = TrimString(input)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHashPassword(b *testing.B) {
|
||||
password := "testpassword123"
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = HashPassword(password, DefaultBcryptCost)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsRecordNotFound(b *testing.B) {
|
||||
err := gorm.ErrRecordNotFound
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = IsRecordNotFound(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuthFacadeForTest(t *testing.T) {
|
||||
t.Run("successful_creation", func(t *testing.T) {
|
||||
db := testutils.NewTestDB(t)
|
||||
defer func() {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
}()
|
||||
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
postRepo := repositories.NewPostRepository(db)
|
||||
deletionRepo := repositories.NewAccountDeletionRepository(db)
|
||||
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
authFacade, err := NewAuthFacadeForTest(testutils.AppTestConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthFacadeForTest() error = %v, want no error", err)
|
||||
}
|
||||
|
||||
if authFacade == nil {
|
||||
t.Fatal("NewAuthFacadeForTest() returned nil AuthFacade")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("successful_creation_with_custom_config", func(t *testing.T) {
|
||||
db := testutils.NewTestDB(t)
|
||||
defer func() {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
}()
|
||||
|
||||
customConfig := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "custom-secret-key-for-testing-purposes-only",
|
||||
Expiration: 48,
|
||||
},
|
||||
App: config.AppConfig{
|
||||
BaseURL: "http://localhost:3000",
|
||||
BcryptCost: 12,
|
||||
},
|
||||
}
|
||||
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
postRepo := repositories.NewPostRepository(db)
|
||||
deletionRepo := repositories.NewAccountDeletionRepository(db)
|
||||
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
authFacade, err := NewAuthFacadeForTest(customConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthFacadeForTest() error = %v, want no error", err)
|
||||
}
|
||||
|
||||
if authFacade == nil {
|
||||
t.Fatal("NewAuthFacadeForTest() returned nil AuthFacade")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error_on_empty_base_url", func(t *testing.T) {
|
||||
db := testutils.NewTestDB(t)
|
||||
defer func() {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
}()
|
||||
|
||||
invalidConfig := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret-key",
|
||||
Expiration: 24,
|
||||
},
|
||||
App: config.AppConfig{
|
||||
BaseURL: "",
|
||||
BcryptCost: 10,
|
||||
},
|
||||
}
|
||||
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
postRepo := repositories.NewPostRepository(db)
|
||||
deletionRepo := repositories.NewAccountDeletionRepository(db)
|
||||
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
authFacade, err := NewAuthFacadeForTest(invalidConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
|
||||
if err == nil {
|
||||
t.Fatal("NewAuthFacadeForTest() expected error for empty base URL, got nil")
|
||||
}
|
||||
|
||||
if authFacade != nil {
|
||||
t.Fatal("NewAuthFacadeForTest() expected nil AuthFacade on error, got non-nil")
|
||||
}
|
||||
|
||||
if err.Error() != "create email service: APP_BASE_URL is required and must be externally reachable" {
|
||||
t.Errorf("NewAuthFacadeForTest() error = %v, want error about APP_BASE_URL", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error_on_whitespace_only_base_url", func(t *testing.T) {
|
||||
db := testutils.NewTestDB(t)
|
||||
defer func() {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
}()
|
||||
|
||||
invalidConfig := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret-key",
|
||||
Expiration: 24,
|
||||
},
|
||||
App: config.AppConfig{
|
||||
BaseURL: " ",
|
||||
BcryptCost: 10,
|
||||
},
|
||||
}
|
||||
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
postRepo := repositories.NewPostRepository(db)
|
||||
deletionRepo := repositories.NewAccountDeletionRepository(db)
|
||||
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
authFacade, err := NewAuthFacadeForTest(invalidConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
|
||||
if err == nil {
|
||||
t.Fatal("NewAuthFacadeForTest() expected error for whitespace-only base URL, got nil")
|
||||
}
|
||||
|
||||
if authFacade != nil {
|
||||
t.Fatal("NewAuthFacadeForTest() expected nil AuthFacade on error, got non-nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("facade_can_perform_operations", func(t *testing.T) {
|
||||
db := testutils.NewTestDB(t)
|
||||
defer func() {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
}()
|
||||
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
postRepo := repositories.NewPostRepository(db)
|
||||
deletionRepo := repositories.NewAccountDeletionRepository(db)
|
||||
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
authFacade, err := NewAuthFacadeForTest(testutils.AppTestConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthFacadeForTest() error = %v, want no error", err)
|
||||
}
|
||||
|
||||
result, err := authFacade.Register("testuser", "test@example.com", "SecurePass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("AuthFacade.Register() error = %v, want no error", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("AuthFacade.Register() returned nil result")
|
||||
}
|
||||
|
||||
if result.User.Username != "testuser" {
|
||||
t.Errorf("AuthFacade.Register() username = %v, want 'testuser'", result.User.Username)
|
||||
}
|
||||
|
||||
if result.User.Email != "test@example.com" {
|
||||
t.Errorf("AuthFacade.Register() email = %v, want 'test@example.com'", result.User.Email)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("facade_can_login", func(t *testing.T) {
|
||||
db := testutils.NewTestDB(t)
|
||||
defer func() {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
}()
|
||||
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
postRepo := repositories.NewPostRepository(db)
|
||||
deletionRepo := repositories.NewAccountDeletionRepository(db)
|
||||
refreshTokenRepo := repositories.NewRefreshTokenRepository(db)
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
authFacade, err := NewAuthFacadeForTest(testutils.AppTestConfig, userRepo, postRepo, deletionRepo, refreshTokenRepo, emailSender)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthFacadeForTest() error = %v, want no error", err)
|
||||
}
|
||||
|
||||
password := "SecurePass123!"
|
||||
_, err = authFacade.Register("loginuser", "login@example.com", password)
|
||||
if err != nil {
|
||||
t.Fatalf("AuthFacade.Register() error = %v, want no error", err)
|
||||
}
|
||||
|
||||
user, err := userRepo.GetByUsername("loginuser")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByUsername() error = %v", err)
|
||||
}
|
||||
user.EmailVerified = true
|
||||
if err := userRepo.Update(user); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
|
||||
authResult, err := authFacade.Login("loginuser", password)
|
||||
if err != nil {
|
||||
t.Fatalf("AuthFacade.Login() error = %v, want no error", err)
|
||||
}
|
||||
|
||||
if authResult == nil {
|
||||
t.Fatal("AuthFacade.Login() returned nil result")
|
||||
}
|
||||
|
||||
if authResult.AccessToken == "" {
|
||||
t.Error("AuthFacade.Login() returned empty access token")
|
||||
}
|
||||
|
||||
if authResult.RefreshToken == "" {
|
||||
t.Error("AuthFacade.Login() returned empty refresh token")
|
||||
}
|
||||
|
||||
if authResult.User == nil {
|
||||
t.Error("AuthFacade.Login() returned nil user")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple_facades_independent", func(t *testing.T) {
|
||||
db1 := testutils.NewTestDB(t)
|
||||
defer func() {
|
||||
sqlDB, _ := db1.DB()
|
||||
sqlDB.Close()
|
||||
}()
|
||||
|
||||
db2 := testutils.NewTestDB(t)
|
||||
defer func() {
|
||||
sqlDB, _ := db2.DB()
|
||||
sqlDB.Close()
|
||||
}()
|
||||
|
||||
userRepo1 := repositories.NewUserRepository(db1)
|
||||
postRepo1 := repositories.NewPostRepository(db1)
|
||||
deletionRepo1 := repositories.NewAccountDeletionRepository(db1)
|
||||
refreshTokenRepo1 := repositories.NewRefreshTokenRepository(db1)
|
||||
emailSender1 := &testutils.MockEmailSender{}
|
||||
|
||||
userRepo2 := repositories.NewUserRepository(db2)
|
||||
postRepo2 := repositories.NewPostRepository(db2)
|
||||
deletionRepo2 := repositories.NewAccountDeletionRepository(db2)
|
||||
refreshTokenRepo2 := repositories.NewRefreshTokenRepository(db2)
|
||||
emailSender2 := &testutils.MockEmailSender{}
|
||||
|
||||
authFacade1, err := NewAuthFacadeForTest(testutils.AppTestConfig, userRepo1, postRepo1, deletionRepo1, refreshTokenRepo1, emailSender1)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthFacadeForTest() error = %v", err)
|
||||
}
|
||||
|
||||
authFacade2, err := NewAuthFacadeForTest(testutils.AppTestConfig, userRepo2, postRepo2, deletionRepo2, refreshTokenRepo2, emailSender2)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthFacadeForTest() error = %v", err)
|
||||
}
|
||||
|
||||
if authFacade1 == authFacade2 {
|
||||
t.Error("NewAuthFacadeForTest() returned same instance for different repositories")
|
||||
}
|
||||
|
||||
_, err = authFacade1.Register("user1", "user1@example.com", "Pass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("AuthFacade1.Register() error = %v", err)
|
||||
}
|
||||
|
||||
_, err = authFacade2.Register("user2", "user2@example.com", "Pass123!")
|
||||
if err != nil {
|
||||
t.Fatalf("AuthFacade2.Register() error = %v", err)
|
||||
}
|
||||
|
||||
user1, err := userRepo1.GetByUsername("user1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByUsername() error = %v", err)
|
||||
}
|
||||
user1.EmailVerified = true
|
||||
if err := userRepo1.Update(user1); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
|
||||
user2, err := userRepo2.GetByUsername("user2")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByUsername() error = %v", err)
|
||||
}
|
||||
user2.EmailVerified = true
|
||||
if err := userRepo2.Update(user2); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
|
||||
_, err = authFacade1.Login("user1", "Pass123!")
|
||||
if err != nil {
|
||||
t.Errorf("AuthFacade1.Login() error = %v, should find user1", err)
|
||||
}
|
||||
|
||||
_, err = authFacade2.Login("user2", "Pass123!")
|
||||
if err != nil {
|
||||
t.Errorf("AuthFacade2.Login() error = %v, should find user2", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
99
internal/services/email_sender.go
Normal file
99
internal/services/email_sender.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type EmailSender interface {
|
||||
Send(to, subject, body string) error
|
||||
}
|
||||
|
||||
type SMTPSender struct {
|
||||
host string
|
||||
port int
|
||||
username string
|
||||
password string
|
||||
from string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewSMTPSender(cfgHost string, cfgPort int, cfgUsername, cfgPassword, cfgFrom string) *SMTPSender {
|
||||
return &SMTPSender{
|
||||
host: cfgHost,
|
||||
port: cfgPort,
|
||||
username: cfgUsername,
|
||||
password: cfgPassword,
|
||||
from: cfgFrom,
|
||||
timeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func NewSMTPSenderWithTimeout(cfgHost string, cfgPort int, cfgUsername, cfgPassword, cfgFrom string, timeout time.Duration) *SMTPSender {
|
||||
return &SMTPSender{
|
||||
host: cfgHost,
|
||||
port: cfgPort,
|
||||
username: cfgUsername,
|
||||
password: cfgPassword,
|
||||
from: cfgFrom,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SMTPSender) Send(to, subject, body string) error {
|
||||
if s == nil {
|
||||
return fmt.Errorf("smtp sender is not configured")
|
||||
}
|
||||
|
||||
to = strings.TrimSpace(to)
|
||||
if to == "" {
|
||||
return fmt.Errorf("recipient address is required")
|
||||
}
|
||||
|
||||
address := fmt.Sprintf("%s:%d", s.host, s.port)
|
||||
|
||||
headers := []string{
|
||||
fmt.Sprintf("From: %s", s.from),
|
||||
fmt.Sprintf("To: %s", to),
|
||||
fmt.Sprintf("Subject: %s", subject),
|
||||
"MIME-Version: 1.0",
|
||||
"Content-Type: text/html; charset=\"UTF-8\"",
|
||||
}
|
||||
|
||||
message := strings.Join(headers, "\r\n") + "\r\n\r\n" + body + "\r\n"
|
||||
|
||||
var auth smtp.Auth
|
||||
if strings.TrimSpace(s.username) != "" {
|
||||
auth = smtp.PlainAuth("", s.username, s.password, s.host)
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- smtp.SendMail(address, auth, s.from, []string{to}, []byte(message))
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-time.After(s.timeout):
|
||||
return fmt.Errorf("email sending timeout after %v", s.timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SMTPSender) SendAsync(to, subject, body string) <-chan error {
|
||||
result := make(chan error, 1)
|
||||
go func() {
|
||||
result <- s.Send(to, subject, body)
|
||||
}()
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *SMTPSender) SetTimeout(timeout time.Duration) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.timeout = timeout
|
||||
}
|
||||
1359
internal/services/email_sender_test.go
Normal file
1359
internal/services/email_sender_test.go
Normal file
File diff suppressed because it is too large
Load Diff
574
internal/services/email_service.go
Normal file
574
internal/services/email_service.go
Normal file
@@ -0,0 +1,574 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/database"
|
||||
)
|
||||
|
||||
type EmailService struct {
|
||||
EmailSender EmailSender
|
||||
baseURL string
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewEmailService(cfg *config.Config, sender EmailSender) (*EmailService, error) {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(cfg.App.BaseURL), "/")
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("APP_BASE_URL is required and must be externally reachable")
|
||||
}
|
||||
|
||||
return &EmailService{
|
||||
EmailSender: sender,
|
||||
baseURL: baseURL,
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *EmailService) SendVerificationEmail(user *database.User, token string) error {
|
||||
if s.EmailSender == nil {
|
||||
return ErrEmailSenderUnavailable
|
||||
}
|
||||
|
||||
verificationURL := fmt.Sprintf("%s/confirm?token=%s", s.baseURL, url.QueryEscape(token))
|
||||
subject := fmt.Sprintf("🎉 Welcome to %s! Confirm your email address", s.config.App.Title)
|
||||
body := s.GenerateVerificationEmailBody(user.Username, verificationURL)
|
||||
|
||||
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
|
||||
return fmt.Errorf("send verification email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *EmailService) SendEmailChangeVerificationEmail(user *database.User, token string) error {
|
||||
if s.EmailSender == nil {
|
||||
return ErrEmailSenderUnavailable
|
||||
}
|
||||
|
||||
verificationURL := fmt.Sprintf("%s/confirm?token=%s", s.baseURL, url.QueryEscape(token))
|
||||
subject := "📧 Confirm your new email address"
|
||||
body := s.GenerateEmailChangeVerificationEmailBody(user.Username, verificationURL)
|
||||
|
||||
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
|
||||
return fmt.Errorf("send email change verification email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *EmailService) SendResendVerificationEmail(user *database.User, token string) error {
|
||||
if s.EmailSender == nil {
|
||||
return ErrEmailSenderUnavailable
|
||||
}
|
||||
|
||||
verificationURL := fmt.Sprintf("%s/confirm?token=%s", s.baseURL, url.QueryEscape(token))
|
||||
subject := fmt.Sprintf("🔄 Resend: Confirm your %s account", s.config.App.Title)
|
||||
body := s.GenerateResendVerificationEmailBody(user.Username, verificationURL)
|
||||
|
||||
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
|
||||
return fmt.Errorf("send verification email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *EmailService) SendPasswordResetEmail(user *database.User, token string) error {
|
||||
if s.EmailSender == nil {
|
||||
return ErrEmailSenderUnavailable
|
||||
}
|
||||
|
||||
resetURL := fmt.Sprintf("%s/reset-password?token=%s", s.baseURL, url.QueryEscape(token))
|
||||
subject := fmt.Sprintf("Reset your %s password", s.config.App.Title)
|
||||
body := s.GeneratePasswordResetEmailBody(user.Username, resetURL)
|
||||
|
||||
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
|
||||
return fmt.Errorf("send password reset email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *EmailService) SendAccountDeletionEmail(user *database.User, token string) error {
|
||||
if s.EmailSender == nil {
|
||||
return ErrEmailSenderUnavailable
|
||||
}
|
||||
|
||||
confirmationURL := fmt.Sprintf("%s/settings/delete/confirm?token=%s", s.baseURL, url.QueryEscape(token))
|
||||
subject := "Confirm Account Deletion"
|
||||
body := s.GenerateAccountDeletionEmailBody(user.Username, confirmationURL)
|
||||
|
||||
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
|
||||
return fmt.Errorf("send account deletion email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *EmailService) SendAccountDeletionNotificationEmail(user *database.User, deletedPosts bool) error {
|
||||
if s.EmailSender == nil {
|
||||
return ErrEmailSenderUnavailable
|
||||
}
|
||||
|
||||
subject, body := GenerateAccountDeletionNotificationEmail(user.Username, s.config.App.AdminEmail, s.baseURL, s.config.App.Title, deletedPosts)
|
||||
|
||||
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
|
||||
return fmt.Errorf("send account deletion notification email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *EmailService) SendAccountLockNotificationEmail(user *database.User) error {
|
||||
if s.EmailSender == nil {
|
||||
return ErrEmailSenderUnavailable
|
||||
}
|
||||
|
||||
subject, body := GenerateAccountLockNotificationEmail(user.Username, s.config.App.AdminEmail, s.config.App.Title)
|
||||
|
||||
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
|
||||
return fmt.Errorf("send account lock notification email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *EmailService) SendAccountUnlockNotificationEmail(user *database.User) error {
|
||||
if s.EmailSender == nil {
|
||||
return ErrEmailSenderUnavailable
|
||||
}
|
||||
|
||||
subject, body := GenerateAccountUnlockNotificationEmail(user.Username, s.config.App.AdminEmail, s.baseURL, s.config.App.Title)
|
||||
|
||||
if err := s.EmailSender.Send(user.Email, subject, body); err != nil {
|
||||
return fmt.Errorf("send account unlock notification email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *EmailService) GenerateVerificationEmailBody(username, verificationURL string) string {
|
||||
safeUsername := html.EscapeString(username)
|
||||
safeURL := html.EscapeString(verificationURL)
|
||||
siteTitle := s.config.App.Title
|
||||
|
||||
return s.generateStyledEmailBody(siteTitle, "🎉 Welcome! Confirm your email address",
|
||||
fmt.Sprintf("Hello %s,", safeUsername),
|
||||
fmt.Sprintf("Welcome to %s! Please confirm your email address by clicking the link below:", siteTitle),
|
||||
safeURL, "Confirm Email Address",
|
||||
"If the link doesn't work, you can copy and paste it into your browser.\n\nOnce confirmed, you'll be able to:\n- Create and share posts with the community\n- Vote on content you find interesting\n- Connect with other members\n- Customize your profile and preferences\n\nIf you didn't create this account, you can safely ignore this email.")
|
||||
}
|
||||
|
||||
func (s *EmailService) GenerateEmailChangeVerificationEmailBody(username, verificationURL string) string {
|
||||
safeUsername := html.EscapeString(username)
|
||||
safeURL := html.EscapeString(verificationURL)
|
||||
siteTitle := s.config.App.Title
|
||||
|
||||
return s.generateStyledEmailBody(siteTitle, "📧 Confirm your new email address",
|
||||
fmt.Sprintf("Hello %s,", safeUsername),
|
||||
"You've requested to change your email address. To complete this change, please confirm your new email address by clicking the link below:\n\nIf the link doesn't work, you can copy and paste it into your browser.\n\nOnce confirmed, your new email address will be active and you'll need to use it for future logins.\n\nIf you didn't request this email change, please contact our support team immediately.",
|
||||
safeURL, "Confirm New Email Address",
|
||||
"")
|
||||
}
|
||||
|
||||
func (s *EmailService) GenerateResendVerificationEmailBody(username, verificationURL string) string {
|
||||
safeUsername := html.EscapeString(username)
|
||||
safeURL := html.EscapeString(verificationURL)
|
||||
siteTitle := s.config.App.Title
|
||||
|
||||
return s.generateStyledEmailBody(siteTitle, fmt.Sprintf("🔄 Resend: Confirm your %s account", siteTitle),
|
||||
fmt.Sprintf("Hello %s,", safeUsername),
|
||||
"We've sent you a new verification link.\n\nPlease confirm your email address by clicking the link below:\n\nIf you're having trouble with the verification link:\n- Check your spam/junk folder\n- Make sure you're clicking the most recent email\n- Try copying the link and pasting it in your browser\n- Contact support if the problem persists\n\nIf you didn't request this email, you can safely ignore this message.",
|
||||
safeURL, "Confirm Email Address",
|
||||
"")
|
||||
}
|
||||
|
||||
func (s *EmailService) GeneratePasswordResetEmailBody(username, resetURL string) string {
|
||||
safeUsername := html.EscapeString(username)
|
||||
safeURL := html.EscapeString(resetURL)
|
||||
siteTitle := s.config.App.Title
|
||||
|
||||
return s.generateStyledEmailBody(siteTitle, fmt.Sprintf("Reset your %s password", siteTitle),
|
||||
fmt.Sprintf("Hello %s,", safeUsername),
|
||||
"We received a request to reset your password. To reset it, click the link below:",
|
||||
safeURL, "Reset Password",
|
||||
"This link will expire in 24 hours. If you didn't make this request, you can safely ignore this message.")
|
||||
}
|
||||
|
||||
func (s *EmailService) GenerateAccountDeletionEmailBody(username, confirmationURL string) string {
|
||||
safeUsername := html.EscapeString(username)
|
||||
safeURL := html.EscapeString(confirmationURL)
|
||||
siteTitle := s.config.App.Title
|
||||
|
||||
return s.generateStyledEmailBody(siteTitle, "Confirm Account Deletion",
|
||||
fmt.Sprintf("Hello %s,", safeUsername),
|
||||
fmt.Sprintf("We received a request to delete your %s account.\n\nTo confirm, click the link below:", siteTitle),
|
||||
safeURL, "Confirm Account Deletion",
|
||||
"This link will expire in 24 hours.\n\nIf you didn't make this request, you can safely ignore this message.")
|
||||
}
|
||||
|
||||
func (s *EmailService) generateStyledEmailBody(siteTitle, emailTitle, greeting, message, actionURL, actionText, footer string) string {
|
||||
safeAdminEmail := html.EscapeString(s.config.App.AdminEmail)
|
||||
|
||||
var buttonHTML string
|
||||
if actionURL != "" && actionText != "" {
|
||||
buttonHTML = fmt.Sprintf(`<div style="text-align: center;">
|
||||
<a href="%s" class="action-button">%s</a>
|
||||
</div>`, actionURL, actionText)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>%s</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
line-height: 1.6;
|
||||
color: #333;
|
||||
max-width: 600px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
background-color: #f8fafc;
|
||||
}
|
||||
.email-container {
|
||||
background: white;
|
||||
border-radius: 12px;
|
||||
padding: 40px;
|
||||
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
|
||||
border: 1px solid #e2e8f0;
|
||||
}
|
||||
.header {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.logo {
|
||||
font-size: 28px;
|
||||
font-weight: 700;
|
||||
color: #0fb9b1;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.title {
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
color: #1a202c;
|
||||
margin: 0;
|
||||
}
|
||||
.content {
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.greeting {
|
||||
font-size: 16px;
|
||||
margin-bottom: 20px;
|
||||
color: #2d3748;
|
||||
}
|
||||
.message {
|
||||
font-size: 16px;
|
||||
margin-bottom: 30px;
|
||||
color: #4a5568;
|
||||
white-space: pre-line;
|
||||
}
|
||||
.action-button {
|
||||
display: inline-block;
|
||||
background: #0fb9b1;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
padding: 12px 24px;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
text-align: center;
|
||||
margin: 20px 0;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
.action-button:hover {
|
||||
background: #0ea5a0;
|
||||
}
|
||||
.footer {
|
||||
font-size: 14px;
|
||||
color: #718096;
|
||||
margin-top: 30px;
|
||||
padding-top: 20px;
|
||||
border-top: 1px solid #e2e8f0;
|
||||
white-space: pre-line;
|
||||
}
|
||||
.link {
|
||||
color: #0fb9b1;
|
||||
text-decoration: none;
|
||||
}
|
||||
.link:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
@media (max-width: 600px) {
|
||||
body {
|
||||
padding: 10px;
|
||||
}
|
||||
.email-container {
|
||||
padding: 20px;
|
||||
}
|
||||
.title {
|
||||
font-size: 20px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="email-container">
|
||||
<div class="header">
|
||||
<div class="logo">%s</div>
|
||||
<h1 class="title">%s</h1>
|
||||
</div>
|
||||
|
||||
<div class="content">
|
||||
<div class="greeting">%s</div>
|
||||
<div class="message">%s</div>
|
||||
%s
|
||||
</div>
|
||||
|
||||
<div class="footer">
|
||||
%s
|
||||
|
||||
If you have any questions or concerns, please <a href="mailto:%s" class="link">contact our support team</a>.<br>
|
||||
Best regards,<br>
|
||||
The %s Team
|
||||
</div>
|
||||
|
||||
<div class="powered-by" style="text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #e2e8f0; font-size: 12px; color: #718096;">
|
||||
Powered with ❤️ by <a href="https://goyco" style="color: #0fb9b1; text-decoration: none;">Goyco</a>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`, emailTitle, siteTitle, emailTitle, greeting, message, buttonHTML, footer, safeAdminEmail, siteTitle)
|
||||
}
|
||||
|
||||
func GenerateAccountDeletionNotificationEmail(username, adminEmail, baseURL, siteTitle string, deletedPosts bool) (string, string) {
|
||||
subject := "Account Deleted"
|
||||
safeUsername := html.EscapeString(username)
|
||||
safeAdminEmail := html.EscapeString(adminEmail)
|
||||
var body string
|
||||
if deletedPosts {
|
||||
body = fmt.Sprintf(`Your %s account has been permanently deleted.
|
||||
|
||||
All your posts have also been removed.
|
||||
|
||||
If you did not request this change, please <a href="mailto:%s" class="link">contact us</a>.
|
||||
|
||||
Thank you for being part of the community.`, siteTitle, safeAdminEmail)
|
||||
} else {
|
||||
body = fmt.Sprintf(`Your %s account has been permanently deleted.
|
||||
|
||||
Your posts have been preserved and are now anonymous.
|
||||
|
||||
If you did not request this change, please <a href="mailto:%s" class="link">contact us</a>.
|
||||
|
||||
Thank you for being part of the community.`, siteTitle, safeAdminEmail)
|
||||
}
|
||||
|
||||
mainPageURL := strings.TrimRight(baseURL, "/")
|
||||
safeMainPageURL := html.EscapeString(mainPageURL)
|
||||
|
||||
styledBody := generateStyledEmailBodyStatic(siteTitle, "Account Deleted",
|
||||
fmt.Sprintf("Hello %s,", safeUsername),
|
||||
body,
|
||||
safeMainPageURL, fmt.Sprintf("Visit %s", siteTitle),
|
||||
"", adminEmail)
|
||||
|
||||
return subject, styledBody
|
||||
}
|
||||
|
||||
func GenerateAdminAccountDeletionNotificationEmail(username, adminEmail, baseURL, siteTitle string, deletedPosts bool) (string, string) {
|
||||
subject := "Account Deleted by Administrator"
|
||||
safeUsername := html.EscapeString(username)
|
||||
safeAdminEmail := html.EscapeString(adminEmail)
|
||||
|
||||
var message string
|
||||
if deletedPosts {
|
||||
message = fmt.Sprintf("Your %s account has been permanently deleted by an administrator.\n\nAll your posts have also been removed.\n\nIf you did not request this change, please <a href=\"mailto:%s\" class=\"link\">contact us</a>.\n\nThank you for being part of the community.", siteTitle, safeAdminEmail)
|
||||
} else {
|
||||
message = fmt.Sprintf("Your %s account has been permanently deleted by an administrator.\n\nYour posts have been preserved and are now anonymous.\n\nIf you did not request this change, please <a href=\"mailto:%s\" class=\"link\">contact us</a>.\n\nThank you for being part of the community.", siteTitle, safeAdminEmail)
|
||||
}
|
||||
|
||||
mainPageURL := strings.TrimRight(baseURL, "/")
|
||||
safeMainPageURL := html.EscapeString(mainPageURL)
|
||||
|
||||
body := generateStyledEmailBodyStatic(siteTitle, "Account Deleted by Administrator",
|
||||
fmt.Sprintf("Hello %s,", safeUsername),
|
||||
message,
|
||||
safeMainPageURL, fmt.Sprintf("Visit %s", siteTitle),
|
||||
"", adminEmail)
|
||||
|
||||
return subject, body
|
||||
}
|
||||
|
||||
func GenerateAccountLockNotificationEmail(username, adminEmail, siteTitle string) (string, string) {
|
||||
subject := "Account Locked"
|
||||
safeUsername := html.EscapeString(username)
|
||||
safeAdminEmail := html.EscapeString(adminEmail)
|
||||
message := fmt.Sprintf("Your %s account has been locked by an administrator.\n\nYou will not be able to log in or access your account until it is unlocked.\n\nIf you believe this is an error or need to discuss your account status, please <a href=\"mailto:%s\" class=\"link\">contact us</a>.\n\nThis action was taken to protect the security and integrity of our platform.", siteTitle, safeAdminEmail)
|
||||
|
||||
body := generateStyledEmailBodyStatic(siteTitle, "Account Locked",
|
||||
fmt.Sprintf("Hello %s,", safeUsername),
|
||||
message,
|
||||
"", "",
|
||||
"", adminEmail)
|
||||
|
||||
return subject, body
|
||||
}
|
||||
|
||||
func GenerateAccountUnlockNotificationEmail(username, adminEmail, baseURL, siteTitle string) (string, string) {
|
||||
subject := "Account Unlocked"
|
||||
safeUsername := html.EscapeString(username)
|
||||
message := fmt.Sprintf("Your %s account has been unlocked by an administrator.\n\nYou can now log in and access all your account features normally.\n\nAll your previous data and settings have been preserved.\n\nWelcome back!", siteTitle)
|
||||
|
||||
loginURL := fmt.Sprintf("%s/login", strings.TrimRight(baseURL, "/"))
|
||||
safeLoginURL := html.EscapeString(loginURL)
|
||||
body := generateStyledEmailBodyStatic(siteTitle, "Account Unlocked",
|
||||
fmt.Sprintf("Hello %s,", safeUsername),
|
||||
message,
|
||||
safeLoginURL, fmt.Sprintf("Login to %s", siteTitle),
|
||||
"", adminEmail)
|
||||
|
||||
return subject, body
|
||||
}
|
||||
|
||||
func generateStyledEmailBodyStatic(siteTitle, emailTitle, greeting, message, actionURL, actionText, footer, adminEmail string) string {
|
||||
safeAdminEmail := html.EscapeString(adminEmail)
|
||||
|
||||
var buttonHTML string
|
||||
if actionURL != "" && actionText != "" {
|
||||
buttonHTML = fmt.Sprintf(`<div style="text-align: center;">
|
||||
<a href="%s" class="action-button">%s</a>
|
||||
</div>`, actionURL, actionText)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>%s</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
line-height: 1.6;
|
||||
color: #333;
|
||||
max-width: 600px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
background-color: #f8fafc;
|
||||
}
|
||||
.email-container {
|
||||
background: white;
|
||||
border-radius: 12px;
|
||||
padding: 40px;
|
||||
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
|
||||
border: 1px solid #e2e8f0;
|
||||
}
|
||||
.header {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.logo {
|
||||
font-size: 28px;
|
||||
font-weight: 700;
|
||||
color: #0fb9b1;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.title {
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
color: #1a202c;
|
||||
margin: 0;
|
||||
}
|
||||
.content {
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.greeting {
|
||||
font-size: 16px;
|
||||
margin-bottom: 20px;
|
||||
color: #2d3748;
|
||||
}
|
||||
.message {
|
||||
font-size: 16px;
|
||||
margin-bottom: 30px;
|
||||
color: #4a5568;
|
||||
white-space: pre-line;
|
||||
}
|
||||
.action-button {
|
||||
display: inline-block;
|
||||
background: #0fb9b1;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
padding: 12px 24px;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
text-align: center;
|
||||
margin: 20px 0;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
.action-button:hover {
|
||||
background: #0ea5a0;
|
||||
}
|
||||
.footer {
|
||||
font-size: 14px;
|
||||
color: #718096;
|
||||
margin-top: 30px;
|
||||
padding-top: 20px;
|
||||
border-top: 1px solid #e2e8f0;
|
||||
white-space: pre-line;
|
||||
}
|
||||
.link {
|
||||
color: #0fb9b1;
|
||||
text-decoration: none;
|
||||
}
|
||||
.link:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
@media (max-width: 600px) {
|
||||
body {
|
||||
padding: 10px;
|
||||
}
|
||||
.email-container {
|
||||
padding: 20px;
|
||||
}
|
||||
.title {
|
||||
font-size: 20px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="email-container">
|
||||
<div class="header">
|
||||
<div class="logo">%s</div>
|
||||
<h1 class="title">%s</h1>
|
||||
</div>
|
||||
|
||||
<div class="content">
|
||||
<div class="greeting">%s</div>
|
||||
<div class="message">%s</div>
|
||||
%s
|
||||
</div>
|
||||
|
||||
<div class="footer">
|
||||
%s
|
||||
|
||||
If you have any questions or concerns, please <a href="mailto:%s" class="link">contact our support team</a>.<br>
|
||||
Best regards,<br>
|
||||
The %s Team
|
||||
</div>
|
||||
|
||||
<div class="powered-by" style="text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #e2e8f0; font-size: 12px; color: #718096;">
|
||||
Powered with ❤️ by <a href="https://goyco" style="color: #0fb9b1; text-decoration: none;">Goyco</a>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`, emailTitle, siteTitle, emailTitle, greeting, message, buttonHTML, footer, safeAdminEmail, siteTitle)
|
||||
}
|
||||
275
internal/services/email_service_test.go
Normal file
275
internal/services/email_service_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/testutils"
|
||||
)
|
||||
|
||||
func TestEmailService_IntegrationWithRealSMTP(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
sender := testutils.GetSMTPSenderFromEnv(t)
|
||||
|
||||
recipient := os.Getenv("SMTP_TEST_RECIPIENT")
|
||||
if strings.TrimSpace(recipient) == "" {
|
||||
recipient = sender.From
|
||||
}
|
||||
|
||||
config := testutils.NewEmailTestConfig("https://example.com")
|
||||
service, err := NewEmailService(config, sender)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping SMTP integration test: failed to create email service: %v", err)
|
||||
}
|
||||
|
||||
user := testutils.NewEmailTestUser("integrationuser", recipient)
|
||||
token := "integration-test-token"
|
||||
|
||||
t.Run("VerificationEmail_RealSMTP", func(t *testing.T) {
|
||||
err := service.SendVerificationEmail(user, token)
|
||||
if err != nil {
|
||||
t.Errorf("SendVerificationEmail failed: %v", err)
|
||||
}
|
||||
|
||||
body := service.GenerateVerificationEmailBody(user.Username, fmt.Sprintf("%s/confirm?token=%s", config.App.BaseURL, token))
|
||||
if !strings.Contains(body, "https://example.com/confirm?token=integration-test-token") {
|
||||
t.Errorf("Expected body to contain verification URL")
|
||||
}
|
||||
if !strings.Contains(body, "<!DOCTYPE html>") {
|
||||
t.Errorf("Expected HTML content in email body")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PasswordResetEmail_RealSMTP", func(t *testing.T) {
|
||||
err := service.SendPasswordResetEmail(user, token)
|
||||
if err != nil {
|
||||
t.Errorf("SendPasswordResetEmail failed: %v", err)
|
||||
}
|
||||
|
||||
body := service.GeneratePasswordResetEmailBody(user.Username, fmt.Sprintf("%s/reset-password?token=%s", config.App.BaseURL, token))
|
||||
if !strings.Contains(body, "https://example.com/reset-password?token=integration-test-token") {
|
||||
t.Errorf("Expected body to contain reset URL")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AccountDeletionEmail_RealSMTP", func(t *testing.T) {
|
||||
err := service.SendAccountDeletionEmail(user, token)
|
||||
if err != nil {
|
||||
t.Errorf("SendAccountDeletionEmail failed: %v", err)
|
||||
}
|
||||
|
||||
body := service.GenerateAccountDeletionEmailBody(user.Username, fmt.Sprintf("%s/settings/delete/confirm?token=%s", config.App.BaseURL, token))
|
||||
if !strings.Contains(body, "https://example.com/settings/delete/confirm?token=integration-test-token") {
|
||||
t.Errorf("Expected body to contain deletion confirmation URL")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmailService_Performance(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping performance test in short mode")
|
||||
}
|
||||
|
||||
config := testutils.NewEmailTestConfig("https://example.com")
|
||||
service, err := NewEmailService(config, &testutils.MockEmailSender{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create email service: %v", err)
|
||||
}
|
||||
|
||||
user := testutils.NewEmailTestUser("perfuser", "perf@example.com")
|
||||
|
||||
service.GenerateVerificationEmailBody(user.Username, "https://example.com/confirm?token=test")
|
||||
|
||||
start := time.Now()
|
||||
iterations := 1000
|
||||
for i := 0; i < iterations; i++ {
|
||||
service.GenerateVerificationEmailBody(user.Username, "https://example.com/confirm?token=test")
|
||||
}
|
||||
duration := time.Since(start)
|
||||
|
||||
maxDuration := 500 * time.Millisecond
|
||||
if duration > maxDuration {
|
||||
t.Errorf("HTML generation took too long: %v (expected < %v for %d iterations, %.2fms per template)",
|
||||
duration, maxDuration, iterations, float64(duration.Nanoseconds())/float64(iterations)/1e6)
|
||||
}
|
||||
|
||||
t.Logf("Generated %d HTML emails in %v (%.2fms per template)",
|
||||
iterations, duration, float64(duration.Nanoseconds())/float64(iterations)/1e6)
|
||||
}
|
||||
|
||||
func TestEmailService_EdgeCases(t *testing.T) {
|
||||
config := testutils.NewEmailTestConfig("https://example.com")
|
||||
service, err := NewEmailService(config, &testutils.MockEmailSender{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create email service: %v", err)
|
||||
}
|
||||
|
||||
t.Run("EmptyUsername", func(t *testing.T) {
|
||||
body := service.GenerateVerificationEmailBody("", "https://example.com/confirm?token=test")
|
||||
if !strings.Contains(body, "Hello ,") {
|
||||
t.Error("Expected empty username to be handled gracefully")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("VeryLongUsername", func(t *testing.T) {
|
||||
longUsername := strings.Repeat("a", 1000)
|
||||
body := service.GenerateVerificationEmailBody(longUsername, "https://example.com/confirm?token=test")
|
||||
if !strings.Contains(body, longUsername) {
|
||||
t.Error("Expected long username to be included in email")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SpecialCharactersInUsername", func(t *testing.T) {
|
||||
specialUsername := "user@domain.com & <script>alert('xss')</script>"
|
||||
body := service.GenerateVerificationEmailBody(specialUsername, "https://example.com/confirm?token=test")
|
||||
escapedUsername := html.EscapeString(specialUsername)
|
||||
if !strings.Contains(body, escapedUsername) {
|
||||
t.Errorf("Expected escaped username %q to be included", escapedUsername)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyToken", func(t *testing.T) {
|
||||
body := service.GenerateVerificationEmailBody("testuser", "https://example.com/confirm?token=")
|
||||
if !strings.Contains(body, "https://example.com/confirm?token=") {
|
||||
t.Error("Expected empty token to be handled")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("VeryLongToken", func(t *testing.T) {
|
||||
longToken := strings.Repeat("a", 1000)
|
||||
url := fmt.Sprintf("https://example.com/confirm?token=%s", longToken)
|
||||
body := service.GenerateVerificationEmailBody("testuser", url)
|
||||
if !strings.Contains(body, url) {
|
||||
t.Error("Expected long token to be included in email")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewEmailService(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *config.Config
|
||||
sender EmailSender
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid configuration",
|
||||
config: testutils.NewEmailTestConfig("https://example.com"),
|
||||
sender: &testutils.MockEmailSender{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Empty base URL",
|
||||
config: testutils.NewEmailTestConfig(""),
|
||||
sender: &testutils.MockEmailSender{},
|
||||
expectError: true,
|
||||
errorMsg: "APP_BASE_URL is required",
|
||||
},
|
||||
{
|
||||
name: "Whitespace base URL",
|
||||
config: testutils.NewEmailTestConfig(" "),
|
||||
sender: &testutils.MockEmailSender{},
|
||||
expectError: true,
|
||||
errorMsg: "APP_BASE_URL is required",
|
||||
},
|
||||
{
|
||||
name: "Base URL with trailing slash",
|
||||
config: testutils.NewEmailTestConfig("https://example.com/"),
|
||||
sender: &testutils.MockEmailSender{},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
service, err := NewEmailService(tt.config, tt.sender)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errorMsg) {
|
||||
t.Errorf("Expected error to contain '%s', got '%s'", tt.errorMsg, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if service == nil {
|
||||
t.Error("Expected service, got nil")
|
||||
return
|
||||
}
|
||||
|
||||
expectedBaseURL := strings.TrimRight(strings.TrimSpace(tt.config.App.BaseURL), "/")
|
||||
if service.baseURL != expectedBaseURL {
|
||||
t.Errorf("Expected baseURL '%s', got '%s'", expectedBaseURL, service.baseURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmailService_DynamicTitle(t *testing.T) {
|
||||
const (
|
||||
placeholderTitle = "My Custom Site"
|
||||
customTitle = "Custom Community"
|
||||
)
|
||||
|
||||
cfg := &config.Config{
|
||||
App: config.AppConfig{
|
||||
Title: customTitle,
|
||||
BaseURL: "https://example.com",
|
||||
},
|
||||
}
|
||||
|
||||
sender := &testutils.MockEmailSender{}
|
||||
service, err := NewEmailService(cfg, sender)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create email service: %v", err)
|
||||
}
|
||||
|
||||
body := service.GenerateVerificationEmailBody("testuser", "https://example.com/confirm?token=abc123")
|
||||
|
||||
if !strings.Contains(body, customTitle) {
|
||||
t.Error("Expected email body to contain custom site title")
|
||||
}
|
||||
|
||||
if strings.Contains(body, placeholderTitle) {
|
||||
t.Errorf("Expected email body to not contain placeholder title %q", placeholderTitle)
|
||||
}
|
||||
|
||||
if strings.Contains(body, "The Goyco Team") {
|
||||
t.Error("Expected email body to not contain default team name when custom title is set")
|
||||
}
|
||||
|
||||
cfgDefault := &config.Config{
|
||||
App: config.AppConfig{
|
||||
Title: "Goyco",
|
||||
BaseURL: "https://example.com",
|
||||
},
|
||||
}
|
||||
|
||||
serviceDefault, err := NewEmailService(cfgDefault, sender)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create email service: %v", err)
|
||||
}
|
||||
|
||||
bodyDefault := serviceDefault.GenerateVerificationEmailBody("testuser", "https://example.com/confirm?token=abc123")
|
||||
|
||||
if !strings.Contains(bodyDefault, "Goyco") {
|
||||
t.Error("Expected email body to contain default site title")
|
||||
}
|
||||
}
|
||||
360
internal/services/jwt_service.go
Normal file
360
internal/services/jwt_service.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenTypeAccess = "access"
|
||||
TokenTypeRefresh = "refresh"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidTokenType = errors.New("invalid token type")
|
||||
ErrTokenExpired = errors.New("token expired")
|
||||
ErrInvalidIssuer = errors.New("invalid issuer")
|
||||
ErrInvalidAudience = errors.New("invalid audience")
|
||||
ErrInvalidKeyID = errors.New("invalid key ID")
|
||||
ErrRefreshTokenExpired = errors.New("refresh token expired")
|
||||
ErrRefreshTokenInvalid = errors.New("refresh token invalid")
|
||||
)
|
||||
|
||||
type TokenClaims struct {
|
||||
UserID uint `json:"sub"`
|
||||
Username string `json:"username"`
|
||||
SessionVersion uint `json:"session_version"`
|
||||
TokenType string `json:"type"`
|
||||
KeyID string `json:"kid,omitempty"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type JWTService struct {
|
||||
config *config.JWTConfig
|
||||
userRepo UserRepository
|
||||
refreshRepo repositories.RefreshTokenRepositoryInterface
|
||||
}
|
||||
|
||||
type verificationKey struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
type UserRepository interface {
|
||||
GetByID(id uint) (*database.User, error)
|
||||
GetByUsername(username string) (*database.User, error)
|
||||
Update(user *database.User) error
|
||||
}
|
||||
|
||||
func NewJWTService(cfg *config.JWTConfig, userRepo UserRepository, refreshRepo repositories.RefreshTokenRepositoryInterface) *JWTService {
|
||||
return &JWTService{
|
||||
config: cfg,
|
||||
userRepo: userRepo,
|
||||
refreshRepo: refreshRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (j *JWTService) GenerateAccessToken(user *database.User) (string, error) {
|
||||
return j.generateToken(user, TokenTypeAccess, time.Duration(j.config.Expiration)*time.Hour)
|
||||
}
|
||||
|
||||
func (j *JWTService) GenerateRefreshToken(user *database.User) (string, error) {
|
||||
if user == nil {
|
||||
return "", ErrInvalidCredentials
|
||||
}
|
||||
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return "", fmt.Errorf("generate refresh token: %w", err)
|
||||
}
|
||||
|
||||
tokenString := hex.EncodeToString(tokenBytes)
|
||||
tokenHash := j.hashToken(tokenString)
|
||||
|
||||
refreshToken := &database.RefreshToken{
|
||||
UserID: user.ID,
|
||||
TokenHash: tokenHash,
|
||||
ExpiresAt: time.Now().Add(time.Duration(j.config.RefreshExpiration) * time.Hour),
|
||||
}
|
||||
|
||||
if err := j.refreshRepo.Create(refreshToken); err != nil {
|
||||
return "", fmt.Errorf("store refresh token: %w", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
func (j *JWTService) VerifyAccessToken(tokenString string) (uint, error) {
|
||||
claims, err := j.parseToken(tokenString)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if claims.TokenType != TokenTypeAccess {
|
||||
return 0, ErrInvalidTokenType
|
||||
}
|
||||
|
||||
user, err := j.userRepo.GetByID(claims.UserID)
|
||||
if err != nil {
|
||||
if IsRecordNotFound(err) {
|
||||
return 0, ErrInvalidToken
|
||||
}
|
||||
return 0, fmt.Errorf("lookup user: %w", err)
|
||||
}
|
||||
|
||||
if user.Locked {
|
||||
return 0, ErrAccountLocked
|
||||
}
|
||||
|
||||
if user.SessionVersion != claims.SessionVersion {
|
||||
return 0, ErrInvalidToken
|
||||
}
|
||||
|
||||
return claims.UserID, nil
|
||||
}
|
||||
|
||||
func (j *JWTService) RefreshAccessToken(refreshTokenString string) (string, error) {
|
||||
|
||||
tokenHash := j.hashToken(refreshTokenString)
|
||||
|
||||
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", ErrRefreshTokenInvalid
|
||||
}
|
||||
return "", fmt.Errorf("lookup refresh token: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().After(refreshToken.ExpiresAt) {
|
||||
|
||||
j.refreshRepo.DeleteByID(refreshToken.ID)
|
||||
return "", ErrRefreshTokenExpired
|
||||
}
|
||||
|
||||
user, err := j.userRepo.GetByID(refreshToken.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
|
||||
j.refreshRepo.DeleteByID(refreshToken.ID)
|
||||
return "", ErrRefreshTokenInvalid
|
||||
}
|
||||
return "", fmt.Errorf("lookup user: %w", err)
|
||||
}
|
||||
|
||||
if user.Locked {
|
||||
|
||||
j.refreshRepo.DeleteByID(refreshToken.ID)
|
||||
return "", ErrAccountLocked
|
||||
}
|
||||
|
||||
accessToken, err := j.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate access token: %w", err)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (j *JWTService) RevokeRefreshToken(refreshTokenString string) error {
|
||||
tokenHash := j.hashToken(refreshTokenString)
|
||||
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("lookup refresh token: %w", err)
|
||||
}
|
||||
|
||||
return j.refreshRepo.DeleteByID(refreshToken.ID)
|
||||
}
|
||||
|
||||
func (j *JWTService) RevokeAllRefreshTokens(userID uint) error {
|
||||
return j.refreshRepo.DeleteByUserID(userID)
|
||||
}
|
||||
|
||||
func (j *JWTService) CleanupExpiredTokens() error {
|
||||
return j.refreshRepo.DeleteExpired()
|
||||
}
|
||||
|
||||
func (j *JWTService) generateToken(user *database.User, tokenType string, expiration time.Duration) (string, error) {
|
||||
if user == nil {
|
||||
return "", ErrInvalidCredentials
|
||||
}
|
||||
|
||||
jtiBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(jtiBytes); err != nil {
|
||||
return "", fmt.Errorf("generate token ID: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
claims := TokenClaims{
|
||||
UserID: user.ID,
|
||||
Username: user.Username,
|
||||
SessionVersion: user.SessionVersion,
|
||||
TokenType: tokenType,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ID: hex.EncodeToString(jtiBytes),
|
||||
Issuer: j.config.Issuer,
|
||||
Audience: []string{j.config.Audience},
|
||||
Subject: fmt.Sprint(user.ID),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(expiration)),
|
||||
},
|
||||
}
|
||||
|
||||
if j.config.KeyRotation.Enabled {
|
||||
claims.KeyID = j.config.KeyRotation.KeyID
|
||||
}
|
||||
|
||||
var signingMethod jwt.SigningMethod
|
||||
var key any
|
||||
|
||||
if j.config.KeyRotation.Enabled {
|
||||
signingMethod = jwt.SigningMethodHS256
|
||||
key = []byte(j.config.KeyRotation.CurrentKey)
|
||||
} else {
|
||||
signingMethod = jwt.SigningMethodHS256
|
||||
key = []byte(j.config.Secret)
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(signingMethod, claims)
|
||||
|
||||
if j.config.KeyRotation.Enabled {
|
||||
token.Header["kid"] = j.config.KeyRotation.KeyID
|
||||
}
|
||||
|
||||
return token.SignedString(key)
|
||||
}
|
||||
|
||||
func (j *JWTService) parseToken(tokenString string) (*TokenClaims, error) {
|
||||
if TrimString(tokenString) == "" {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
parser := jwt.NewParser()
|
||||
unverified, _, err := parser.ParseUnverified(tokenString, jwt.MapClaims{})
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
headerKid, _ := unverified.Header["kid"].(string)
|
||||
|
||||
if j.config.KeyRotation.Enabled {
|
||||
if headerKid != "" && headerKid != j.config.KeyRotation.KeyID && j.config.KeyRotation.PreviousKey == "" {
|
||||
return nil, ErrInvalidKeyID
|
||||
}
|
||||
} else if headerKid != "" {
|
||||
return nil, ErrInvalidKeyID
|
||||
}
|
||||
|
||||
keys := j.verificationKeys()
|
||||
if len(keys) == 0 {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, candidate := range keys {
|
||||
claims := &TokenClaims{}
|
||||
token, err := jwt.ParseWithClaims(tokenString, claims, func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != jwt.SigningMethodHS256.Alg() {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return candidate.key, nil
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
lastErr = ErrInvalidToken
|
||||
continue
|
||||
}
|
||||
|
||||
if claims.ExpiresAt == nil || time.Until(claims.ExpiresAt.Time) < 0 {
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
lastErr = ErrInvalidToken
|
||||
continue
|
||||
}
|
||||
|
||||
if err := j.validateTokenMetadata(token, claims, headerKid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
func (j *JWTService) verificationKeys() []verificationKey {
|
||||
if j.config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if j.config.KeyRotation.Enabled {
|
||||
keys := []verificationKey{{key: []byte(j.config.KeyRotation.CurrentKey)}}
|
||||
if j.config.KeyRotation.PreviousKey != "" {
|
||||
keys = append(keys, verificationKey{key: []byte(j.config.KeyRotation.PreviousKey)})
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
return []verificationKey{{key: []byte(j.config.Secret)}}
|
||||
}
|
||||
|
||||
func (j *JWTService) validateTokenMetadata(token *jwt.Token, claims *TokenClaims, headerKid string) error {
|
||||
actualKid, _ := token.Header["kid"].(string)
|
||||
if actualKid == "" {
|
||||
actualKid = headerKid
|
||||
}
|
||||
|
||||
if j.config.KeyRotation.Enabled {
|
||||
if actualKid == "" {
|
||||
if claims.KeyID != "" {
|
||||
return ErrInvalidKeyID
|
||||
}
|
||||
} else {
|
||||
if claims.KeyID == "" || claims.KeyID != actualKid {
|
||||
return ErrInvalidKeyID
|
||||
}
|
||||
}
|
||||
|
||||
if actualKid != "" && actualKid != j.config.KeyRotation.KeyID && j.config.KeyRotation.PreviousKey == "" {
|
||||
return ErrInvalidKeyID
|
||||
}
|
||||
} else {
|
||||
if actualKid != "" || claims.KeyID != "" {
|
||||
return ErrInvalidKeyID
|
||||
}
|
||||
}
|
||||
|
||||
if claims.Issuer != j.config.Issuer {
|
||||
return ErrInvalidIssuer
|
||||
}
|
||||
|
||||
if !slices.Contains(claims.Audience, j.config.Audience) {
|
||||
return ErrInvalidAudience
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *JWTService) hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
966
internal/services/jwt_service_test.go
Normal file
966
internal/services/jwt_service_test.go
Normal file
@@ -0,0 +1,966 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/database"
|
||||
)
|
||||
|
||||
type jwtMockUserRepo struct {
|
||||
users map[uint]*database.User
|
||||
}
|
||||
|
||||
func (m *jwtMockUserRepo) GetByID(id uint) (*database.User, error) {
|
||||
if user, exists := m.users[id]; exists {
|
||||
return user, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *jwtMockUserRepo) GetByUsername(username string) (*database.User, error) {
|
||||
for _, user := range m.users {
|
||||
if user.Username == username {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *jwtMockUserRepo) Update(user *database.User) error {
|
||||
if _, exists := m.users[user.ID]; !exists {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
m.users[user.ID] = user
|
||||
return nil
|
||||
}
|
||||
|
||||
type jwtMockRefreshTokenRepo struct {
|
||||
tokens map[string]*database.RefreshToken
|
||||
nextID uint
|
||||
}
|
||||
|
||||
func (m *jwtMockRefreshTokenRepo) Create(token *database.RefreshToken) error {
|
||||
if m.tokens == nil {
|
||||
m.tokens = make(map[string]*database.RefreshToken)
|
||||
}
|
||||
m.nextID++
|
||||
token.ID = m.nextID
|
||||
m.tokens[token.TokenHash] = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *jwtMockRefreshTokenRepo) GetByTokenHash(tokenHash string) (*database.RefreshToken, error) {
|
||||
if token, exists := m.tokens[tokenHash]; exists {
|
||||
return token, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *jwtMockRefreshTokenRepo) DeleteByUserID(userID uint) error {
|
||||
for hash, token := range m.tokens {
|
||||
if token.UserID == userID {
|
||||
delete(m.tokens, hash)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *jwtMockRefreshTokenRepo) DeleteExpired() error {
|
||||
now := time.Now()
|
||||
for hash, token := range m.tokens {
|
||||
if token.ExpiresAt.Before(now) {
|
||||
delete(m.tokens, hash)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *jwtMockRefreshTokenRepo) DeleteByID(id uint) error {
|
||||
for hash, token := range m.tokens {
|
||||
if token.ID == id {
|
||||
delete(m.tokens, hash)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *jwtMockRefreshTokenRepo) GetByUserID(userID uint) ([]database.RefreshToken, error) {
|
||||
var tokens []database.RefreshToken
|
||||
for _, token := range m.tokens {
|
||||
if token.UserID == userID {
|
||||
tokens = append(tokens, *token)
|
||||
}
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (m *jwtMockRefreshTokenRepo) CountByUserID(userID uint) (int64, error) {
|
||||
var count int64
|
||||
for _, token := range m.tokens {
|
||||
if token.UserID == userID {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func createTestJWTService() (*JWTService, *jwtMockUserRepo, *jwtMockRefreshTokenRepo) {
|
||||
cfg := &config.JWTConfig{
|
||||
Secret: "test-secret-key-that-is-long-enough-for-security",
|
||||
Expiration: 1,
|
||||
RefreshExpiration: 24,
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeyRotation: config.KeyRotationConfig{
|
||||
Enabled: false,
|
||||
CurrentKey: "",
|
||||
PreviousKey: "",
|
||||
KeyID: "",
|
||||
},
|
||||
}
|
||||
|
||||
userRepo := &jwtMockUserRepo{
|
||||
users: make(map[uint]*database.User),
|
||||
}
|
||||
refreshRepo := &jwtMockRefreshTokenRepo{
|
||||
tokens: make(map[string]*database.RefreshToken),
|
||||
}
|
||||
|
||||
jwtService := NewJWTService(cfg, userRepo, refreshRepo)
|
||||
return jwtService, userRepo, refreshRepo
|
||||
}
|
||||
|
||||
func createTestUser() *database.User {
|
||||
return &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashedpassword",
|
||||
EmailVerified: true,
|
||||
Locked: false,
|
||||
SessionVersion: 1,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTService_GenerateAccessToken(t *testing.T) {
|
||||
jwtService, userRepo, _ := createTestJWTService()
|
||||
user := createTestUser()
|
||||
userRepo.users[user.ID] = user
|
||||
|
||||
t.Run("Successful_Generation", func(t *testing.T) {
|
||||
token, err := jwtService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful token generation, got error: %v", err)
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Error("Expected non-empty token")
|
||||
}
|
||||
|
||||
parsedToken, err := jwt.Parse(token, func(token *jwt.Token) (any, error) {
|
||||
return []byte(jwtService.config.Secret), nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse generated token: %v", err)
|
||||
}
|
||||
|
||||
if !parsedToken.Valid {
|
||||
t.Error("Generated token should be valid")
|
||||
}
|
||||
|
||||
if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
|
||||
if claims["sub"] != float64(user.ID) {
|
||||
t.Errorf("Expected subject %d, got %v", user.ID, claims["sub"])
|
||||
}
|
||||
if claims["username"] != user.Username {
|
||||
t.Errorf("Expected username %s, got %v", user.Username, claims["username"])
|
||||
}
|
||||
if claims["session_version"] != float64(user.SessionVersion) {
|
||||
t.Errorf("Expected session_version %d, got %v", user.SessionVersion, claims["session_version"])
|
||||
}
|
||||
if claims["type"] != TokenTypeAccess {
|
||||
t.Errorf("Expected type %s, got %v", TokenTypeAccess, claims["type"])
|
||||
}
|
||||
if claims["iss"] != jwtService.config.Issuer {
|
||||
t.Errorf("Expected issuer %s, got %v", jwtService.config.Issuer, claims["iss"])
|
||||
}
|
||||
if aud, ok := claims["aud"].([]any); !ok || len(aud) != 1 || aud[0] != jwtService.config.Audience {
|
||||
t.Errorf("Expected audience [%s], got %v", jwtService.config.Audience, claims["aud"])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Nil_User", func(t *testing.T) {
|
||||
_, err := jwtService.GenerateAccessToken(nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil user")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidCredentials) {
|
||||
t.Errorf("Expected ErrInvalidCredentials, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTService_GenerateRefreshToken(t *testing.T) {
|
||||
jwtService, userRepo, refreshRepo := createTestJWTService()
|
||||
user := createTestUser()
|
||||
userRepo.users[user.ID] = user
|
||||
|
||||
t.Run("Successful_Generation", func(t *testing.T) {
|
||||
token, err := jwtService.GenerateRefreshToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful refresh token generation, got error: %v", err)
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Error("Expected non-empty refresh token")
|
||||
}
|
||||
|
||||
tokenHash := jwtService.hashToken(token)
|
||||
storedToken, err := refreshRepo.GetByTokenHash(tokenHash)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected refresh token to be stored in database: %v", err)
|
||||
}
|
||||
|
||||
if storedToken.UserID != user.ID {
|
||||
t.Errorf("Expected user ID %d, got %d", user.ID, storedToken.UserID)
|
||||
}
|
||||
|
||||
expectedExpiry := time.Now().Add(time.Duration(jwtService.config.RefreshExpiration) * time.Hour)
|
||||
if storedToken.ExpiresAt.Before(expectedExpiry.Add(-time.Minute)) || storedToken.ExpiresAt.After(expectedExpiry.Add(time.Minute)) {
|
||||
t.Errorf("Expected expiry around %v, got %v", expectedExpiry, storedToken.ExpiresAt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Nil_User", func(t *testing.T) {
|
||||
_, err := jwtService.GenerateRefreshToken(nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil user")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidCredentials) {
|
||||
t.Errorf("Expected ErrInvalidCredentials, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTService_VerifyAccessToken(t *testing.T) {
|
||||
jwtService, userRepo, _ := createTestJWTService()
|
||||
user := createTestUser()
|
||||
userRepo.users[user.ID] = user
|
||||
|
||||
t.Run("Valid_Token", func(t *testing.T) {
|
||||
token, err := jwtService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
userID, err := jwtService.VerifyAccessToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful token verification, got error: %v", err)
|
||||
}
|
||||
|
||||
if userID != user.ID {
|
||||
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid_Token", func(t *testing.T) {
|
||||
_, err := jwtService.VerifyAccessToken("invalid-token")
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Empty_Token", func(t *testing.T) {
|
||||
_, err := jwtService.VerifyAccessToken("")
|
||||
if err == nil {
|
||||
t.Error("Expected error for empty token")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidToken) {
|
||||
t.Errorf("Expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("User_Not_Found", func(t *testing.T) {
|
||||
token, err := jwtService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
delete(userRepo.users, user.ID)
|
||||
|
||||
_, err = jwtService.VerifyAccessToken(token)
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent user")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidToken) {
|
||||
t.Errorf("Expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Locked_User", func(t *testing.T) {
|
||||
user.Locked = true
|
||||
userRepo.users[user.ID] = user
|
||||
|
||||
token, err := jwtService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
_, err = jwtService.VerifyAccessToken(token)
|
||||
if err == nil {
|
||||
t.Error("Expected error for locked user")
|
||||
}
|
||||
if !errors.Is(err, ErrAccountLocked) {
|
||||
t.Errorf("Expected ErrAccountLocked, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Session_Version_Mismatch", func(t *testing.T) {
|
||||
user.Locked = false
|
||||
user.SessionVersion = 2
|
||||
userRepo.users[user.ID] = user
|
||||
|
||||
oldUser := *user
|
||||
oldUser.SessionVersion = 1
|
||||
token, err := jwtService.GenerateAccessToken(&oldUser)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
_, err = jwtService.VerifyAccessToken(token)
|
||||
if err == nil {
|
||||
t.Error("Expected error for session version mismatch")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidToken) {
|
||||
t.Errorf("Expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTService_RefreshAccessToken(t *testing.T) {
|
||||
jwtService, userRepo, refreshRepo := createTestJWTService()
|
||||
user := createTestUser()
|
||||
userRepo.users[user.ID] = user
|
||||
|
||||
t.Run("Successful_Refresh", func(t *testing.T) {
|
||||
|
||||
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate refresh token: %v", err)
|
||||
}
|
||||
|
||||
accessToken, err := jwtService.RefreshAccessToken(refreshToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful token refresh, got error: %v", err)
|
||||
}
|
||||
|
||||
if accessToken == "" {
|
||||
t.Error("Expected non-empty access token")
|
||||
}
|
||||
|
||||
userID, err := jwtService.VerifyAccessToken(accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected valid access token, got error: %v", err)
|
||||
}
|
||||
|
||||
if userID != user.ID {
|
||||
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid_Refresh_Token", func(t *testing.T) {
|
||||
_, err := jwtService.RefreshAccessToken("invalid-refresh-token")
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid refresh token")
|
||||
}
|
||||
if !errors.Is(err, ErrRefreshTokenInvalid) {
|
||||
t.Errorf("Expected ErrRefreshTokenInvalid, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Expired_Refresh_Token", func(t *testing.T) {
|
||||
|
||||
refreshToken := &database.RefreshToken{
|
||||
UserID: user.ID,
|
||||
TokenHash: "expired-token-hash",
|
||||
ExpiresAt: time.Now().Add(-time.Hour),
|
||||
}
|
||||
refreshRepo.tokens["expired-token-hash"] = refreshToken
|
||||
|
||||
testToken := "test-expired-token"
|
||||
tokenHash := jwtService.hashToken(testToken)
|
||||
refreshToken.TokenHash = tokenHash
|
||||
refreshRepo.tokens[tokenHash] = refreshToken
|
||||
|
||||
_, err := jwtService.RefreshAccessToken(testToken)
|
||||
if err == nil {
|
||||
t.Error("Expected error for expired refresh token")
|
||||
}
|
||||
if !errors.Is(err, ErrRefreshTokenExpired) {
|
||||
t.Errorf("Expected ErrRefreshTokenExpired, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("User_Not_Found", func(t *testing.T) {
|
||||
|
||||
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate refresh token: %v", err)
|
||||
}
|
||||
|
||||
delete(userRepo.users, user.ID)
|
||||
|
||||
_, err = jwtService.RefreshAccessToken(refreshToken)
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent user")
|
||||
}
|
||||
if !errors.Is(err, ErrRefreshTokenInvalid) {
|
||||
t.Errorf("Expected ErrRefreshTokenInvalid, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Locked_User", func(t *testing.T) {
|
||||
user.Locked = true
|
||||
userRepo.users[user.ID] = user
|
||||
|
||||
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate refresh token: %v", err)
|
||||
}
|
||||
|
||||
_, err = jwtService.RefreshAccessToken(refreshToken)
|
||||
if err == nil {
|
||||
t.Error("Expected error for locked user")
|
||||
}
|
||||
if !errors.Is(err, ErrAccountLocked) {
|
||||
t.Errorf("Expected ErrAccountLocked, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTService_RevokeRefreshToken(t *testing.T) {
|
||||
jwtService, userRepo, refreshRepo := createTestJWTService()
|
||||
user := createTestUser()
|
||||
userRepo.users[user.ID] = user
|
||||
|
||||
t.Run("Successful_Revocation", func(t *testing.T) {
|
||||
|
||||
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate refresh token: %v", err)
|
||||
}
|
||||
|
||||
tokenHash := jwtService.hashToken(refreshToken)
|
||||
_, err = refreshRepo.GetByTokenHash(tokenHash)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected refresh token to exist: %v", err)
|
||||
}
|
||||
|
||||
err = jwtService.RevokeRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful token revocation, got error: %v", err)
|
||||
}
|
||||
|
||||
_, err = refreshRepo.GetByTokenHash(tokenHash)
|
||||
if err == nil {
|
||||
t.Error("Expected refresh token to be removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non_Existent_Token", func(t *testing.T) {
|
||||
err := jwtService.RevokeRefreshToken("non-existent-token")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for non-existent token, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTService_RevokeAllRefreshTokens(t *testing.T) {
|
||||
jwtService, userRepo, refreshRepo := createTestJWTService()
|
||||
user := createTestUser()
|
||||
userRepo.users[user.ID] = user
|
||||
|
||||
t.Run("Successful_Revocation", func(t *testing.T) {
|
||||
|
||||
_, err := jwtService.GenerateRefreshToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate first refresh token: %v", err)
|
||||
}
|
||||
|
||||
_, err = jwtService.GenerateRefreshToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate second refresh token: %v", err)
|
||||
}
|
||||
|
||||
count, err := refreshRepo.CountByUserID(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to count tokens: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tokens, got %d", count)
|
||||
}
|
||||
|
||||
err = jwtService.RevokeAllRefreshTokens(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful token revocation, got error: %v", err)
|
||||
}
|
||||
|
||||
count, err = refreshRepo.CountByUserID(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to count tokens: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Errorf("Expected 0 tokens, got %d", count)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTService_CleanupExpiredTokens(t *testing.T) {
|
||||
jwtService, userRepo, refreshRepo := createTestJWTService()
|
||||
user := createTestUser()
|
||||
userRepo.users[user.ID] = user
|
||||
|
||||
t.Run("Successful_Cleanup", func(t *testing.T) {
|
||||
|
||||
expiredToken := &database.RefreshToken{
|
||||
UserID: user.ID,
|
||||
TokenHash: "expired-token-hash",
|
||||
ExpiresAt: time.Now().Add(-time.Hour),
|
||||
}
|
||||
refreshRepo.tokens["expired-token-hash"] = expiredToken
|
||||
|
||||
validToken, err := jwtService.GenerateRefreshToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate valid refresh token: %v", err)
|
||||
}
|
||||
|
||||
if len(refreshRepo.tokens) != 2 {
|
||||
t.Errorf("Expected 2 tokens, got %d", len(refreshRepo.tokens))
|
||||
}
|
||||
|
||||
err = jwtService.CleanupExpiredTokens()
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful cleanup, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(refreshRepo.tokens) != 1 {
|
||||
t.Errorf("Expected 1 token after cleanup, got %d", len(refreshRepo.tokens))
|
||||
}
|
||||
|
||||
tokenHash := jwtService.hashToken(validToken)
|
||||
_, exists := refreshRepo.tokens[tokenHash]
|
||||
if !exists {
|
||||
t.Error("Expected valid token to remain after cleanup")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTService_KeyRotation(t *testing.T) {
|
||||
cfg := &config.JWTConfig{
|
||||
Secret: "old-secret-key-that-is-long-enough-for-security",
|
||||
Expiration: 1,
|
||||
RefreshExpiration: 24,
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeyRotation: config.KeyRotationConfig{
|
||||
Enabled: true,
|
||||
CurrentKey: "current-key-that-is-long-enough-for-security",
|
||||
PreviousKey: "previous-key-that-is-long-enough-for-security",
|
||||
KeyID: "current-key-id",
|
||||
},
|
||||
}
|
||||
|
||||
userRepo := &jwtMockUserRepo{
|
||||
users: make(map[uint]*database.User),
|
||||
}
|
||||
refreshRepo := &jwtMockRefreshTokenRepo{
|
||||
tokens: make(map[string]*database.RefreshToken),
|
||||
}
|
||||
|
||||
jwtService := NewJWTService(cfg, userRepo, refreshRepo)
|
||||
user := createTestUser()
|
||||
userRepo.users[user.ID] = user
|
||||
|
||||
t.Run("Generate_Token_With_Key_Rotation", func(t *testing.T) {
|
||||
token, err := jwtService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful token generation with key rotation, got error: %v", err)
|
||||
}
|
||||
|
||||
parsedToken, err := jwt.Parse(token, func(token *jwt.Token) (any, error) {
|
||||
return []byte(cfg.KeyRotation.CurrentKey), nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse token with key rotation: %v", err)
|
||||
}
|
||||
|
||||
if !parsedToken.Valid {
|
||||
t.Error("Generated token should be valid")
|
||||
}
|
||||
|
||||
if kid, ok := parsedToken.Header["kid"].(string); !ok || kid != cfg.KeyRotation.KeyID {
|
||||
t.Errorf("Expected key ID %s, got %v", cfg.KeyRotation.KeyID, parsedToken.Header["kid"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Verify_Token_With_Current_Key", func(t *testing.T) {
|
||||
token, err := jwtService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
userID, err := jwtService.VerifyAccessToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful token verification with current key, got error: %v", err)
|
||||
}
|
||||
|
||||
if userID != user.ID {
|
||||
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Legacy_Token_Without_KID_Remains_Valid", func(t *testing.T) {
|
||||
legacyCfg := &config.JWTConfig{
|
||||
Secret: "legacy-secret-key-that-is-long-enough-for-security",
|
||||
Expiration: 1,
|
||||
RefreshExpiration: 24,
|
||||
Issuer: "legacy-issuer",
|
||||
Audience: "legacy-audience",
|
||||
KeyRotation: config.KeyRotationConfig{Enabled: false},
|
||||
}
|
||||
|
||||
legacyUserRepo := &jwtMockUserRepo{users: map[uint]*database.User{user.ID: user}}
|
||||
legacyRefreshRepo := &jwtMockRefreshTokenRepo{tokens: make(map[string]*database.RefreshToken)}
|
||||
legacyService := NewJWTService(legacyCfg, legacyUserRepo, legacyRefreshRepo)
|
||||
|
||||
legacyToken, err := legacyService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate legacy token: %v", err)
|
||||
}
|
||||
|
||||
legacyCfg.KeyRotation.Enabled = true
|
||||
legacyCfg.KeyRotation.CurrentKey = "rotated-current-key-that-is-long-enough-for-security"
|
||||
legacyCfg.KeyRotation.PreviousKey = legacyCfg.Secret
|
||||
legacyCfg.KeyRotation.KeyID = "rotated-key-id"
|
||||
|
||||
parsedUserID, err := legacyService.VerifyAccessToken(legacyToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Legacy token should remain valid after enabling rotation: %v", err)
|
||||
}
|
||||
if parsedUserID != user.ID {
|
||||
t.Fatalf("Expected user ID %d, got %d", user.ID, parsedUserID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Legacy_Token_With_Previous_KID_Remains_Valid", func(t *testing.T) {
|
||||
rotCfg := &config.JWTConfig{
|
||||
Secret: "unused-secret-key-that-is-long-enough-for-security",
|
||||
Expiration: 1,
|
||||
RefreshExpiration: 24,
|
||||
Issuer: "rotation-issuer",
|
||||
Audience: "rotation-audience",
|
||||
KeyRotation: config.KeyRotationConfig{
|
||||
Enabled: true,
|
||||
CurrentKey: "rotation-key-v1-that-is-long-enough-for-security",
|
||||
KeyID: "key-id-v1",
|
||||
},
|
||||
}
|
||||
|
||||
rotUserRepo := &jwtMockUserRepo{users: map[uint]*database.User{user.ID: user}}
|
||||
rotRefreshRepo := &jwtMockRefreshTokenRepo{tokens: make(map[string]*database.RefreshToken)}
|
||||
rotService := NewJWTService(rotCfg, rotUserRepo, rotRefreshRepo)
|
||||
|
||||
tokenV1, err := rotService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate v1 token: %v", err)
|
||||
}
|
||||
|
||||
rotCfg.KeyRotation.PreviousKey = rotCfg.KeyRotation.CurrentKey
|
||||
rotCfg.KeyRotation.CurrentKey = "rotation-key-v2-that-is-long-enough-for-security"
|
||||
rotCfg.KeyRotation.KeyID = "key-id-v2"
|
||||
|
||||
parsedUserID, err := rotService.VerifyAccessToken(tokenV1)
|
||||
if err != nil {
|
||||
t.Fatalf("Token signed with previous key should remain valid: %v", err)
|
||||
}
|
||||
if parsedUserID != user.ID {
|
||||
t.Fatalf("Expected user ID %d, got %d", user.ID, parsedUserID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Unknown_KID_Is_Rejected", func(t *testing.T) {
|
||||
cfg := &config.JWTConfig{
|
||||
Secret: "unused-secret-key-that-is-long-enough-for-security",
|
||||
Expiration: 1,
|
||||
RefreshExpiration: 24,
|
||||
Issuer: "issuer",
|
||||
Audience: "audience",
|
||||
KeyRotation: config.KeyRotationConfig{
|
||||
Enabled: true,
|
||||
CurrentKey: "current-key-for-unknown-kid-test-that-is-long-enough",
|
||||
KeyID: "expected-key-id",
|
||||
},
|
||||
}
|
||||
|
||||
repo := &jwtMockUserRepo{users: map[uint]*database.User{user.ID: user}}
|
||||
refreshRepo := &jwtMockRefreshTokenRepo{tokens: make(map[string]*database.RefreshToken)}
|
||||
service := NewJWTService(cfg, repo, refreshRepo)
|
||||
|
||||
claims := TokenClaims{
|
||||
UserID: user.ID,
|
||||
Username: user.Username,
|
||||
SessionVersion: user.SessionVersion,
|
||||
TokenType: TokenTypeAccess,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: cfg.Issuer,
|
||||
Audience: []string{cfg.Audience},
|
||||
Subject: fmt.Sprint(user.ID),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
token.Header["kid"] = "unexpected-key-id"
|
||||
tokenString, err := token.SignedString([]byte(cfg.KeyRotation.CurrentKey))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to sign token: %v", err)
|
||||
}
|
||||
|
||||
_, err = service.VerifyAccessToken(tokenString)
|
||||
if !errors.Is(err, ErrInvalidKeyID) {
|
||||
t.Fatalf("Expected ErrInvalidKeyID, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTService_ErrorHandling(t *testing.T) {
|
||||
jwtService, _, _ := createTestJWTService()
|
||||
|
||||
t.Run("Invalid_Issuer", func(t *testing.T) {
|
||||
|
||||
claims := TokenClaims{
|
||||
UserID: 1,
|
||||
Username: "testuser",
|
||||
SessionVersion: 1,
|
||||
TokenType: TokenTypeAccess,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "wrong-issuer",
|
||||
Audience: []string{jwtService.config.Audience},
|
||||
Subject: "1",
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(jwtService.config.Secret))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
_, err = jwtService.VerifyAccessToken(tokenString)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid issuer")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidIssuer) {
|
||||
t.Errorf("Expected ErrInvalidIssuer, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid_Audience", func(t *testing.T) {
|
||||
|
||||
claims := TokenClaims{
|
||||
UserID: 1,
|
||||
Username: "testuser",
|
||||
SessionVersion: 1,
|
||||
TokenType: TokenTypeAccess,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: jwtService.config.Issuer,
|
||||
Audience: []string{"wrong-audience"},
|
||||
Subject: "1",
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(jwtService.config.Secret))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
_, err = jwtService.VerifyAccessToken(tokenString)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid audience")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidAudience) {
|
||||
t.Errorf("Expected ErrInvalidAudience, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Expired_Token", func(t *testing.T) {
|
||||
|
||||
claims := TokenClaims{
|
||||
UserID: 1,
|
||||
Username: "testuser",
|
||||
SessionVersion: 1,
|
||||
TokenType: TokenTypeAccess,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: jwtService.config.Issuer,
|
||||
Audience: []string{jwtService.config.Audience},
|
||||
Subject: "1",
|
||||
IssuedAt: jwt.NewNumericDate(time.Now().Add(-2 * time.Hour)),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(jwtService.config.Secret))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
_, err = jwtService.VerifyAccessToken(tokenString)
|
||||
if err == nil {
|
||||
t.Error("Expected error for expired token")
|
||||
}
|
||||
if !errors.Is(err, ErrTokenExpired) {
|
||||
t.Errorf("Expected ErrTokenExpired, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Subject_Mismatch", func(t *testing.T) {
|
||||
claims := TokenClaims{
|
||||
UserID: 1,
|
||||
Username: "testuser",
|
||||
SessionVersion: 1,
|
||||
TokenType: TokenTypeAccess,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: jwtService.config.Issuer,
|
||||
Audience: []string{jwtService.config.Audience},
|
||||
Subject: "999",
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(jwtService.config.Secret))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
_, err = jwtService.VerifyAccessToken(tokenString)
|
||||
if !errors.Is(err, ErrInvalidToken) {
|
||||
t.Fatalf("Expected ErrInvalidToken for subject mismatch, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTService_HelperFunctions(t *testing.T) {
|
||||
jwtService, _, _ := createTestJWTService()
|
||||
|
||||
t.Run("HashToken", func(t *testing.T) {
|
||||
token := "test-token"
|
||||
hash1 := jwtService.hashToken(token)
|
||||
hash2 := jwtService.hashToken(token)
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Error("Hash should be deterministic")
|
||||
}
|
||||
|
||||
if hash1 == token {
|
||||
t.Error("Hash should be different from original token")
|
||||
}
|
||||
|
||||
hash3 := jwtService.hashToken("different-token")
|
||||
if hash1 == hash3 {
|
||||
t.Error("Different tokens should produce different hashes")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Contains", func(t *testing.T) {
|
||||
slice := []string{"item1", "item2", "item3"}
|
||||
|
||||
if !slices.Contains(slice, "item1") {
|
||||
t.Error("Should contain item1")
|
||||
}
|
||||
|
||||
if !slices.Contains(slice, "item2") {
|
||||
t.Error("Should contain item2")
|
||||
}
|
||||
|
||||
if slices.Contains(slice, "item4") {
|
||||
t.Error("Should not contain item4")
|
||||
}
|
||||
|
||||
if slices.Contains(slice, "") {
|
||||
t.Error("Should not contain empty string")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTService_Integration(t *testing.T) {
|
||||
jwtService, userRepo, _ := createTestJWTService()
|
||||
user := createTestUser()
|
||||
userRepo.users[user.ID] = user
|
||||
|
||||
t.Run("Complete_Flow", func(t *testing.T) {
|
||||
|
||||
accessToken, err := jwtService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate access token: %v", err)
|
||||
}
|
||||
|
||||
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate refresh token: %v", err)
|
||||
}
|
||||
|
||||
userID, err := jwtService.VerifyAccessToken(accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to verify access token: %v", err)
|
||||
}
|
||||
if userID != user.ID {
|
||||
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
||||
}
|
||||
|
||||
newAccessToken, err := jwtService.RefreshAccessToken(refreshToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to refresh access token: %v", err)
|
||||
}
|
||||
|
||||
userID, err = jwtService.VerifyAccessToken(newAccessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to verify new access token: %v", err)
|
||||
}
|
||||
if userID != user.ID {
|
||||
t.Errorf("Expected user ID %d, got %d", user.ID, userID)
|
||||
}
|
||||
|
||||
err = jwtService.RevokeRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to revoke refresh token: %v", err)
|
||||
}
|
||||
|
||||
_, err = jwtService.RefreshAccessToken(refreshToken)
|
||||
if err == nil {
|
||||
t.Error("Expected error when using revoked refresh token")
|
||||
}
|
||||
if !errors.Is(err, ErrRefreshTokenInvalid) {
|
||||
t.Errorf("Expected ErrRefreshTokenInvalid, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
135
internal/services/password_reset_service.go
Normal file
135
internal/services/password_reset_service.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
"goyco/internal/validation"
|
||||
)
|
||||
|
||||
type PasswordResetService struct {
|
||||
userRepo repositories.UserRepository
|
||||
emailService *EmailService
|
||||
}
|
||||
|
||||
func NewPasswordResetService(userRepo repositories.UserRepository, emailService *EmailService) *PasswordResetService {
|
||||
return &PasswordResetService{
|
||||
userRepo: userRepo,
|
||||
emailService: emailService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PasswordResetService) RequestPasswordReset(usernameOrEmail string) error {
|
||||
trimmed := TrimString(usernameOrEmail)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("username or email is required")
|
||||
}
|
||||
|
||||
var user *database.User
|
||||
var err error
|
||||
|
||||
normalized, emailErr := normalizeEmail(trimmed)
|
||||
if emailErr == nil {
|
||||
user, err = s.userRepo.GetByEmail(normalized)
|
||||
if err != nil && !IsRecordNotFound(err) {
|
||||
return fmt.Errorf("lookup user by email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
user, err = s.userRepo.GetByUsername(trimmed)
|
||||
if err != nil {
|
||||
if IsRecordNotFound(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("lookup user by username: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
token, hashed, err := generateVerificationToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(time.Duration(defaultTokenExpirationHours) * time.Hour)
|
||||
|
||||
user.PasswordResetToken = hashed
|
||||
user.PasswordResetSentAt = &now
|
||||
user.PasswordResetExpiresAt = &expiresAt
|
||||
|
||||
if err := s.userRepo.Update(user); err != nil {
|
||||
return fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
|
||||
if err := s.emailService.SendPasswordResetEmail(user, token); err != nil {
|
||||
user.PasswordResetToken = ""
|
||||
user.PasswordResetSentAt = nil
|
||||
user.PasswordResetExpiresAt = nil
|
||||
_ = s.userRepo.Update(user)
|
||||
return fmt.Errorf("send password reset email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PasswordResetService) GetUserByResetToken(token string) (*database.User, error) {
|
||||
trimmed := TrimString(token)
|
||||
if trimmed == "" {
|
||||
return nil, fmt.Errorf("reset token is required")
|
||||
}
|
||||
|
||||
hashed := HashVerificationToken(trimmed)
|
||||
user, err := s.userRepo.GetByPasswordResetToken(hashed)
|
||||
if err != nil {
|
||||
if IsRecordNotFound(err) {
|
||||
return nil, fmt.Errorf("invalid or expired reset token")
|
||||
}
|
||||
return nil, fmt.Errorf("lookup reset token: %w", err)
|
||||
}
|
||||
|
||||
if user.PasswordResetExpiresAt == nil || time.Now().After(*user.PasswordResetExpiresAt) {
|
||||
return nil, fmt.Errorf("invalid or expired reset token")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *PasswordResetService) ResetPassword(token, newPassword string) error {
|
||||
if err := validation.ValidatePassword(newPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := s.GetUserByResetToken(token)
|
||||
if err != nil {
|
||||
hashed := HashVerificationToken(TrimString(token))
|
||||
expiredUser, lookupErr := s.userRepo.GetByPasswordResetToken(hashed)
|
||||
if lookupErr == nil && expiredUser != nil {
|
||||
if expiredUser.PasswordResetExpiresAt == nil || time.Now().After(*expiredUser.PasswordResetExpiresAt) {
|
||||
expiredUser.PasswordResetToken = ""
|
||||
expiredUser.PasswordResetSentAt = nil
|
||||
expiredUser.PasswordResetExpiresAt = nil
|
||||
_ = s.userRepo.Update(expiredUser)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
hashedPassword, err := HashPassword(newPassword, DefaultBcryptCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.Password = string(hashedPassword)
|
||||
user.PasswordResetToken = ""
|
||||
user.PasswordResetSentAt = nil
|
||||
user.PasswordResetExpiresAt = nil
|
||||
|
||||
if err := s.userRepo.Update(user); err != nil {
|
||||
return fmt.Errorf("update password: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
417
internal/services/password_reset_service_test.go
Normal file
417
internal/services/password_reset_service_test.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/testutils"
|
||||
)
|
||||
|
||||
func TestNewPasswordResetService(t *testing.T) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
|
||||
service := NewPasswordResetService(userRepo, emailService)
|
||||
|
||||
if service == nil {
|
||||
t.Fatal("expected service to be created")
|
||||
}
|
||||
|
||||
if service.userRepo != userRepo {
|
||||
t.Error("expected userRepo to be set")
|
||||
}
|
||||
|
||||
if service.emailService != emailService {
|
||||
t.Error("expected emailService to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetService_RequestPasswordReset(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
usernameOrEmail string
|
||||
setupMocks func() (*testutils.MockUserRepository, EmailSender)
|
||||
expectedError bool
|
||||
shouldSendEmail bool
|
||||
}{
|
||||
{
|
||||
name: "successful request by username",
|
||||
usernameOrEmail: "testuser",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
return userRepo, emailSender
|
||||
},
|
||||
expectedError: false,
|
||||
shouldSendEmail: true,
|
||||
},
|
||||
{
|
||||
name: "successful request by email",
|
||||
usernameOrEmail: "test@example.com",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
return userRepo, emailSender
|
||||
},
|
||||
expectedError: false,
|
||||
shouldSendEmail: true,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
usernameOrEmail: "",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
return testutils.NewMockUserRepository(), &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: true,
|
||||
shouldSendEmail: false,
|
||||
},
|
||||
{
|
||||
name: "whitespace only input",
|
||||
usernameOrEmail: " ",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
return testutils.NewMockUserRepository(), &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: true,
|
||||
shouldSendEmail: false,
|
||||
},
|
||||
{
|
||||
name: "user not found",
|
||||
usernameOrEmail: "nonexistent",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
return testutils.NewMockUserRepository(), &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: false,
|
||||
shouldSendEmail: false,
|
||||
},
|
||||
{
|
||||
name: "email service error",
|
||||
usernameOrEmail: "testuser",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
|
||||
var errorSender errorEmailSender
|
||||
errorSender.err = errors.New("email service error")
|
||||
emailSender := &errorSender
|
||||
|
||||
return userRepo, emailSender
|
||||
},
|
||||
expectedError: true,
|
||||
shouldSendEmail: false,
|
||||
},
|
||||
{
|
||||
name: "prefers email over username",
|
||||
usernameOrEmail: "test@example.com",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
|
||||
return userRepo, emailSender
|
||||
},
|
||||
expectedError: false,
|
||||
shouldSendEmail: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo, emailSender := tt.setupMocks()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
|
||||
|
||||
service := NewPasswordResetService(userRepo, emailService)
|
||||
|
||||
err := service.RequestPasswordReset(tt.usernameOrEmail)
|
||||
|
||||
if tt.expectedError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if tt.shouldSendEmail {
|
||||
user, _ := userRepo.GetByUsername("testuser")
|
||||
if user == nil {
|
||||
user, _ = userRepo.GetByEmail("test@example.com")
|
||||
}
|
||||
if user != nil && user.PasswordResetToken == "" {
|
||||
t.Error("expected password reset token to be set")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetService_ResetPassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
newPassword string
|
||||
setupMocks func() (*testutils.MockUserRepository, EmailSender)
|
||||
expectedError bool
|
||||
verifyPassword bool
|
||||
}{
|
||||
{
|
||||
name: "successful password reset",
|
||||
token: "valid-token",
|
||||
newPassword: "NewSecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
PasswordResetToken: HashVerificationToken("valid-token"),
|
||||
PasswordResetExpiresAt: &expiresAt,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
|
||||
return userRepo, &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: false,
|
||||
verifyPassword: true,
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
token: "",
|
||||
newPassword: "NewSecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
return testutils.NewMockUserRepository(), &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: true,
|
||||
verifyPassword: false,
|
||||
},
|
||||
{
|
||||
name: "whitespace only token",
|
||||
token: " ",
|
||||
newPassword: "NewSecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
return testutils.NewMockUserRepository(), &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: true,
|
||||
verifyPassword: false,
|
||||
},
|
||||
{
|
||||
name: "invalid token",
|
||||
token: "invalid-token",
|
||||
newPassword: "NewSecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
return testutils.NewMockUserRepository(), &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: true,
|
||||
verifyPassword: false,
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
token: "expired-token",
|
||||
newPassword: "NewSecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
expiresAt := time.Now().Add(-time.Hour)
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
PasswordResetToken: HashVerificationToken("expired-token"),
|
||||
PasswordResetExpiresAt: &expiresAt,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
|
||||
return userRepo, &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: true,
|
||||
verifyPassword: false,
|
||||
},
|
||||
{
|
||||
name: "nil expiration date",
|
||||
token: "valid-token",
|
||||
newPassword: "NewSecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
PasswordResetToken: HashVerificationToken("valid-token"),
|
||||
}
|
||||
userRepo.Create(user)
|
||||
|
||||
return userRepo, &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: true,
|
||||
verifyPassword: false,
|
||||
},
|
||||
{
|
||||
name: "invalid password",
|
||||
token: "valid-token",
|
||||
newPassword: "short",
|
||||
setupMocks: func() (*testutils.MockUserRepository, EmailSender) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
PasswordResetToken: HashVerificationToken("valid-token"),
|
||||
PasswordResetExpiresAt: &expiresAt,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
|
||||
return userRepo, &testutils.MockEmailSender{}
|
||||
},
|
||||
expectedError: true,
|
||||
verifyPassword: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo, emailSender := tt.setupMocks()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
|
||||
|
||||
service := NewPasswordResetService(userRepo, emailService)
|
||||
|
||||
err := service.ResetPassword(tt.token, tt.newPassword)
|
||||
|
||||
if tt.expectedError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if tt.verifyPassword {
|
||||
user, _ := userRepo.GetByUsername("testuser")
|
||||
if user == nil {
|
||||
t.Fatal("expected user to exist")
|
||||
}
|
||||
|
||||
if user.PasswordResetToken != "" {
|
||||
t.Error("expected password reset token to be cleared")
|
||||
}
|
||||
|
||||
if user.PasswordResetExpiresAt != nil {
|
||||
t.Error("expected password reset expiration to be cleared")
|
||||
}
|
||||
|
||||
if user.Password == "" {
|
||||
t.Error("expected password to be set")
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(tt.newPassword))
|
||||
if err != nil {
|
||||
t.Errorf("password hash verification failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetService_ResetPassword_TokenClearedAfterExpiration(t *testing.T) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
expiresAt := time.Now().Add(-time.Hour)
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
PasswordResetToken: HashVerificationToken("expired-token"),
|
||||
PasswordResetExpiresAt: &expiresAt,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
service := NewPasswordResetService(userRepo, emailService)
|
||||
|
||||
err := service.ResetPassword("expired-token", "NewSecurePass123!")
|
||||
if err == nil {
|
||||
t.Error("expected error for expired token")
|
||||
}
|
||||
|
||||
updatedUser, _ := userRepo.GetByID(1)
|
||||
if updatedUser == nil {
|
||||
t.Fatal("expected user to exist")
|
||||
}
|
||||
|
||||
if updatedUser.PasswordResetToken != "" {
|
||||
t.Error("expected password reset token to be cleared after expiration")
|
||||
}
|
||||
|
||||
if updatedUser.PasswordResetExpiresAt != nil {
|
||||
t.Error("expected password reset expiration to be cleared after expiration")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetService_RequestPasswordReset_EmailFailureRollback(t *testing.T) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
|
||||
emailSender := &errorEmailSender{err: errors.New("email service error")}
|
||||
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
|
||||
service := NewPasswordResetService(userRepo, emailService)
|
||||
|
||||
err := service.RequestPasswordReset("testuser")
|
||||
if err == nil {
|
||||
t.Error("expected error when email fails")
|
||||
}
|
||||
|
||||
updatedUser, _ := userRepo.GetByID(1)
|
||||
if updatedUser == nil {
|
||||
t.Fatal("expected user to exist")
|
||||
}
|
||||
|
||||
if updatedUser.PasswordResetToken != "" {
|
||||
t.Error("expected password reset token to be rolled back on email failure")
|
||||
}
|
||||
|
||||
if updatedUser.PasswordResetSentAt != nil {
|
||||
t.Error("expected password reset sent at to be rolled back on email failure")
|
||||
}
|
||||
|
||||
if updatedUser.PasswordResetExpiresAt != nil {
|
||||
t.Error("expected password reset expiration to be rolled back on email failure")
|
||||
}
|
||||
}
|
||||
123
internal/services/post_queries.go
Normal file
123
internal/services/post_queries.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
)
|
||||
|
||||
type PostQueries struct {
|
||||
postRepo repositories.PostRepository
|
||||
voteService *VoteService
|
||||
}
|
||||
|
||||
func NewPostQueries(postRepo repositories.PostRepository, voteService *VoteService) *PostQueries {
|
||||
return &PostQueries{
|
||||
postRepo: postRepo,
|
||||
voteService: voteService,
|
||||
}
|
||||
}
|
||||
|
||||
type QueryOptions struct {
|
||||
Limit int
|
||||
Offset int
|
||||
Sort string
|
||||
}
|
||||
|
||||
type VoteContext struct {
|
||||
UserID uint
|
||||
IPAddress string
|
||||
UserAgent string
|
||||
}
|
||||
|
||||
func (pq *PostQueries) enrichPostsWithVotes(posts []database.Post, ctx VoteContext) []database.Post {
|
||||
if pq.voteService == nil {
|
||||
return posts
|
||||
}
|
||||
|
||||
enriched := make([]database.Post, len(posts))
|
||||
for i := range posts {
|
||||
enriched[i] = posts[i]
|
||||
vote, err := pq.voteService.GetUserVote(ctx.UserID, posts[i].ID, ctx.IPAddress, ctx.UserAgent)
|
||||
if err == nil && vote != nil {
|
||||
enriched[i].CurrentVote = vote.Type
|
||||
}
|
||||
}
|
||||
|
||||
return enriched
|
||||
}
|
||||
|
||||
func (pq *PostQueries) enrichPostWithVote(post *database.Post, ctx VoteContext) *database.Post {
|
||||
if pq.voteService == nil || post == nil {
|
||||
return post
|
||||
}
|
||||
|
||||
vote, err := pq.voteService.GetUserVote(ctx.UserID, post.ID, ctx.IPAddress, ctx.UserAgent)
|
||||
if err == nil && vote != nil {
|
||||
post.CurrentVote = vote.Type
|
||||
}
|
||||
|
||||
return post
|
||||
}
|
||||
|
||||
func (pq *PostQueries) GetAll(opts QueryOptions, ctx VoteContext) ([]database.Post, error) {
|
||||
posts, err := pq.postRepo.GetAll(opts.Limit, opts.Offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pq.enrichPostsWithVotes(posts, ctx), nil
|
||||
}
|
||||
|
||||
func (pq *PostQueries) GetTop(limit int, ctx VoteContext) ([]database.Post, error) {
|
||||
posts, err := pq.postRepo.GetTopPosts(limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pq.enrichPostsWithVotes(posts, ctx), nil
|
||||
}
|
||||
|
||||
func (pq *PostQueries) GetNewest(limit int, ctx VoteContext) ([]database.Post, error) {
|
||||
posts, err := pq.postRepo.GetNewestPosts(limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pq.enrichPostsWithVotes(posts, ctx), nil
|
||||
}
|
||||
|
||||
func (pq *PostQueries) GetBySort(sort string, limit int, ctx VoteContext) ([]database.Post, error) {
|
||||
switch sort {
|
||||
case "new", "newest", "latest":
|
||||
return pq.GetNewest(limit, ctx)
|
||||
default:
|
||||
return pq.GetTop(limit, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (pq *PostQueries) GetSearch(query string, opts QueryOptions, ctx VoteContext) ([]database.Post, error) {
|
||||
posts, err := pq.postRepo.Search(query, opts.Limit, opts.Offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pq.enrichPostsWithVotes(posts, ctx), nil
|
||||
}
|
||||
|
||||
func (pq *PostQueries) GetByID(postID uint, ctx VoteContext) (*database.Post, error) {
|
||||
post, err := pq.postRepo.GetByID(postID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pq.enrichPostWithVote(post, ctx), nil
|
||||
}
|
||||
|
||||
func (pq *PostQueries) GetByUserID(userID uint, opts QueryOptions, ctx VoteContext) ([]database.Post, error) {
|
||||
posts, err := pq.postRepo.GetByUserID(userID, opts.Limit, opts.Offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pq.enrichPostsWithVotes(posts, ctx), nil
|
||||
}
|
||||
609
internal/services/post_queries_test.go
Normal file
609
internal/services/post_queries_test.go
Normal file
@@ -0,0 +1,609 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/testutils"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestNewPostQueries(t *testing.T) {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
voteService := NewVoteService(testutils.NewMockVoteRepository(), repo, nil)
|
||||
postQueries := NewPostQueries(repo, voteService)
|
||||
if postQueries == nil {
|
||||
t.Fatal("expected PostQueries to be created")
|
||||
}
|
||||
if postQueries.postRepo != repo {
|
||||
t.Error("expected postRepo to be set")
|
||||
}
|
||||
if postQueries.voteService != voteService {
|
||||
t.Error("expected voteService to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostQueries_GetAll(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRepo func() *testutils.MockPostRepository
|
||||
opts QueryOptions
|
||||
ctx VoteContext
|
||||
expectedCount int
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "success with pagination",
|
||||
setupRepo: func() *testutils.MockPostRepository {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
repo.Create(&database.Post{ID: 1, Title: "Post 1", Score: 10})
|
||||
repo.Create(&database.Post{ID: 2, Title: "Post 2", Score: 5})
|
||||
repo.Create(&database.Post{ID: 3, Title: "Post 3", Score: 15})
|
||||
return repo
|
||||
},
|
||||
opts: QueryOptions{
|
||||
Limit: 2,
|
||||
Offset: 0,
|
||||
},
|
||||
ctx: VoteContext{UserID: 1},
|
||||
expectedCount: 2,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "success with offset",
|
||||
setupRepo: func() *testutils.MockPostRepository {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
repo.Create(&database.Post{ID: 1, Title: "Post 1", Score: 10})
|
||||
repo.Create(&database.Post{ID: 2, Title: "Post 2", Score: 5})
|
||||
repo.Create(&database.Post{ID: 3, Title: "Post 3", Score: 15})
|
||||
return repo
|
||||
},
|
||||
opts: QueryOptions{
|
||||
Limit: 2,
|
||||
Offset: 1,
|
||||
},
|
||||
ctx: VoteContext{UserID: 1},
|
||||
expectedCount: 2,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "repository error",
|
||||
setupRepo: func() *testutils.MockPostRepository {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
repo.GetErr = errors.New("database error")
|
||||
return repo
|
||||
},
|
||||
opts: QueryOptions{
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
},
|
||||
ctx: VoteContext{UserID: 1},
|
||||
expectedCount: 0,
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "empty result",
|
||||
setupRepo: func() *testutils.MockPostRepository {
|
||||
return testutils.NewMockPostRepository()
|
||||
},
|
||||
opts: QueryOptions{
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
},
|
||||
ctx: VoteContext{UserID: 1},
|
||||
expectedCount: 0,
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := tt.setupRepo()
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := NewVoteService(voteRepo, repo, nil)
|
||||
postQueries := NewPostQueries(repo, voteService)
|
||||
posts, err := postQueries.GetAll(tt.opts, tt.ctx)
|
||||
if tt.expectedError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if len(posts) != tt.expectedCount {
|
||||
t.Errorf("expected %d posts, got %d", tt.expectedCount, len(posts))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostQueries_GetAll_WithVoteEnrichment(t *testing.T) {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
post1 := &database.Post{ID: 1, Title: "Post 1", Score: 10}
|
||||
post2 := &database.Post{ID: 2, Title: "Post 2", Score: 5}
|
||||
repo.Create(post1)
|
||||
repo.Create(post2)
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := NewVoteService(voteRepo, repo, nil)
|
||||
userID := uint(1)
|
||||
voteRepo.Create(&database.Vote{
|
||||
UserID: &userID,
|
||||
PostID: 1,
|
||||
Type: database.VoteUp,
|
||||
})
|
||||
postQueries := NewPostQueries(repo, voteService)
|
||||
ctx := VoteContext{
|
||||
UserID: 1,
|
||||
IPAddress: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
}
|
||||
posts, err := postQueries.GetAll(QueryOptions{Limit: 10, Offset: 0}, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(posts) != 2 {
|
||||
t.Fatalf("expected 2 posts, got %d", len(posts))
|
||||
}
|
||||
if posts[0].CurrentVote != database.VoteUp && posts[0].ID == 1 {
|
||||
if posts[1].ID == 1 && posts[1].CurrentVote != database.VoteUp {
|
||||
t.Error("expected post 1 to have CurrentVote set to VoteUp")
|
||||
}
|
||||
}
|
||||
for _, post := range posts {
|
||||
if post.ID == 2 && post.CurrentVote != "" {
|
||||
t.Errorf("expected post 2 to have no vote, got %s", post.CurrentVote)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostQueries_GetTop(t *testing.T) {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
post1 := &database.Post{ID: 1, Title: "Post 1", Score: 10}
|
||||
post2 := &database.Post{ID: 2, Title: "Post 2", Score: 15}
|
||||
post3 := &database.Post{ID: 3, Title: "Post 3", Score: 5}
|
||||
repo.Create(post1)
|
||||
repo.Create(post2)
|
||||
repo.Create(post3)
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := NewVoteService(voteRepo, repo, nil)
|
||||
postQueries := NewPostQueries(repo, voteService)
|
||||
posts, err := postQueries.GetTop(2, VoteContext{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(posts) != 2 {
|
||||
t.Errorf("expected 2 posts, got %d", len(posts))
|
||||
}
|
||||
if len(posts) == 0 {
|
||||
t.Error("expected at least one post")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostQueries_GetNewest(t *testing.T) {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
now := time.Now()
|
||||
post1 := &database.Post{ID: 1, Title: "Post 1", CreatedAt: now.Add(-2 * time.Hour)}
|
||||
post2 := &database.Post{ID: 2, Title: "Post 2", CreatedAt: now.Add(-1 * time.Hour)}
|
||||
post3 := &database.Post{ID: 3, Title: "Post 3", CreatedAt: now}
|
||||
repo.Create(post1)
|
||||
repo.Create(post2)
|
||||
repo.Create(post3)
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := NewVoteService(voteRepo, repo, nil)
|
||||
postQueries := NewPostQueries(repo, voteService)
|
||||
posts, err := postQueries.GetNewest(2, VoteContext{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(posts) != 2 {
|
||||
t.Errorf("expected 2 posts, got %d", len(posts))
|
||||
}
|
||||
if len(posts) == 0 {
|
||||
t.Error("expected at least one post")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostQueries_GetBySort(t *testing.T) {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
post1 := &database.Post{ID: 1, Title: "Post 1", Score: 10}
|
||||
post2 := &database.Post{ID: 2, Title: "Post 2", Score: 15}
|
||||
repo.Create(post1)
|
||||
repo.Create(post2)
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := NewVoteService(voteRepo, repo, nil)
|
||||
postQueries := NewPostQueries(repo, voteService)
|
||||
tests := []struct {
|
||||
name string
|
||||
sort string
|
||||
expectTop bool
|
||||
}{
|
||||
{"new sort", "new", false},
|
||||
{"newest sort", "newest", false},
|
||||
{"latest sort", "latest", false},
|
||||
{"default sort", "", true},
|
||||
{"invalid sort", "invalid", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
posts, err := postQueries.GetBySort(tt.sort, 10, VoteContext{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(posts) == 0 {
|
||||
t.Error("expected at least one post")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostQueries_GetSearch(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
setupRepo func() *testutils.MockPostRepository
|
||||
opts QueryOptions
|
||||
ctx VoteContext
|
||||
expectedCount int
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "successful search",
|
||||
query: "test",
|
||||
setupRepo: func() *testutils.MockPostRepository {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
repo.Create(&database.Post{ID: 1, Title: "Test Post", Score: 10})
|
||||
repo.Create(&database.Post{ID: 2, Title: "Another Post", Score: 5})
|
||||
return repo
|
||||
},
|
||||
opts: QueryOptions{
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
},
|
||||
ctx: VoteContext{},
|
||||
expectedCount: 1,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "search with pagination",
|
||||
query: "post",
|
||||
setupRepo: func() *testutils.MockPostRepository {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
repo.Create(&database.Post{ID: 1, Title: "Post 1", Score: 10})
|
||||
repo.Create(&database.Post{ID: 2, Title: "Post 2", Score: 5})
|
||||
repo.Create(&database.Post{ID: 3, Title: "Post 3", Score: 15})
|
||||
return repo
|
||||
},
|
||||
opts: QueryOptions{
|
||||
Limit: 2,
|
||||
Offset: 0,
|
||||
},
|
||||
ctx: VoteContext{},
|
||||
expectedCount: 2,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "search error",
|
||||
query: "test",
|
||||
setupRepo: func() *testutils.MockPostRepository {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
repo.SearchErr = errors.New("search error")
|
||||
return repo
|
||||
},
|
||||
opts: QueryOptions{
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
},
|
||||
ctx: VoteContext{},
|
||||
expectedCount: 0,
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := tt.setupRepo()
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := NewVoteService(voteRepo, repo, nil)
|
||||
postQueries := NewPostQueries(repo, voteService)
|
||||
posts, err := postQueries.GetSearch(tt.query, tt.opts, tt.ctx)
|
||||
if tt.expectedError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if len(posts) < tt.expectedCount {
|
||||
t.Errorf("expected at least %d posts, got %d", tt.expectedCount, len(posts))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostQueries_GetByID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
postID uint
|
||||
setupRepo func() *testutils.MockPostRepository
|
||||
ctx VoteContext
|
||||
expectedError bool
|
||||
expectedID uint
|
||||
}{
|
||||
{
|
||||
name: "successful retrieval",
|
||||
postID: 1,
|
||||
setupRepo: func() *testutils.MockPostRepository {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
repo.Create(&database.Post{ID: 1, Title: "Test Post", Score: 10})
|
||||
return repo
|
||||
},
|
||||
ctx: VoteContext{UserID: 1},
|
||||
expectedError: false,
|
||||
expectedID: 1,
|
||||
},
|
||||
{
|
||||
name: "post not found",
|
||||
postID: 999,
|
||||
setupRepo: func() *testutils.MockPostRepository {
|
||||
return testutils.NewMockPostRepository()
|
||||
},
|
||||
ctx: VoteContext{UserID: 1},
|
||||
expectedError: true,
|
||||
expectedID: 0,
|
||||
},
|
||||
{
|
||||
name: "repository error",
|
||||
postID: 1,
|
||||
setupRepo: func() *testutils.MockPostRepository {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
repo.GetErr = errors.New("database error")
|
||||
return repo
|
||||
},
|
||||
ctx: VoteContext{UserID: 1},
|
||||
expectedError: true,
|
||||
expectedID: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := tt.setupRepo()
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := NewVoteService(voteRepo, repo, nil)
|
||||
postQueries := NewPostQueries(repo, voteService)
|
||||
post, err := postQueries.GetByID(tt.postID, tt.ctx)
|
||||
if tt.expectedError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if post == nil {
|
||||
t.Fatal("expected post to be returned")
|
||||
}
|
||||
if post.ID != tt.expectedID {
|
||||
t.Errorf("expected post ID %d, got %d", tt.expectedID, post.ID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostQueries_GetByID_WithVoteEnrichment(t *testing.T) {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
post := &database.Post{ID: 1, Title: "Test Post", Score: 10}
|
||||
repo.Create(post)
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := NewVoteService(voteRepo, repo, nil)
|
||||
userID := uint(1)
|
||||
voteRepo.Create(&database.Vote{
|
||||
UserID: &userID,
|
||||
PostID: 1,
|
||||
Type: database.VoteDown,
|
||||
})
|
||||
postQueries := NewPostQueries(repo, voteService)
|
||||
ctx := VoteContext{
|
||||
UserID: 1,
|
||||
IPAddress: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
}
|
||||
retrievedPost, err := postQueries.GetByID(1, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if retrievedPost.CurrentVote != database.VoteDown {
|
||||
t.Errorf("expected CurrentVote to be VoteDown, got %s", retrievedPost.CurrentVote)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostQueries_GetByUserID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uint
|
||||
setupRepo func() *testutils.MockPostRepository
|
||||
opts QueryOptions
|
||||
ctx VoteContext
|
||||
expectedCount int
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "successful retrieval",
|
||||
userID: 1,
|
||||
setupRepo: func() *testutils.MockPostRepository {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
authorID1 := uint(1)
|
||||
authorID2 := uint(2)
|
||||
repo.Create(&database.Post{ID: 1, Title: "User 1 Post", AuthorID: &authorID1, Score: 10})
|
||||
repo.Create(&database.Post{ID: 2, Title: "User 2 Post", AuthorID: &authorID2, Score: 5})
|
||||
repo.Create(&database.Post{ID: 3, Title: "User 1 Post 2", AuthorID: &authorID1, Score: 15})
|
||||
return repo
|
||||
},
|
||||
opts: QueryOptions{
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
},
|
||||
ctx: VoteContext{UserID: 1},
|
||||
expectedCount: 2,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "with pagination",
|
||||
userID: 1,
|
||||
setupRepo: func() *testutils.MockPostRepository {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
authorID := uint(1)
|
||||
repo.Create(&database.Post{ID: 1, Title: "Post 1", AuthorID: &authorID, Score: 10})
|
||||
repo.Create(&database.Post{ID: 2, Title: "Post 2", AuthorID: &authorID, Score: 5})
|
||||
repo.Create(&database.Post{ID: 3, Title: "Post 3", AuthorID: &authorID, Score: 15})
|
||||
return repo
|
||||
},
|
||||
opts: QueryOptions{
|
||||
Limit: 2,
|
||||
Offset: 0,
|
||||
},
|
||||
ctx: VoteContext{UserID: 1},
|
||||
expectedCount: 2,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "repository error",
|
||||
userID: 1,
|
||||
setupRepo: func() *testutils.MockPostRepository {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
repo.GetErr = errors.New("database error")
|
||||
return repo
|
||||
},
|
||||
opts: QueryOptions{
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
},
|
||||
ctx: VoteContext{UserID: 1},
|
||||
expectedCount: 0,
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := tt.setupRepo()
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := NewVoteService(voteRepo, repo, nil)
|
||||
postQueries := NewPostQueries(repo, voteService)
|
||||
posts, err := postQueries.GetByUserID(tt.userID, tt.opts, tt.ctx)
|
||||
if tt.expectedError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if len(posts) < tt.expectedCount {
|
||||
t.Errorf("expected at least %d posts, got %d", tt.expectedCount, len(posts))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostQueries_WithoutVoteService(t *testing.T) {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
repo.Create(&database.Post{ID: 1, Title: "Test Post", Score: 10})
|
||||
postQueries := NewPostQueries(repo, nil)
|
||||
ctx := VoteContext{
|
||||
UserID: 1,
|
||||
IPAddress: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
}
|
||||
post, err := postQueries.GetByID(1, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if post == nil {
|
||||
t.Fatal("expected post to be returned")
|
||||
}
|
||||
if post.CurrentVote != "" {
|
||||
t.Errorf("expected CurrentVote to be empty when voteService is nil, got %s", post.CurrentVote)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostQueries_WithIPBasedVote(t *testing.T) {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
post := &database.Post{ID: 1, Title: "Test Post", Score: 10}
|
||||
repo.Create(post)
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := NewVoteService(voteRepo, repo, nil)
|
||||
voteHash := voteService.GenerateVoteHash("127.0.0.1", "test-agent", 1)
|
||||
voteRepo.Create(&database.Vote{
|
||||
PostID: 1,
|
||||
Type: database.VoteUp,
|
||||
VoteHash: &voteHash,
|
||||
})
|
||||
postQueries := NewPostQueries(repo, voteService)
|
||||
ctx := VoteContext{
|
||||
UserID: 0,
|
||||
IPAddress: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
}
|
||||
retrievedPost, err := postQueries.GetByID(1, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if retrievedPost.CurrentVote != database.VoteUp {
|
||||
t.Errorf("expected CurrentVote to be VoteUp for IP-based vote, got %s", retrievedPost.CurrentVote)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostQueries_EnrichPostsWithVotes_NoVotes(t *testing.T) {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
post1 := &database.Post{ID: 1, Title: "Post 1", Score: 10}
|
||||
post2 := &database.Post{ID: 2, Title: "Post 2", Score: 5}
|
||||
repo.Create(post1)
|
||||
repo.Create(post2)
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := NewVoteService(voteRepo, repo, nil)
|
||||
postQueries := NewPostQueries(repo, voteService)
|
||||
ctx := VoteContext{
|
||||
UserID: 1,
|
||||
IPAddress: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
}
|
||||
posts, err := postQueries.GetAll(QueryOptions{Limit: 10, Offset: 0}, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(posts) != 2 {
|
||||
t.Errorf("expected 2 posts, got %d", len(posts))
|
||||
}
|
||||
for _, post := range posts {
|
||||
if post.CurrentVote != "" {
|
||||
t.Errorf("expected CurrentVote to be empty when no votes exist, got %s", post.CurrentVote)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostQueries_GetByID_NotFound(t *testing.T) {
|
||||
repo := testutils.NewMockPostRepository()
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := NewVoteService(voteRepo, repo, nil)
|
||||
postQueries := NewPostQueries(repo, voteService)
|
||||
post, err := postQueries.GetByID(999, VoteContext{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-existent post")
|
||||
}
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
t.Errorf("expected gorm.ErrRecordNotFound, got %v", err)
|
||||
}
|
||||
if post != nil {
|
||||
t.Error("expected nil post when not found")
|
||||
}
|
||||
}
|
||||
178
internal/services/registration_service.go
Normal file
178
internal/services/registration_service.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
"goyco/internal/validation"
|
||||
)
|
||||
|
||||
type RegistrationService struct {
|
||||
userRepo repositories.UserRepository
|
||||
emailService *EmailService
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewRegistrationService(userRepo repositories.UserRepository, emailService *EmailService, config *config.Config) *RegistrationService {
|
||||
return &RegistrationService{
|
||||
userRepo: userRepo,
|
||||
emailService: emailService,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RegistrationService) Register(username, email, password string) (*RegistrationResult, error) {
|
||||
trimmedUsername := TrimString(username)
|
||||
if err := validation.ValidateUsername(trimmedUsername); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validation.ValidatePassword(password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
normalizedEmail, err := normalizeEmail(email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userCheck, err := s.userRepo.GetByUsername(trimmedUsername)
|
||||
if err == nil {
|
||||
if userCheck != nil {
|
||||
return nil, ErrUsernameTaken
|
||||
}
|
||||
} else if !IsRecordNotFound(err) {
|
||||
if handled := HandleUniqueConstraintError(err); handled != err {
|
||||
return nil, handled
|
||||
}
|
||||
return nil, fmt.Errorf("lookup user: %w", err)
|
||||
}
|
||||
|
||||
emailCheck, err := s.userRepo.GetByEmail(normalizedEmail)
|
||||
if err == nil {
|
||||
if emailCheck != nil {
|
||||
return nil, ErrEmailTaken
|
||||
}
|
||||
} else if !IsRecordNotFound(err) {
|
||||
if handled := HandleUniqueConstraintError(err); handled != err {
|
||||
return nil, handled
|
||||
}
|
||||
return nil, fmt.Errorf("lookup email: %w", err)
|
||||
}
|
||||
|
||||
hashedPassword, err := HashPassword(password, s.config.App.BcryptCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token, hashedToken, err := generateVerificationToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
user := &database.User{
|
||||
Username: trimmedUsername,
|
||||
Email: normalizedEmail,
|
||||
Password: string(hashedPassword),
|
||||
EmailVerified: false,
|
||||
EmailVerificationToken: hashedToken,
|
||||
EmailVerificationSentAt: &now,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(user); err != nil {
|
||||
if handled := HandleUniqueConstraintErrorWithMessage(err); handled != err {
|
||||
return nil, handled
|
||||
}
|
||||
return nil, fmt.Errorf("create user: %w", err)
|
||||
}
|
||||
|
||||
if err := s.emailService.SendVerificationEmail(user, token); err != nil {
|
||||
if deleteErr := s.userRepo.HardDelete(user.ID); deleteErr != nil {
|
||||
return nil, fmt.Errorf("verification email failed and user cleanup failed: email=%w, cleanup=%v", err, deleteErr)
|
||||
}
|
||||
return nil, fmt.Errorf("verification email failed: %w", err)
|
||||
}
|
||||
|
||||
return &RegistrationResult{
|
||||
User: sanitizeUser(user),
|
||||
VerificationSent: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *RegistrationService) ConfirmEmail(token string) (*database.User, error) {
|
||||
trimmed := TrimString(token)
|
||||
if trimmed == "" {
|
||||
return nil, ErrInvalidVerificationToken
|
||||
}
|
||||
|
||||
hashed := HashVerificationToken(trimmed)
|
||||
user, err := s.userRepo.GetByVerificationToken(hashed)
|
||||
if err != nil {
|
||||
if IsRecordNotFound(err) {
|
||||
return nil, ErrInvalidVerificationToken
|
||||
}
|
||||
return nil, fmt.Errorf("lookup verification token: %w", err)
|
||||
}
|
||||
|
||||
if user.EmailVerified {
|
||||
return sanitizeUser(user), nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
user.EmailVerified = true
|
||||
user.EmailVerifiedAt = &now
|
||||
user.EmailVerificationToken = ""
|
||||
user.EmailVerificationSentAt = nil
|
||||
|
||||
if err := s.userRepo.Update(user); err != nil {
|
||||
return nil, fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
|
||||
return sanitizeUser(user), nil
|
||||
}
|
||||
|
||||
func (s *RegistrationService) ResendVerificationEmail(email string) error {
|
||||
email = TrimString(email)
|
||||
if err := validation.ValidateEmail(email); err != nil {
|
||||
return ErrInvalidEmail
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByEmail(email)
|
||||
if err != nil {
|
||||
if IsRecordNotFound(err) {
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
return fmt.Errorf("lookup user: %w", err)
|
||||
}
|
||||
|
||||
if user.EmailVerified {
|
||||
return fmt.Errorf("email already verified")
|
||||
}
|
||||
|
||||
if user.EmailVerificationSentAt != nil && time.Since(*user.EmailVerificationSentAt) < 5*time.Minute {
|
||||
return fmt.Errorf("verification email sent recently, please wait before requesting another")
|
||||
}
|
||||
|
||||
token, hash, err := generateVerificationToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
user.EmailVerificationToken = hash
|
||||
user.EmailVerificationSentAt = &now
|
||||
|
||||
if err := s.userRepo.Update(user); err != nil {
|
||||
return fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
|
||||
if err := s.emailService.SendResendVerificationEmail(user, token); err != nil {
|
||||
return fmt.Errorf("send verification email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
579
internal/services/registration_service_test.go
Normal file
579
internal/services/registration_service_test.go
Normal file
@@ -0,0 +1,579 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/testutils"
|
||||
)
|
||||
|
||||
func TestNewRegistrationService(t *testing.T) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
cfg := testutils.AppTestConfig
|
||||
|
||||
service := NewRegistrationService(userRepo, emailService, cfg)
|
||||
|
||||
if service == nil {
|
||||
t.Fatal("expected service to be created")
|
||||
}
|
||||
|
||||
if service.userRepo != userRepo {
|
||||
t.Error("expected userRepo to be set")
|
||||
}
|
||||
|
||||
if service.emailService != emailService {
|
||||
t.Error("expected emailService to be set")
|
||||
}
|
||||
|
||||
if service.config != cfg {
|
||||
t.Error("expected config to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrationService_Register(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
email string
|
||||
password string
|
||||
setupMocks func() (*testutils.MockUserRepository, *EmailService, *config.Config)
|
||||
expectedError error
|
||||
checkResult func(*testing.T, *RegistrationResult)
|
||||
}{
|
||||
{
|
||||
name: "successful registration",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, result *RegistrationResult) {
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if result.User == nil {
|
||||
t.Fatal("expected non-nil user")
|
||||
}
|
||||
if result.User.Username != "testuser" {
|
||||
t.Errorf("expected username 'testuser', got %q", result.User.Username)
|
||||
}
|
||||
if result.User.Email != "test@example.com" {
|
||||
t.Errorf("expected email 'test@example.com', got %q", result.User.Email)
|
||||
}
|
||||
if result.User.Password != "" {
|
||||
t.Error("expected password to be sanitized")
|
||||
}
|
||||
if !result.VerificationSent {
|
||||
t.Error("expected VerificationSent to be true")
|
||||
}
|
||||
if result.User.EmailVerified {
|
||||
t.Error("expected EmailVerified to be false")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid username",
|
||||
username: "",
|
||||
email: "test@example.com",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid password",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "short",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
username: "testuser",
|
||||
email: "invalid-email",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "username already taken",
|
||||
username: "existinguser",
|
||||
email: "test@example.com",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
existingUser := &database.User{
|
||||
ID: 1,
|
||||
Username: "existinguser",
|
||||
Email: "existing@example.com",
|
||||
}
|
||||
userRepo.Create(existingUser)
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: ErrUsernameTaken,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "email already taken",
|
||||
username: "testuser",
|
||||
email: "existing@example.com",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
existingUser := &database.User{
|
||||
ID: 1,
|
||||
Username: "existinguser",
|
||||
Email: "existing@example.com",
|
||||
}
|
||||
userRepo.Create(existingUser)
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: ErrEmailTaken,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "email service error",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
errorSender := &errorEmailSender{err: errors.New("email service error")}
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, errorSender)
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "trims username whitespace",
|
||||
username: " testuser ",
|
||||
email: "test@example.com",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, result *RegistrationResult) {
|
||||
if result.User.Username != "testuser" {
|
||||
t.Errorf("expected trimmed username 'testuser', got %q", result.User.Username)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "normalizes email",
|
||||
username: "testuser",
|
||||
email: "TEST@EXAMPLE.COM",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, result *RegistrationResult) {
|
||||
if result.User.Email != "test@example.com" {
|
||||
t.Errorf("expected normalized email 'test@example.com', got %q", result.User.Email)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo, emailService, cfg := tt.setupMocks()
|
||||
service := NewRegistrationService(userRepo, emailService, cfg)
|
||||
|
||||
result, err := service.Register(tt.username, tt.email, tt.password)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if tt.checkResult == nil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if tt.checkResult != nil {
|
||||
tt.checkResult(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrationService_ConfirmEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
setupMocks func() (*testutils.MockUserRepository, *EmailService, *config.Config)
|
||||
expectedError error
|
||||
checkResult func(*testing.T, *database.User)
|
||||
}{
|
||||
{
|
||||
name: "successful confirmation",
|
||||
token: "valid-token",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
hashedToken := HashVerificationToken("valid-token")
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
EmailVerified: false,
|
||||
EmailVerificationToken: hashedToken,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, user *database.User) {
|
||||
if user == nil {
|
||||
t.Fatal("expected non-nil user")
|
||||
}
|
||||
if !user.EmailVerified {
|
||||
t.Error("expected EmailVerified to be true")
|
||||
}
|
||||
if user.EmailVerificationToken != "" {
|
||||
t.Error("expected EmailVerificationToken to be cleared")
|
||||
}
|
||||
if user.EmailVerificationSentAt != nil {
|
||||
t.Error("expected EmailVerificationSentAt to be nil")
|
||||
}
|
||||
if user.EmailVerifiedAt == nil {
|
||||
t.Error("expected EmailVerifiedAt to be set")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
token: "",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: ErrInvalidVerificationToken,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "whitespace only token",
|
||||
token: " ",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: ErrInvalidVerificationToken,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid token",
|
||||
token: "invalid-token",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: ErrInvalidVerificationToken,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "already verified",
|
||||
token: "valid-token",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
hashedToken := HashVerificationToken("valid-token")
|
||||
now := time.Now()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
EmailVerified: true,
|
||||
EmailVerifiedAt: &now,
|
||||
EmailVerificationToken: hashedToken,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, user *database.User) {
|
||||
if user == nil {
|
||||
t.Fatal("expected non-nil user")
|
||||
}
|
||||
if !user.EmailVerified {
|
||||
t.Error("expected EmailVerified to be true")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "trims token whitespace",
|
||||
token: " valid-token ",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
hashedToken := HashVerificationToken("valid-token")
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
EmailVerified: false,
|
||||
EmailVerificationToken: hashedToken,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, user *database.User) {
|
||||
if !user.EmailVerified {
|
||||
t.Error("expected EmailVerified to be true")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo, emailService, cfg := tt.setupMocks()
|
||||
service := NewRegistrationService(userRepo, emailService, cfg)
|
||||
|
||||
user, err := service.ConfirmEmail(tt.token)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if tt.checkResult == nil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if tt.checkResult != nil {
|
||||
tt.checkResult(t, user)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrationService_ResendVerificationEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
setupMocks func() (*testutils.MockUserRepository, *EmailService, *config.Config)
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "successful resend",
|
||||
email: "test@example.com",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
oldTime := time.Now().Add(-10 * time.Minute)
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
EmailVerified: false,
|
||||
EmailVerificationSentAt: &oldTime,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
email: "invalid-email",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: ErrInvalidEmail,
|
||||
},
|
||||
{
|
||||
name: "user not found",
|
||||
email: "nonexistent@example.com",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: ErrInvalidCredentials,
|
||||
},
|
||||
{
|
||||
name: "email already verified",
|
||||
email: "test@example.com",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
now := time.Now()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
EmailVerified: true,
|
||||
EmailVerifiedAt: &now,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "email sent too recently",
|
||||
email: "test@example.com",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
recentTime := time.Now().Add(-2 * time.Minute)
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
EmailVerified: false,
|
||||
EmailVerificationSentAt: &recentTime,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "email service error",
|
||||
email: "test@example.com",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
oldTime := time.Now().Add(-10 * time.Minute)
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
EmailVerified: false,
|
||||
EmailVerificationSentAt: &oldTime,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
errorSender := &errorEmailSender{err: errors.New("email service error")}
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, errorSender)
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "trims email whitespace",
|
||||
email: " test@example.com ",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
oldTime := time.Now().Add(-10 * time.Minute)
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
EmailVerified: false,
|
||||
EmailVerificationSentAt: &oldTime,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "no previous verification sent",
|
||||
email: "test@example.com",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *EmailService, *config.Config) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
EmailVerified: false,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, emailSender)
|
||||
return userRepo, emailService, testutils.AppTestConfig
|
||||
},
|
||||
expectedError: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo, emailService, cfg := tt.setupMocks()
|
||||
service := NewRegistrationService(userRepo, emailService, cfg)
|
||||
|
||||
err := service.ResendVerificationEmail(tt.email)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if tt.name == "email already verified" || tt.name == "email sent too recently" || tt.name == "email service error" {
|
||||
if err.Error() == "" {
|
||||
t.Fatal("expected error message")
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
124
internal/services/session_service.go
Normal file
124
internal/services/session_service.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
)
|
||||
|
||||
type SessionService struct {
|
||||
jwtService *JWTService
|
||||
userRepo repositories.UserRepository
|
||||
}
|
||||
|
||||
func NewSessionService(jwtService *JWTService, userRepo repositories.UserRepository) *SessionService {
|
||||
return &SessionService{
|
||||
jwtService: jwtService,
|
||||
userRepo: userRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SessionService) Login(username, password string) (*AuthResult, error) {
|
||||
trimmedUsername := TrimString(username)
|
||||
if trimmedUsername == "" {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByUsername(trimmedUsername)
|
||||
if err != nil {
|
||||
if IsRecordNotFound(err) {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
return nil, fmt.Errorf("lookup user: %w", err)
|
||||
}
|
||||
if !user.EmailVerified {
|
||||
return nil, ErrEmailNotVerified
|
||||
}
|
||||
|
||||
if user.Locked {
|
||||
return nil, ErrAccountLocked
|
||||
}
|
||||
|
||||
if compareErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); compareErr != nil {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
return s.issueAuthResult(user)
|
||||
}
|
||||
|
||||
func (s *SessionService) VerifyToken(tokenString string) (uint, error) {
|
||||
return s.jwtService.VerifyAccessToken(tokenString)
|
||||
}
|
||||
|
||||
func (s *SessionService) issueAuthResult(user *database.User) (*AuthResult, error) {
|
||||
accessToken, err := s.jwtService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate access token: %w", err)
|
||||
}
|
||||
|
||||
refreshToken, err := s.jwtService.GenerateRefreshToken(user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate refresh token: %w", err)
|
||||
}
|
||||
|
||||
return &AuthResult{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
User: sanitizeUser(user),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SessionService) RefreshAccessToken(refreshToken string) (*AuthResult, error) {
|
||||
accessToken, err := s.jwtService.RefreshAccessToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userID, err := s.jwtService.VerifyAccessToken(accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("verify new access token: %w", err)
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lookup user: %w", err)
|
||||
}
|
||||
|
||||
return &AuthResult{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
User: sanitizeUser(user),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SessionService) RevokeRefreshToken(refreshToken string) error {
|
||||
return s.jwtService.RevokeRefreshToken(refreshToken)
|
||||
}
|
||||
|
||||
func (s *SessionService) RevokeAllUserTokens(userID uint) error {
|
||||
return s.jwtService.RevokeAllRefreshTokens(userID)
|
||||
}
|
||||
|
||||
func (s *SessionService) InvalidateAllSessions(userID uint) error {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load user: %w", err)
|
||||
}
|
||||
|
||||
user.SessionVersion++
|
||||
if err := s.userRepo.Update(user); err != nil {
|
||||
return fmt.Errorf("update session version: %w", err)
|
||||
}
|
||||
|
||||
if err := s.jwtService.RevokeAllRefreshTokens(userID); err != nil {
|
||||
return fmt.Errorf("revoke refresh tokens: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SessionService) CleanupExpiredTokens() error {
|
||||
return s.jwtService.CleanupExpiredTokens()
|
||||
}
|
||||
563
internal/services/session_service_test.go
Normal file
563
internal/services/session_service_test.go
Normal file
@@ -0,0 +1,563 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/testutils"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type sessionMockRefreshTokenRepo struct {
|
||||
tokens map[string]*database.RefreshToken
|
||||
createErr error
|
||||
deleteByUserIDErr error
|
||||
deleteExpiredErr error
|
||||
getByTokenHashErr error
|
||||
}
|
||||
|
||||
func newSessionMockRefreshTokenRepo() *sessionMockRefreshTokenRepo {
|
||||
return &sessionMockRefreshTokenRepo{
|
||||
tokens: make(map[string]*database.RefreshToken),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *sessionMockRefreshTokenRepo) Create(token *database.RefreshToken) error {
|
||||
if m.createErr != nil {
|
||||
return m.createErr
|
||||
}
|
||||
if m.tokens == nil {
|
||||
m.tokens = make(map[string]*database.RefreshToken)
|
||||
}
|
||||
m.tokens[token.TokenHash] = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *sessionMockRefreshTokenRepo) GetByTokenHash(tokenHash string) (*database.RefreshToken, error) {
|
||||
if m.getByTokenHashErr != nil {
|
||||
return nil, m.getByTokenHashErr
|
||||
}
|
||||
if token, ok := m.tokens[tokenHash]; ok {
|
||||
return token, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *sessionMockRefreshTokenRepo) DeleteByUserID(userID uint) error {
|
||||
if m.deleteByUserIDErr != nil {
|
||||
return m.deleteByUserIDErr
|
||||
}
|
||||
for hash, token := range m.tokens {
|
||||
if token.UserID == userID {
|
||||
delete(m.tokens, hash)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *sessionMockRefreshTokenRepo) DeleteExpired() error {
|
||||
if m.deleteExpiredErr != nil {
|
||||
return m.deleteExpiredErr
|
||||
}
|
||||
now := time.Now()
|
||||
for hash, token := range m.tokens {
|
||||
if token.ExpiresAt.Before(now) {
|
||||
delete(m.tokens, hash)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *sessionMockRefreshTokenRepo) DeleteByID(id uint) error {
|
||||
for hash, token := range m.tokens {
|
||||
if token.ID == id {
|
||||
delete(m.tokens, hash)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *sessionMockRefreshTokenRepo) GetByUserID(userID uint) ([]database.RefreshToken, error) {
|
||||
var tokens []database.RefreshToken
|
||||
for _, token := range m.tokens {
|
||||
if token.UserID == userID {
|
||||
tokens = append(tokens, *token)
|
||||
}
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (m *sessionMockRefreshTokenRepo) CountByUserID(userID uint) (int64, error) {
|
||||
var count int64
|
||||
for _, token := range m.tokens {
|
||||
if token.UserID == userID {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func createSessionTestJWTService(userRepo *testutils.MockUserRepository) (*JWTService, *sessionMockRefreshTokenRepo) {
|
||||
cfg := &config.JWTConfig{
|
||||
Secret: "test-secret-key-that-is-long-enough-for-security",
|
||||
Expiration: 1,
|
||||
RefreshExpiration: 24,
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeyRotation: config.KeyRotationConfig{
|
||||
Enabled: false,
|
||||
CurrentKey: "",
|
||||
PreviousKey: "",
|
||||
KeyID: "",
|
||||
},
|
||||
}
|
||||
|
||||
refreshRepo := newSessionMockRefreshTokenRepo()
|
||||
jwtService := NewJWTService(cfg, userRepo, refreshRepo)
|
||||
return jwtService, refreshRepo
|
||||
}
|
||||
|
||||
func createTestUserWithPassword(password string) *database.User {
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: string(hashedPassword),
|
||||
EmailVerified: true,
|
||||
Locked: false,
|
||||
SessionVersion: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSessionService(t *testing.T) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
|
||||
service := NewSessionService(jwtService, userRepo)
|
||||
|
||||
if service == nil {
|
||||
t.Fatal("expected service to be created")
|
||||
}
|
||||
|
||||
if service.jwtService != jwtService {
|
||||
t.Error("expected jwtService to be set")
|
||||
}
|
||||
|
||||
if service.userRepo != userRepo {
|
||||
t.Error("expected userRepo to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionService_Login(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
password string
|
||||
setupMocks func() (*JWTService, *testutils.MockUserRepository)
|
||||
expectedError error
|
||||
checkResult func(*testing.T, *AuthResult)
|
||||
}{
|
||||
{
|
||||
name: "successful login",
|
||||
username: "testuser",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := createTestUserWithPassword("SecurePass123!")
|
||||
userRepo.Create(user)
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
return jwtService, userRepo
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, result *AuthResult) {
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if result.AccessToken == "" {
|
||||
t.Error("expected non-empty access token")
|
||||
}
|
||||
if result.RefreshToken == "" {
|
||||
t.Error("expected non-empty refresh token")
|
||||
}
|
||||
if result.User == nil {
|
||||
t.Fatal("expected non-nil user")
|
||||
}
|
||||
if result.User.Username != "testuser" {
|
||||
t.Errorf("expected username 'testuser', got %q", result.User.Username)
|
||||
}
|
||||
if result.User.Password != "" {
|
||||
t.Error("expected password to be sanitized")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty username",
|
||||
username: "",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
return jwtService, userRepo
|
||||
},
|
||||
expectedError: ErrInvalidCredentials,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "whitespace only username",
|
||||
username: " ",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
return jwtService, userRepo
|
||||
},
|
||||
expectedError: ErrInvalidCredentials,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "user not found",
|
||||
username: "nonexistent",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
return jwtService, userRepo
|
||||
},
|
||||
expectedError: ErrInvalidCredentials,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "email not verified",
|
||||
username: "testuser",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := createTestUserWithPassword("SecurePass123!")
|
||||
user.EmailVerified = false
|
||||
userRepo.Create(user)
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
return jwtService, userRepo
|
||||
},
|
||||
expectedError: ErrEmailNotVerified,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "account locked",
|
||||
username: "testuser",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := createTestUserWithPassword("SecurePass123!")
|
||||
user.Locked = true
|
||||
userRepo.Create(user)
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
return jwtService, userRepo
|
||||
},
|
||||
expectedError: ErrAccountLocked,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid password",
|
||||
username: "testuser",
|
||||
password: "WrongPassword",
|
||||
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := createTestUserWithPassword("SecurePass123!")
|
||||
userRepo.Create(user)
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
return jwtService, userRepo
|
||||
},
|
||||
expectedError: ErrInvalidCredentials,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "trims username whitespace",
|
||||
username: " testuser ",
|
||||
password: "SecurePass123!",
|
||||
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := createTestUserWithPassword("SecurePass123!")
|
||||
userRepo.Create(user)
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
return jwtService, userRepo
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, result *AuthResult) {
|
||||
if result.User.Username != "testuser" {
|
||||
t.Errorf("expected trimmed username 'testuser', got %q", result.User.Username)
|
||||
}
|
||||
if result.AccessToken == "" {
|
||||
t.Error("expected non-empty access token")
|
||||
}
|
||||
if result.RefreshToken == "" {
|
||||
t.Error("expected non-empty refresh token")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
jwtService, userRepo := tt.setupMocks()
|
||||
service := NewSessionService(jwtService, userRepo)
|
||||
|
||||
result, err := service.Login(tt.username, tt.password)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if tt.checkResult == nil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if tt.checkResult != nil {
|
||||
tt.checkResult(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionService_VerifyToken(t *testing.T) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := createTestUserWithPassword("SecurePass123!")
|
||||
userRepo.Create(user)
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
|
||||
t.Run("successful verification", func(t *testing.T) {
|
||||
service := NewSessionService(jwtService, userRepo)
|
||||
token, err := jwtService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
userID, err := service.VerifyToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if userID != user.ID {
|
||||
t.Errorf("expected user ID %d, got %d", user.ID, userID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid token", func(t *testing.T) {
|
||||
service := NewSessionService(jwtService, userRepo)
|
||||
_, err := service.VerifyToken("invalid-token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty token", func(t *testing.T) {
|
||||
service := NewSessionService(jwtService, userRepo)
|
||||
_, err := service.VerifyToken("")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty token")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionService_RefreshAccessToken(t *testing.T) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := createTestUserWithPassword("SecurePass123!")
|
||||
userRepo.Create(user)
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
|
||||
t.Run("successful refresh", func(t *testing.T) {
|
||||
service := NewSessionService(jwtService, userRepo)
|
||||
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate refresh token: %v", err)
|
||||
}
|
||||
|
||||
result, err := service.RefreshAccessToken(refreshToken)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if result.AccessToken == "" {
|
||||
t.Error("expected non-empty access token")
|
||||
}
|
||||
if result.RefreshToken != refreshToken {
|
||||
t.Errorf("expected refresh token to remain unchanged")
|
||||
}
|
||||
if result.User == nil {
|
||||
t.Fatal("expected non-nil user")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid refresh token", func(t *testing.T) {
|
||||
service := NewSessionService(jwtService, userRepo)
|
||||
_, err := service.RefreshAccessToken("invalid-refresh-token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid refresh token")
|
||||
}
|
||||
if !errors.Is(err, ErrRefreshTokenInvalid) {
|
||||
t.Errorf("expected ErrRefreshTokenInvalid, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionService_RevokeRefreshToken(t *testing.T) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := createTestUserWithPassword("SecurePass123!")
|
||||
userRepo.Create(user)
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
|
||||
t.Run("successful revocation", func(t *testing.T) {
|
||||
service := NewSessionService(jwtService, userRepo)
|
||||
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate refresh token: %v", err)
|
||||
}
|
||||
|
||||
err = service.RevokeRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
_, err = service.RefreshAccessToken(refreshToken)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when using revoked refresh token")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionService_RevokeAllUserTokens(t *testing.T) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := createTestUserWithPassword("SecurePass123!")
|
||||
userRepo.Create(user)
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
|
||||
t.Run("successful revocation", func(t *testing.T) {
|
||||
service := NewSessionService(jwtService, userRepo)
|
||||
refreshToken, err := jwtService.GenerateRefreshToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate refresh token: %v", err)
|
||||
}
|
||||
|
||||
err = service.RevokeAllUserTokens(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
_, err = service.RefreshAccessToken(refreshToken)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when using revoked refresh token")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionService_InvalidateAllSessions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uint
|
||||
setupMocks func() (*JWTService, *testutils.MockUserRepository)
|
||||
expectedError error
|
||||
checkResult func(*testing.T, *testutils.MockUserRepository)
|
||||
}{
|
||||
{
|
||||
name: "successful invalidation",
|
||||
userID: 1,
|
||||
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
SessionVersion: 1,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
return jwtService, userRepo
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, userRepo *testutils.MockUserRepository) {
|
||||
user, err := userRepo.GetByID(1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get user: %v", err)
|
||||
}
|
||||
if user.SessionVersion != 2 {
|
||||
t.Errorf("expected SessionVersion to be 2, got %d", user.SessionVersion)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user not found",
|
||||
userID: 999,
|
||||
setupMocks: func() (*JWTService, *testutils.MockUserRepository) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
return jwtService, userRepo
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
jwtService, userRepo := tt.setupMocks()
|
||||
service := NewSessionService(jwtService, userRepo)
|
||||
|
||||
err := service.InvalidateAllSessions(tt.userID)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if tt.name == "user not found" {
|
||||
if err.Error() == "" {
|
||||
t.Fatal("expected error message")
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if tt.checkResult != nil {
|
||||
tt.checkResult(t, userRepo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionService_CleanupExpiredTokens(t *testing.T) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
jwtService, _ := createSessionTestJWTService(userRepo)
|
||||
|
||||
t.Run("successful cleanup", func(t *testing.T) {
|
||||
service := NewSessionService(jwtService, userRepo)
|
||||
err := service.CleanupExpiredTokens()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
598
internal/services/url_metadata_service.go
Normal file
598
internal/services/url_metadata_service.go
Normal file
@@ -0,0 +1,598 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnsupportedScheme = errors.New("unsupported URL scheme")
|
||||
ErrTitleNotFound = errors.New("page title not found")
|
||||
ErrSSRFBlocked = errors.New("request blocked for security reasons")
|
||||
ErrTooManyRedirects = errors.New("too many redirects")
|
||||
)
|
||||
|
||||
const (
|
||||
maxTitleBodyBytes = 512 * 1024
|
||||
defaultUserAgent = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
maxRedirects = 3
|
||||
requestTimeout = 10 * time.Second
|
||||
dialTimeout = 5 * time.Second
|
||||
tlsHandshakeTimeout = 5 * time.Second
|
||||
responseHeaderTimeout = 5 * time.Second
|
||||
maxContentLength = 10 * 1024 * 1024
|
||||
)
|
||||
|
||||
type TitleFetcher interface {
|
||||
FetchTitle(ctx context.Context, rawURL string) (string, error)
|
||||
}
|
||||
|
||||
type DNSResolver interface {
|
||||
LookupIP(hostname string) ([]net.IP, error)
|
||||
}
|
||||
|
||||
type DefaultDNSResolver struct{}
|
||||
|
||||
func (d DefaultDNSResolver) LookupIP(hostname string) ([]net.IP, error) {
|
||||
return net.LookupIP(hostname)
|
||||
}
|
||||
|
||||
type DNSCache struct {
|
||||
mu sync.RWMutex
|
||||
data map[string][]net.IP
|
||||
}
|
||||
|
||||
func NewDNSCache() *DNSCache {
|
||||
return &DNSCache{
|
||||
data: make(map[string][]net.IP),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DNSCache) Get(hostname string) ([]net.IP, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
ips, exists := c.data[hostname]
|
||||
return ips, exists
|
||||
}
|
||||
|
||||
func (c *DNSCache) Set(hostname string, ips []net.IP) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.data[hostname] = ips
|
||||
}
|
||||
|
||||
type CachedDNSResolver struct {
|
||||
resolver DNSResolver
|
||||
cache *DNSCache
|
||||
}
|
||||
|
||||
func NewCachedDNSResolver(resolver DNSResolver) *CachedDNSResolver {
|
||||
return &CachedDNSResolver{
|
||||
resolver: resolver,
|
||||
cache: NewDNSCache(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CachedDNSResolver) LookupIP(hostname string) ([]net.IP, error) {
|
||||
if ips, exists := c.cache.Get(hostname); exists {
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
ips, err := c.resolver.LookupIP(hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.cache.Set(hostname, ips)
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
type CustomDialer struct {
|
||||
cache *DNSCache
|
||||
fallback *net.Dialer
|
||||
}
|
||||
|
||||
func NewCustomDialer(cache *DNSCache) *CustomDialer {
|
||||
return &CustomDialer{
|
||||
cache: cache,
|
||||
fallback: &net.Dialer{
|
||||
Timeout: dialTimeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (d *CustomDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return d.fallback.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
if ips, exists := d.cache.Get(host); exists {
|
||||
for _, ip := range ips {
|
||||
ipAddr := net.JoinHostPort(ip.String(), port)
|
||||
if conn, err := d.fallback.DialContext(ctx, network, ipAddr); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return d.fallback.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
type URLMetadataService struct {
|
||||
client *http.Client
|
||||
resolver DNSResolver
|
||||
dnsCache *DNSCache
|
||||
approvedHosts map[string]bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewURLMetadataService() *URLMetadataService {
|
||||
dnsCache := NewDNSCache()
|
||||
cachedResolver := NewCachedDNSResolver(DefaultDNSResolver{})
|
||||
customDialer := NewCustomDialer(dnsCache)
|
||||
|
||||
svc := &URLMetadataService{
|
||||
resolver: cachedResolver,
|
||||
dnsCache: dnsCache,
|
||||
approvedHosts: make(map[string]bool),
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
DialContext: customDialer.DialContext,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: tlsHandshakeTimeout,
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
DisableKeepAlives: false,
|
||||
}
|
||||
|
||||
svc.client = &http.Client{
|
||||
Timeout: requestTimeout,
|
||||
Transport: transport,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return ErrTooManyRedirects
|
||||
}
|
||||
|
||||
hostname := req.URL.Hostname()
|
||||
svc.mu.RLock()
|
||||
approved := svc.approvedHosts[hostname]
|
||||
svc.mu.RUnlock()
|
||||
|
||||
if approved {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := svc.validateURLForSSRF(req.URL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
svc.mu.Lock()
|
||||
svc.approvedHosts[hostname] = true
|
||||
svc.mu.Unlock()
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return svc
|
||||
}
|
||||
|
||||
func (s *URLMetadataService) FetchTitle(ctx context.Context, rawURL string) (string, error) {
|
||||
if rawURL == "" {
|
||||
return "", errors.New("empty URL")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse url: %w", err)
|
||||
}
|
||||
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return "", ErrUnsupportedScheme
|
||||
}
|
||||
|
||||
hostname := parsed.Hostname()
|
||||
s.mu.RLock()
|
||||
approved := s.approvedHosts[hostname]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !approved {
|
||||
if err := s.validateURLForSSRF(parsed); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.approvedHosts[hostname] = true
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("build request: %w", err)
|
||||
}
|
||||
|
||||
request.Header.Set("User-Agent", defaultUserAgent)
|
||||
request.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8")
|
||||
request.Header.Set("Accept-Language", "en-US,en;q=0.5")
|
||||
|
||||
resp, err := s.client.Do(request)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch url: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if !strings.Contains(strings.ToLower(contentType), "text/html") {
|
||||
return "", ErrTitleNotFound
|
||||
}
|
||||
|
||||
contentLength := resp.ContentLength
|
||||
if contentLength > maxContentLength {
|
||||
return "", ErrTitleNotFound
|
||||
}
|
||||
|
||||
limited := io.LimitReader(resp.Body, maxTitleBodyBytes)
|
||||
body, err := io.ReadAll(limited)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
title := s.ExtractTitleFromHTML(string(body))
|
||||
if title != "" {
|
||||
return title, nil
|
||||
}
|
||||
|
||||
return "", ErrTitleNotFound
|
||||
}
|
||||
|
||||
func (s *URLMetadataService) ExtractTitleFromHTML(html string) string {
|
||||
|
||||
if title := s.ExtractFromTitleTag(html); title != "" {
|
||||
return title
|
||||
}
|
||||
|
||||
if title := s.ExtractFromOpenGraph(html); title != "" {
|
||||
return title
|
||||
}
|
||||
|
||||
if title := s.ExtractFromJSONLD(html); title != "" {
|
||||
return title
|
||||
}
|
||||
|
||||
if title := s.ExtractFromTwitterCard(html); title != "" {
|
||||
return title
|
||||
}
|
||||
|
||||
if title := s.extractFromMetaTags(html); title != "" {
|
||||
return title
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *URLMetadataService) ExtractFromTitleTag(htmlContent string) string {
|
||||
tokenizer := html.NewTokenizer(strings.NewReader(htmlContent))
|
||||
|
||||
for {
|
||||
tokenType := tokenizer.Next()
|
||||
switch tokenType {
|
||||
case html.ErrorToken:
|
||||
if errors.Is(tokenizer.Err(), io.EOF) {
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
case html.StartTagToken, html.SelfClosingTagToken:
|
||||
token := tokenizer.Token()
|
||||
if strings.EqualFold(token.Data, "title") {
|
||||
textTokenType := tokenizer.Next()
|
||||
if textTokenType == html.TextToken {
|
||||
rawTitle := tokenizer.Token().Data
|
||||
cleaned := s.optimizedTitleClean(rawTitle)
|
||||
if cleaned != "" {
|
||||
return cleaned
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *URLMetadataService) ExtractFromOpenGraph(htmlContent string) string {
|
||||
|
||||
lines := strings.Split(htmlContent, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.Contains(strings.ToLower(line), `property="og:title"`) && strings.Contains(line, `content="`) {
|
||||
start := strings.Index(line, `content="`)
|
||||
if start != -1 {
|
||||
start += 9
|
||||
end := strings.Index(line[start:], `"`)
|
||||
if end != -1 {
|
||||
title := line[start : start+end]
|
||||
cleaned := s.optimizedTitleClean(title)
|
||||
if cleaned != "" {
|
||||
return cleaned
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *URLMetadataService) ExtractFromJSONLD(htmlContent string) string {
|
||||
|
||||
lines := strings.Split(htmlContent, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.Contains(line, `"@type":"VideoObject"`) || strings.Contains(line, `"@type":"WebPage"`) {
|
||||
|
||||
if strings.Contains(line, `"name":`) {
|
||||
start := strings.Index(line, `"name":`)
|
||||
if start != -1 {
|
||||
start += 7
|
||||
|
||||
for i := start; i < len(line); i++ {
|
||||
if line[i] == '"' {
|
||||
start = i + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
end := strings.Index(line[start:], `"`)
|
||||
if end != -1 {
|
||||
title := line[start : start+end]
|
||||
cleaned := s.optimizedTitleClean(title)
|
||||
if cleaned != "" {
|
||||
return cleaned
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *URLMetadataService) ExtractFromTwitterCard(htmlContent string) string {
|
||||
|
||||
lines := strings.Split(htmlContent, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.Contains(strings.ToLower(line), `name="twitter:title"`) && strings.Contains(line, `content="`) {
|
||||
start := strings.Index(line, `content="`)
|
||||
if start != -1 {
|
||||
start += 9
|
||||
end := strings.Index(line[start:], `"`)
|
||||
if end != -1 {
|
||||
title := line[start : start+end]
|
||||
cleaned := s.optimizedTitleClean(title)
|
||||
if cleaned != "" {
|
||||
return cleaned
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *URLMetadataService) extractFromMetaTags(htmlContent string) string {
|
||||
|
||||
lines := strings.Split(htmlContent, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if strings.Contains(strings.ToLower(line), `name="title"`) && strings.Contains(line, `content="`) {
|
||||
start := strings.Index(line, `content="`)
|
||||
if start != -1 {
|
||||
start += 9
|
||||
end := strings.Index(line[start:], `"`)
|
||||
if end != -1 {
|
||||
title := line[start : start+end]
|
||||
cleaned := s.optimizedTitleClean(title)
|
||||
if cleaned != "" {
|
||||
return cleaned
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *URLMetadataService) optimizedTitleClean(title string) string {
|
||||
if title == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.Grow(len(title))
|
||||
|
||||
inWhitespace := false
|
||||
started := false
|
||||
|
||||
for _, r := range title {
|
||||
if r == ' ' || r == '\t' || r == '\n' || r == '\r' {
|
||||
if started && !inWhitespace {
|
||||
result.WriteRune(' ')
|
||||
inWhitespace = true
|
||||
}
|
||||
} else {
|
||||
result.WriteRune(r)
|
||||
inWhitespace = false
|
||||
started = true
|
||||
}
|
||||
}
|
||||
|
||||
cleaned := result.String()
|
||||
|
||||
if len(cleaned) > 0 && cleaned[len(cleaned)-1] == ' ' {
|
||||
cleaned = cleaned[:len(cleaned)-1]
|
||||
}
|
||||
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func (s *URLMetadataService) validateURLForSSRF(u *url.URL) error {
|
||||
if u == nil {
|
||||
return ErrSSRFBlocked
|
||||
}
|
||||
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return ErrSSRFBlocked
|
||||
}
|
||||
|
||||
if u.Host == "" {
|
||||
return ErrSSRFBlocked
|
||||
}
|
||||
|
||||
hostname := u.Hostname()
|
||||
if hostname == "" {
|
||||
return ErrSSRFBlocked
|
||||
}
|
||||
|
||||
if isLocalhost(hostname) {
|
||||
return ErrSSRFBlocked
|
||||
}
|
||||
|
||||
ips, err := s.resolver.LookupIP(hostname)
|
||||
if err != nil {
|
||||
return ErrSSRFBlocked
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if isPrivateOrReservedIP(ip) {
|
||||
return ErrSSRFBlocked
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isLocalhost(hostname string) bool {
|
||||
hostname = strings.ToLower(hostname)
|
||||
|
||||
localhostNames := []string{
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
"0.0.0.0",
|
||||
"0:0:0:0:0:0:0:1",
|
||||
"0:0:0:0:0:0:0:0",
|
||||
}
|
||||
|
||||
for _, name := range localhostNames {
|
||||
if hostname == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isPrivateOrReservedIP(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return isPrivateIPv6(ip)
|
||||
}
|
||||
|
||||
privateRanges := []struct {
|
||||
start, end net.IP
|
||||
}{
|
||||
{net.IPv4(10, 0, 0, 0), net.IPv4(10, 255, 255, 255)},
|
||||
{net.IPv4(172, 16, 0, 0), net.IPv4(172, 31, 255, 255)},
|
||||
{net.IPv4(192, 168, 0, 0), net.IPv4(192, 168, 255, 255)},
|
||||
{net.IPv4(127, 0, 0, 0), net.IPv4(127, 255, 255, 255)},
|
||||
{net.IPv4(169, 254, 0, 0), net.IPv4(169, 254, 255, 255)},
|
||||
{net.IPv4(224, 0, 0, 0), net.IPv4(239, 255, 255, 255)},
|
||||
{net.IPv4(240, 0, 0, 0), net.IPv4(255, 255, 255, 255)},
|
||||
}
|
||||
|
||||
for _, r := range privateRanges {
|
||||
if ipInRange(ipv4, r.start, r.end) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isPrivateIPv6(ip net.IP) bool {
|
||||
privateRanges := []struct {
|
||||
prefix []byte
|
||||
length int
|
||||
}{
|
||||
{[]byte{0xfc, 0x00}, 7},
|
||||
{[]byte{0xfe, 0x80}, 10},
|
||||
{[]byte{0xff, 0x00}, 8},
|
||||
{[]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, 128},
|
||||
}
|
||||
|
||||
for _, r := range privateRanges {
|
||||
if ipv6InRange(ip, r.prefix, r.length) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func ipInRange(ip, start, end net.IP) bool {
|
||||
ipInt := ipToInt(ip)
|
||||
startInt := ipToInt(start)
|
||||
endInt := ipToInt(end)
|
||||
return ipInt >= startInt && ipInt <= endInt
|
||||
}
|
||||
|
||||
func ipToInt(ip net.IP) uint32 {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return 0
|
||||
}
|
||||
return uint32(ipv4[0])<<24 + uint32(ipv4[1])<<16 + uint32(ipv4[2])<<8 + uint32(ipv4[3])
|
||||
}
|
||||
|
||||
func ipv6InRange(ip net.IP, prefix []byte, length int) bool {
|
||||
ipBytes := ip.To16()
|
||||
if ipBytes == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
bytesToCompare := length / 8
|
||||
bitsToCompare := length % 8
|
||||
|
||||
for i := 0; i < bytesToCompare && i < len(prefix) && i < len(ipBytes); i++ {
|
||||
if ipBytes[i] != prefix[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if bitsToCompare > 0 && bytesToCompare < len(prefix) && bytesToCompare < len(ipBytes) {
|
||||
mask := byte(0xff) << (8 - bitsToCompare)
|
||||
if (ipBytes[bytesToCompare] & mask) != (prefix[bytesToCompare] & mask) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
1270
internal/services/url_metadata_service_test.go
Normal file
1270
internal/services/url_metadata_service_test.go
Normal file
File diff suppressed because it is too large
Load Diff
160
internal/services/user_management_service.go
Normal file
160
internal/services/user_management_service.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
"goyco/internal/validation"
|
||||
)
|
||||
|
||||
type UserManagementService struct {
|
||||
userRepo repositories.UserRepository
|
||||
postRepo repositories.PostRepository
|
||||
emailService *EmailService
|
||||
}
|
||||
|
||||
func NewUserManagementService(userRepo repositories.UserRepository, postRepo repositories.PostRepository, emailService *EmailService) *UserManagementService {
|
||||
return &UserManagementService{
|
||||
userRepo: userRepo,
|
||||
postRepo: postRepo,
|
||||
emailService: emailService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserManagementService) UpdateUsername(userID uint, newUsername string) (*database.User, error) {
|
||||
trimmed := TrimString(newUsername)
|
||||
if err := validation.ValidateUsername(trimmed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := s.userRepo.GetByUsernameIncludingDeleted(trimmed)
|
||||
if err == nil && existing.ID != userID {
|
||||
return nil, ErrUsernameTaken
|
||||
}
|
||||
if err != nil && !IsRecordNotFound(err) {
|
||||
return nil, fmt.Errorf("lookup username: %w", err)
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load user: %w", err)
|
||||
}
|
||||
|
||||
if user.Username == trimmed {
|
||||
return sanitizeUser(user), nil
|
||||
}
|
||||
|
||||
user.Username = trimmed
|
||||
if err := s.userRepo.Update(user); err != nil {
|
||||
if handled := HandleUniqueConstraintError(err); handled != err {
|
||||
return nil, handled
|
||||
}
|
||||
return nil, fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
|
||||
return sanitizeUser(user), nil
|
||||
}
|
||||
|
||||
func (s *UserManagementService) UpdatePassword(userID uint, currentPassword, newPassword string) (*database.User, error) {
|
||||
if err := validation.ValidatePassword(newPassword); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load user: %w", err)
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(currentPassword)); err != nil {
|
||||
return nil, fmt.Errorf("current password is incorrect")
|
||||
}
|
||||
|
||||
hashedPassword, err := HashPassword(newPassword, DefaultBcryptCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user.Password = string(hashedPassword)
|
||||
if err := s.userRepo.Update(user); err != nil {
|
||||
return nil, fmt.Errorf("update password: %w", err)
|
||||
}
|
||||
|
||||
return sanitizeUser(user), nil
|
||||
}
|
||||
|
||||
func (s *UserManagementService) UpdateEmail(userID uint, newEmail string) (*database.User, error) {
|
||||
normalized, err := normalizeEmail(newEmail)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := s.userRepo.GetByEmail(normalized)
|
||||
if err == nil && existing.ID != userID {
|
||||
return nil, ErrEmailTaken
|
||||
}
|
||||
if err != nil && !IsRecordNotFound(err) {
|
||||
return nil, fmt.Errorf("lookup email: %w", err)
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load user: %w", err)
|
||||
}
|
||||
|
||||
if user.Email == normalized {
|
||||
return sanitizeUser(user), nil
|
||||
}
|
||||
|
||||
previousEmail := user.Email
|
||||
previousVerified := user.EmailVerified
|
||||
previousVerifiedAt := user.EmailVerifiedAt
|
||||
previousToken := user.EmailVerificationToken
|
||||
previousSentAt := user.EmailVerificationSentAt
|
||||
|
||||
token, hashed, err := generateVerificationToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
user.Email = normalized
|
||||
user.EmailVerified = false
|
||||
user.EmailVerifiedAt = nil
|
||||
user.EmailVerificationToken = hashed
|
||||
user.EmailVerificationSentAt = &now
|
||||
|
||||
if err := s.userRepo.Update(user); err != nil {
|
||||
if handled := HandleUniqueConstraintError(err); handled != err {
|
||||
return nil, handled
|
||||
}
|
||||
return nil, fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
|
||||
if err := s.emailService.SendEmailChangeVerificationEmail(user, token); err != nil {
|
||||
user.Email = previousEmail
|
||||
user.EmailVerified = previousVerified
|
||||
user.EmailVerifiedAt = previousVerifiedAt
|
||||
user.EmailVerificationToken = previousToken
|
||||
user.EmailVerificationSentAt = previousSentAt
|
||||
_ = s.userRepo.Update(user)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sanitizeUser(user), nil
|
||||
}
|
||||
|
||||
func (s *UserManagementService) UserHasPosts(userID uint) (bool, int64, error) {
|
||||
if s.postRepo == nil {
|
||||
return false, 0, fmt.Errorf("post repository not configured")
|
||||
}
|
||||
|
||||
count, err := s.postRepo.CountByUserID(userID)
|
||||
if err != nil {
|
||||
return false, 0, fmt.Errorf("count user posts: %w", err)
|
||||
}
|
||||
|
||||
return count > 0, count, nil
|
||||
}
|
||||
647
internal/services/user_management_service_test.go
Normal file
647
internal/services/user_management_service_test.go
Normal file
@@ -0,0 +1,647 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/testutils"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func TestNewUserManagementService(t *testing.T) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
|
||||
service := NewUserManagementService(userRepo, postRepo, emailService)
|
||||
|
||||
if service == nil {
|
||||
t.Fatal("expected service to be created")
|
||||
}
|
||||
|
||||
if service.userRepo != userRepo {
|
||||
t.Error("expected userRepo to be set")
|
||||
}
|
||||
|
||||
if service.postRepo != postRepo {
|
||||
t.Error("expected postRepo to be set")
|
||||
}
|
||||
|
||||
if service.emailService != emailService {
|
||||
t.Error("expected emailService to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserManagementService_UpdateUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uint
|
||||
newUsername string
|
||||
setupMocks func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService)
|
||||
expectedError error
|
||||
checkResult func(*testing.T, *database.User)
|
||||
}{
|
||||
{
|
||||
name: "successful update",
|
||||
userID: 1,
|
||||
newUsername: "newusername",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "oldusername",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, user *database.User) {
|
||||
if user == nil {
|
||||
t.Fatal("expected non-nil user")
|
||||
}
|
||||
if user.Username != "newusername" {
|
||||
t.Errorf("expected username 'newusername', got %q", user.Username)
|
||||
}
|
||||
if user.Password != "" {
|
||||
t.Error("expected password to be sanitized")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid username",
|
||||
userID: 1,
|
||||
newUsername: "",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "oldusername",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "username already taken by different user",
|
||||
userID: 1,
|
||||
newUsername: "takenusername",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user1 := &database.User{
|
||||
ID: 1,
|
||||
Username: "oldusername",
|
||||
Email: "test1@example.com",
|
||||
}
|
||||
user2 := &database.User{
|
||||
ID: 2,
|
||||
Username: "takenusername",
|
||||
Email: "test2@example.com",
|
||||
}
|
||||
userRepo.Create(user1)
|
||||
userRepo.Create(user2)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: ErrUsernameTaken,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "same username",
|
||||
userID: 1,
|
||||
newUsername: "oldusername",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "oldusername",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, user *database.User) {
|
||||
if user.Username != "oldusername" {
|
||||
t.Errorf("expected username 'oldusername', got %q", user.Username)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user not found",
|
||||
userID: 999,
|
||||
newUsername: "newusername",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "trims username whitespace",
|
||||
userID: 1,
|
||||
newUsername: " newusername ",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "oldusername",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, user *database.User) {
|
||||
if user.Username != "newusername" {
|
||||
t.Errorf("expected trimmed username 'newusername', got %q", user.Username)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo, postRepo, emailService := tt.setupMocks()
|
||||
service := NewUserManagementService(userRepo, postRepo, emailService)
|
||||
|
||||
result, err := service.UpdateUsername(tt.userID, tt.newUsername)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if tt.checkResult == nil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if tt.checkResult != nil {
|
||||
tt.checkResult(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserManagementService_UpdatePassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uint
|
||||
currentPassword string
|
||||
newPassword string
|
||||
setupMocks func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService)
|
||||
expectedError error
|
||||
checkResult func(*testing.T, *database.User)
|
||||
}{
|
||||
{
|
||||
name: "successful update",
|
||||
userID: 1,
|
||||
currentPassword: "OldPass123!",
|
||||
newPassword: "NewPass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("OldPass123!"), bcrypt.DefaultCost)
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: string(hashedPassword),
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, user *database.User) {
|
||||
if user == nil {
|
||||
t.Fatal("expected non-nil user")
|
||||
}
|
||||
if user.Password != "" {
|
||||
t.Error("expected password to be sanitized")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid new password",
|
||||
userID: 1,
|
||||
currentPassword: "OldPass123!",
|
||||
newPassword: "short",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("OldPass123!"), bcrypt.DefaultCost)
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: string(hashedPassword),
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "incorrect current password",
|
||||
userID: 1,
|
||||
currentPassword: "WrongPassword",
|
||||
newPassword: "NewPass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("OldPass123!"), bcrypt.DefaultCost)
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: string(hashedPassword),
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "user not found",
|
||||
userID: 999,
|
||||
currentPassword: "OldPass123!",
|
||||
newPassword: "NewPass123!",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo, postRepo, emailService := tt.setupMocks()
|
||||
service := NewUserManagementService(userRepo, postRepo, emailService)
|
||||
|
||||
result, err := service.UpdatePassword(tt.userID, tt.currentPassword, tt.newPassword)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if tt.checkResult == nil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if tt.checkResult != nil {
|
||||
tt.checkResult(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserManagementService_UpdateEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uint
|
||||
newEmail string
|
||||
setupMocks func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService)
|
||||
expectedError error
|
||||
checkResult func(*testing.T, *database.User)
|
||||
}{
|
||||
{
|
||||
name: "successful update",
|
||||
userID: 1,
|
||||
newEmail: "newemail@example.com",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
now := time.Now()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "oldemail@example.com",
|
||||
EmailVerified: true,
|
||||
EmailVerifiedAt: &now,
|
||||
EmailVerificationToken: "",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, user *database.User) {
|
||||
if user == nil {
|
||||
t.Fatal("expected non-nil user")
|
||||
}
|
||||
if user.Email != "newemail@example.com" {
|
||||
t.Errorf("expected email 'newemail@example.com', got %q", user.Email)
|
||||
}
|
||||
if user.EmailVerified {
|
||||
t.Error("expected EmailVerified to be false")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
userID: 1,
|
||||
newEmail: "invalid-email",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "oldemail@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "email already taken by different user",
|
||||
userID: 1,
|
||||
newEmail: "taken@example.com",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user1 := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser1",
|
||||
Email: "oldemail@example.com",
|
||||
}
|
||||
user2 := &database.User{
|
||||
ID: 2,
|
||||
Username: "testuser2",
|
||||
Email: "taken@example.com",
|
||||
}
|
||||
userRepo.Create(user1)
|
||||
userRepo.Create(user2)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: ErrEmailTaken,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "same email",
|
||||
userID: 1,
|
||||
newEmail: "oldemail@example.com",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "oldemail@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, user *database.User) {
|
||||
if user.Email != "oldemail@example.com" {
|
||||
t.Errorf("expected email 'oldemail@example.com', got %q", user.Email)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user not found",
|
||||
userID: 999,
|
||||
newEmail: "newemail@example.com",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: nil,
|
||||
},
|
||||
{
|
||||
name: "normalizes email",
|
||||
userID: 1,
|
||||
newEmail: "NEWEMAIL@EXAMPLE.COM",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "oldemail@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: func(t *testing.T, user *database.User) {
|
||||
if user.Email != "newemail@example.com" {
|
||||
t.Errorf("expected normalized email 'newemail@example.com', got %q", user.Email)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "email service error rolls back",
|
||||
userID: 1,
|
||||
newEmail: "newemail@example.com",
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
now := time.Now()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "oldemail@example.com",
|
||||
EmailVerified: true,
|
||||
EmailVerifiedAt: &now,
|
||||
EmailVerificationToken: "",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
errorSender := &errorEmailSender{err: errors.New("email service error")}
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, errorSender)
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedError: nil,
|
||||
checkResult: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo, postRepo, emailService := tt.setupMocks()
|
||||
service := NewUserManagementService(userRepo, postRepo, emailService)
|
||||
|
||||
result, err := service.UpdateEmail(tt.userID, tt.newEmail)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if tt.checkResult == nil || tt.name == "email service error rolls back" {
|
||||
if tt.name == "email service error rolls back" {
|
||||
user, _ := userRepo.GetByID(1)
|
||||
if user.Email != "oldemail@example.com" {
|
||||
t.Error("expected email to be rolled back to original")
|
||||
}
|
||||
if !user.EmailVerified {
|
||||
t.Error("expected EmailVerified to be rolled back")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if tt.checkResult != nil {
|
||||
tt.checkResult(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserManagementService_UserHasPosts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uint
|
||||
setupMocks func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService)
|
||||
expectedHas bool
|
||||
expectedCount int64
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "user has posts",
|
||||
userID: 1,
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
userID := uint(1)
|
||||
post1 := &database.Post{
|
||||
ID: 1,
|
||||
AuthorID: &userID,
|
||||
Title: "Post 1",
|
||||
URL: "https://example.com/1",
|
||||
}
|
||||
post2 := &database.Post{
|
||||
ID: 2,
|
||||
AuthorID: &userID,
|
||||
Title: "Post 2",
|
||||
URL: "https://example.com/2",
|
||||
}
|
||||
postRepo.Create(post1)
|
||||
postRepo.Create(post2)
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedHas: true,
|
||||
expectedCount: 2,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "user has no posts",
|
||||
userID: 1,
|
||||
setupMocks: func() (*testutils.MockUserRepository, *testutils.MockPostRepository, *EmailService) {
|
||||
userRepo := testutils.NewMockUserRepository()
|
||||
user := &database.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
postRepo := testutils.NewMockPostRepository()
|
||||
emailService, _ := NewEmailService(testutils.AppTestConfig, &testutils.MockEmailSender{})
|
||||
return userRepo, postRepo, emailService
|
||||
},
|
||||
expectedHas: false,
|
||||
expectedCount: 0,
|
||||
expectedError: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
userRepo, postRepo, emailService := tt.setupMocks()
|
||||
service := NewUserManagementService(userRepo, postRepo, emailService)
|
||||
|
||||
hasPosts, count, err := service.UserHasPosts(tt.userID)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if hasPosts != tt.expectedHas {
|
||||
t.Errorf("expected hasPosts %v, got %v", tt.expectedHas, hasPosts)
|
||||
}
|
||||
|
||||
if count != tt.expectedCount {
|
||||
t.Errorf("expected count %d, got %d", tt.expectedCount, count)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
376
internal/services/vote_service.go
Normal file
376
internal/services/vote_service.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
)
|
||||
|
||||
type VoteService struct {
|
||||
voteRepo repositories.VoteRepository
|
||||
postRepo repositories.PostRepository
|
||||
db *gorm.DB
|
||||
voteMutex sync.RWMutex
|
||||
}
|
||||
|
||||
type VoteRequest struct {
|
||||
UserID uint `json:"user_id,omitempty"`
|
||||
PostID uint `json:"post_id"`
|
||||
Type database.VoteType `json:"type"`
|
||||
IPAddress string `json:"-"`
|
||||
UserAgent string `json:"-"`
|
||||
}
|
||||
|
||||
type VoteResponse struct {
|
||||
PostID uint `json:"post_id"`
|
||||
Type database.VoteType `json:"type"`
|
||||
UpVotes int `json:"up_votes"`
|
||||
DownVotes int `json:"down_votes"`
|
||||
Score int `json:"score"`
|
||||
Message string `json:"message"`
|
||||
IsUnauthenticated bool `json:"is_unauthenticated"`
|
||||
}
|
||||
|
||||
func NewVoteService(voteRepo repositories.VoteRepository, postRepo repositories.PostRepository, db *gorm.DB) *VoteService {
|
||||
return &VoteService{
|
||||
voteRepo: voteRepo,
|
||||
postRepo: postRepo,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VoteService) GenerateVoteHash(ipAddress, userAgent string, postID uint) string {
|
||||
data := fmt.Sprintf("%s:%s:%d", ipAddress, userAgent, postID)
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func (vs *VoteService) CastVote(req VoteRequest) (*VoteResponse, error) {
|
||||
if err := vs.validateVoteRequest(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vs.voteMutex.Lock()
|
||||
defer vs.voteMutex.Unlock()
|
||||
|
||||
var response *VoteResponse
|
||||
|
||||
if vs.db == nil {
|
||||
return vs.castVoteWithoutTransaction(req)
|
||||
}
|
||||
|
||||
err := vs.db.Transaction(func(tx *gorm.DB) error {
|
||||
txVoteRepo := vs.voteRepo.WithTx(tx)
|
||||
txPostRepo := vs.postRepo.WithTx(tx)
|
||||
|
||||
post, err := txPostRepo.GetByID(req.PostID)
|
||||
if err != nil {
|
||||
if IsRecordNotFound(err) {
|
||||
return errors.New("post not found")
|
||||
}
|
||||
return fmt.Errorf("failed to get post: %w", err)
|
||||
}
|
||||
|
||||
isUnauthenticated := req.UserID == 0
|
||||
|
||||
if req.Type == database.VoteNone {
|
||||
|
||||
var existingVote *database.Vote
|
||||
var err error
|
||||
|
||||
if isUnauthenticated {
|
||||
voteHash := vs.GenerateVoteHash(req.IPAddress, req.UserAgent, req.PostID)
|
||||
existingVote, err = txVoteRepo.GetByVoteHash(voteHash)
|
||||
} else {
|
||||
existingVote, err = txVoteRepo.GetByUserAndPost(req.UserID, req.PostID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if IsRecordNotFound(err) {
|
||||
response = vs.buildVoteResponse(post, database.VoteNone, isUnauthenticated)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to get existing vote: %w", err)
|
||||
}
|
||||
|
||||
if err := txVoteRepo.Delete(existingVote.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete vote: %w", err)
|
||||
}
|
||||
} else {
|
||||
|
||||
var vote *database.Vote
|
||||
|
||||
if isUnauthenticated {
|
||||
voteHash := vs.GenerateVoteHash(req.IPAddress, req.UserAgent, req.PostID)
|
||||
vote = &database.Vote{
|
||||
PostID: req.PostID,
|
||||
Type: req.Type,
|
||||
VoteHash: &voteHash,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
} else {
|
||||
vote = &database.Vote{
|
||||
UserID: &req.UserID,
|
||||
PostID: req.PostID,
|
||||
Type: req.Type,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
if err := txVoteRepo.CreateOrUpdate(vote); err != nil {
|
||||
return fmt.Errorf("failed to create or update vote: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := vs.updatePostVoteCountsWithTx(tx, req.PostID); err != nil {
|
||||
return fmt.Errorf("failed to update post vote counts: %w", err)
|
||||
}
|
||||
|
||||
updatedPost, err := txPostRepo.GetByID(req.PostID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get updated post: %w", err)
|
||||
}
|
||||
|
||||
response = vs.buildVoteResponse(updatedPost, req.Type, isUnauthenticated)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (vs *VoteService) castVoteWithoutTransaction(req VoteRequest) (*VoteResponse, error) {
|
||||
post, err := vs.postRepo.GetByID(req.PostID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("post not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get post: %w", err)
|
||||
}
|
||||
|
||||
isUnauthenticated := req.UserID == 0
|
||||
|
||||
if req.Type == database.VoteNone {
|
||||
|
||||
var existingVote *database.Vote
|
||||
var err error
|
||||
|
||||
if isUnauthenticated {
|
||||
voteHash := vs.GenerateVoteHash(req.IPAddress, req.UserAgent, req.PostID)
|
||||
existingVote, err = vs.voteRepo.GetByVoteHash(voteHash)
|
||||
} else {
|
||||
existingVote, err = vs.voteRepo.GetByUserAndPost(req.UserID, req.PostID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return vs.buildVoteResponse(post, database.VoteNone, isUnauthenticated), nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get existing vote: %w", err)
|
||||
}
|
||||
|
||||
if err := vs.voteRepo.Delete(existingVote.ID); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete vote: %w", err)
|
||||
}
|
||||
} else {
|
||||
|
||||
var vote *database.Vote
|
||||
|
||||
if isUnauthenticated {
|
||||
voteHash := vs.GenerateVoteHash(req.IPAddress, req.UserAgent, req.PostID)
|
||||
vote = &database.Vote{
|
||||
PostID: req.PostID,
|
||||
Type: req.Type,
|
||||
VoteHash: &voteHash,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
} else {
|
||||
vote = &database.Vote{
|
||||
UserID: &req.UserID,
|
||||
PostID: req.PostID,
|
||||
Type: req.Type,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
if err := vs.voteRepo.CreateOrUpdate(vote); err != nil {
|
||||
return nil, fmt.Errorf("failed to create or update vote: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := vs.updatePostVoteCounts(req.PostID); err != nil {
|
||||
return nil, fmt.Errorf("failed to update post vote counts: %w", err)
|
||||
}
|
||||
|
||||
updatedPost, err := vs.postRepo.GetByID(req.PostID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get updated post: %w", err)
|
||||
}
|
||||
|
||||
return vs.buildVoteResponse(updatedPost, req.Type, isUnauthenticated), nil
|
||||
}
|
||||
|
||||
func (vs *VoteService) GetUserVote(userID uint, postID uint, ipAddress, userAgent string) (*database.Vote, error) {
|
||||
|
||||
if userID > 0 {
|
||||
vote, err := vs.voteRepo.GetByUserAndPost(userID, postID)
|
||||
if err == nil && vote != nil {
|
||||
return vote, nil
|
||||
}
|
||||
}
|
||||
|
||||
voteHash := vs.GenerateVoteHash(ipAddress, userAgent, postID)
|
||||
vote, err := vs.voteRepo.GetByVoteHash(voteHash)
|
||||
if err == nil && vote != nil {
|
||||
return vote, nil
|
||||
}
|
||||
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (vs *VoteService) GetPostVotes(postID uint) ([]database.Vote, error) {
|
||||
votes, err := vs.voteRepo.GetByPostID(postID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return votes, nil
|
||||
}
|
||||
|
||||
func (vs *VoteService) DeleteVotesByPostID(postID uint) error {
|
||||
if vs.db != nil {
|
||||
if err := vs.db.Unscoped().Where("post_id = ?", postID).Delete(&database.Vote{}).Error; err != nil {
|
||||
return fmt.Errorf("failed to delete votes for post: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
votes, err := vs.voteRepo.GetByPostID(postID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get votes: %w", err)
|
||||
}
|
||||
|
||||
for _, vote := range votes {
|
||||
if err := vs.voteRepo.Delete(vote.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete vote %d: %w", vote.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vs *VoteService) validateVoteRequest(req VoteRequest) error {
|
||||
if req.PostID == 0 {
|
||||
return errors.New("post ID is required")
|
||||
}
|
||||
if req.Type != database.VoteUp && req.Type != database.VoteDown && req.Type != database.VoteNone {
|
||||
return errors.New("invalid vote type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vs *VoteService) buildVoteResponse(post *database.Post, voteType database.VoteType, isUnauthenticated bool) *VoteResponse {
|
||||
message := "Vote updated successfully"
|
||||
if voteType == database.VoteNone {
|
||||
message = "Vote removed successfully"
|
||||
}
|
||||
|
||||
return &VoteResponse{
|
||||
PostID: post.ID,
|
||||
Type: voteType,
|
||||
UpVotes: post.UpVotes,
|
||||
DownVotes: post.DownVotes,
|
||||
Score: post.Score,
|
||||
Message: message,
|
||||
IsUnauthenticated: isUnauthenticated,
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VoteService) updatePostVoteCounts(postID uint) error {
|
||||
if vs.db == nil {
|
||||
|
||||
votes, err := vs.voteRepo.GetByPostID(postID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get votes: %w", err)
|
||||
}
|
||||
|
||||
upVotes, downVotes := vs.countVotes(votes)
|
||||
score := upVotes - downVotes
|
||||
|
||||
post, err := vs.postRepo.GetByID(postID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get post: %w", err)
|
||||
}
|
||||
|
||||
post.UpVotes = upVotes
|
||||
post.DownVotes = downVotes
|
||||
post.Score = score
|
||||
|
||||
return vs.postRepo.Update(post)
|
||||
}
|
||||
return vs.updatePostVoteCountsWithTx(vs.db, postID)
|
||||
}
|
||||
|
||||
func (vs *VoteService) updatePostVoteCountsWithTx(tx *gorm.DB, postID uint) error {
|
||||
txVoteRepo := vs.voteRepo.WithTx(tx)
|
||||
txPostRepo := vs.postRepo.WithTx(tx)
|
||||
|
||||
votes, err := txVoteRepo.GetByPostID(postID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get votes: %w", err)
|
||||
}
|
||||
|
||||
upVotes, downVotes := vs.countVotes(votes)
|
||||
score := upVotes - downVotes
|
||||
|
||||
post, err := txPostRepo.GetByID(postID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get post: %w", err)
|
||||
}
|
||||
|
||||
post.UpVotes = upVotes
|
||||
post.DownVotes = downVotes
|
||||
post.Score = score
|
||||
|
||||
return txPostRepo.Update(post)
|
||||
}
|
||||
|
||||
func (vs *VoteService) countVotes(votes []database.Vote) (int, int) {
|
||||
upVotes := 0
|
||||
downVotes := 0
|
||||
|
||||
for _, vote := range votes {
|
||||
switch vote.Type {
|
||||
case database.VoteUp:
|
||||
upVotes++
|
||||
case database.VoteDown:
|
||||
downVotes++
|
||||
}
|
||||
}
|
||||
|
||||
return upVotes, downVotes
|
||||
}
|
||||
|
||||
func (vs *VoteService) GetVoteStatistics() (int64, int64, error) {
|
||||
|
||||
totalCount, err := vs.voteRepo.Count()
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to get vote count: %w", err)
|
||||
}
|
||||
|
||||
return totalCount, 0, nil
|
||||
}
|
||||
918
internal/services/vote_service_test.go
Normal file
918
internal/services/vote_service_test.go
Normal file
@@ -0,0 +1,918 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
)
|
||||
|
||||
type mockVoteRepo struct {
|
||||
votes map[uint]*database.Vote
|
||||
byUserPost map[string]*database.Vote
|
||||
byVoteHash map[string]*database.Vote
|
||||
nextID uint
|
||||
createErr error
|
||||
getByUserAndPostErr error
|
||||
getByVoteHashErr error
|
||||
getByPostIDErr error
|
||||
updateErr error
|
||||
deleteErr error
|
||||
createCalls int
|
||||
updateCalls int
|
||||
deleteCalls int
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockVoteRepo() *mockVoteRepo {
|
||||
return &mockVoteRepo{
|
||||
votes: make(map[uint]*database.Vote),
|
||||
byUserPost: make(map[string]*database.Vote),
|
||||
byVoteHash: make(map[string]*database.Vote),
|
||||
nextID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) Create(vote *database.Vote) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.createErr != nil {
|
||||
return m.createErr
|
||||
}
|
||||
|
||||
var key string
|
||||
if vote.UserID != nil {
|
||||
key = m.key(*vote.UserID, vote.PostID)
|
||||
} else if vote.VoteHash != nil {
|
||||
key = *vote.VoteHash
|
||||
} else {
|
||||
return errors.New("vote must have either user_id or vote_hash")
|
||||
}
|
||||
|
||||
if existingVote, exists := m.byUserPost[key]; exists {
|
||||
existingVote.Type = vote.Type
|
||||
existingVote.UpdatedAt = vote.UpdatedAt
|
||||
vote.ID = existingVote.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
vote.ID = m.nextID
|
||||
m.nextID++
|
||||
|
||||
voteCopy := *vote
|
||||
m.votes[vote.ID] = &voteCopy
|
||||
m.byUserPost[key] = &voteCopy
|
||||
if vote.VoteHash != nil {
|
||||
m.byVoteHash[*vote.VoteHash] = &voteCopy
|
||||
}
|
||||
|
||||
m.createCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) CreateOrUpdate(vote *database.Vote) error {
|
||||
return m.Create(vote)
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) GetByID(id uint) (*database.Vote, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if vote, ok := m.votes[id]; ok {
|
||||
voteCopy := *vote
|
||||
return &voteCopy, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) GetByUserAndPost(userID, postID uint) (*database.Vote, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.getByUserAndPostErr != nil {
|
||||
return nil, m.getByUserAndPostErr
|
||||
}
|
||||
|
||||
key := m.key(userID, postID)
|
||||
if vote, ok := m.byUserPost[key]; ok {
|
||||
voteCopy := *vote
|
||||
return &voteCopy, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) GetByVoteHash(voteHash string) (*database.Vote, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.getByVoteHashErr != nil {
|
||||
return nil, m.getByVoteHashErr
|
||||
}
|
||||
|
||||
if vote, ok := m.byVoteHash[voteHash]; ok {
|
||||
voteCopy := *vote
|
||||
return &voteCopy, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) GetByPostID(postID uint) ([]database.Vote, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.getByPostIDErr != nil {
|
||||
return nil, m.getByPostIDErr
|
||||
}
|
||||
|
||||
var votes []database.Vote
|
||||
for _, vote := range m.votes {
|
||||
if vote.PostID == postID {
|
||||
votes = append(votes, *vote)
|
||||
}
|
||||
}
|
||||
return votes, nil
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) GetByUserID(userID uint) ([]database.Vote, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var votes []database.Vote
|
||||
for _, vote := range m.votes {
|
||||
if vote.UserID != nil && *vote.UserID == userID {
|
||||
votes = append(votes, *vote)
|
||||
}
|
||||
}
|
||||
return votes, nil
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) Update(vote *database.Vote) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.updateErr != nil {
|
||||
return m.updateErr
|
||||
}
|
||||
|
||||
if _, ok := m.votes[vote.ID]; !ok {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
voteCopy := *vote
|
||||
m.votes[vote.ID] = &voteCopy
|
||||
|
||||
var key string
|
||||
if vote.UserID != nil {
|
||||
key = m.key(*vote.UserID, vote.PostID)
|
||||
m.byUserPost[key] = &voteCopy
|
||||
}
|
||||
if vote.VoteHash != nil {
|
||||
m.byVoteHash[*vote.VoteHash] = &voteCopy
|
||||
}
|
||||
|
||||
m.updateCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) Delete(id uint) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.deleteErr != nil {
|
||||
return m.deleteErr
|
||||
}
|
||||
|
||||
vote, ok := m.votes[id]
|
||||
if !ok {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
delete(m.votes, id)
|
||||
|
||||
var key string
|
||||
if vote.UserID != nil {
|
||||
key = m.key(*vote.UserID, vote.PostID)
|
||||
delete(m.byUserPost, key)
|
||||
}
|
||||
if vote.VoteHash != nil {
|
||||
delete(m.byVoteHash, *vote.VoteHash)
|
||||
}
|
||||
|
||||
m.deleteCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) Count() (int64, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return int64(len(m.votes)), nil
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) CountByPostID(postID uint) (int64, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
count := int64(0)
|
||||
for _, vote := range m.votes {
|
||||
if vote.PostID == postID {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) CountByUserID(userID uint) (int64, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
count := int64(0)
|
||||
for _, vote := range m.votes {
|
||||
if vote.UserID != nil && *vote.UserID == userID {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) WithTx(tx *gorm.DB) repositories.VoteRepository {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockVoteRepo) key(userID, postID uint) string {
|
||||
return fmt.Sprintf("%d:%d", userID, postID)
|
||||
}
|
||||
|
||||
type mockPostRepo struct {
|
||||
posts map[uint]*database.Post
|
||||
nextID uint
|
||||
getErr error
|
||||
updateErr error
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockPostRepo() *mockPostRepo {
|
||||
return &mockPostRepo{
|
||||
posts: make(map[uint]*database.Post),
|
||||
nextID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) GetByID(id uint) (*database.Post, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.getErr != nil {
|
||||
return nil, m.getErr
|
||||
}
|
||||
|
||||
if post, ok := m.posts[id]; ok {
|
||||
postCopy := *post
|
||||
return &postCopy, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) Update(post *database.Post) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.updateErr != nil {
|
||||
return m.updateErr
|
||||
}
|
||||
|
||||
if _, ok := m.posts[post.ID]; !ok {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
postCopy := *post
|
||||
m.posts[post.ID] = &postCopy
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) Create(post *database.Post) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
post.ID = m.nextID
|
||||
m.nextID++
|
||||
|
||||
postCopy := *post
|
||||
m.posts[post.ID] = &postCopy
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) GetByURL(url string) (*database.Post, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for _, post := range m.posts {
|
||||
if post.URL == url {
|
||||
postCopy := *post
|
||||
return &postCopy, nil
|
||||
}
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) GetByAuthorID(authorID uint) ([]database.Post, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var posts []database.Post
|
||||
for _, post := range m.posts {
|
||||
if post.AuthorID != nil && *post.AuthorID == authorID {
|
||||
posts = append(posts, *post)
|
||||
}
|
||||
}
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) GetAll(limit, offset int) ([]database.Post, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var posts []database.Post
|
||||
count := 0
|
||||
for _, post := range m.posts {
|
||||
if count >= offset && count < offset+limit {
|
||||
posts = append(posts, *post)
|
||||
}
|
||||
count++
|
||||
}
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) Count() (int64, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return int64(len(m.posts)), nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) Delete(id uint) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.posts[id]; !ok {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
delete(m.posts, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) GetByUserID(userID uint, limit, offset int) ([]database.Post, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var posts []database.Post
|
||||
count := 0
|
||||
for _, post := range m.posts {
|
||||
if post.AuthorID != nil && *post.AuthorID == userID {
|
||||
if count >= offset && count < offset+limit {
|
||||
posts = append(posts, *post)
|
||||
}
|
||||
count++
|
||||
}
|
||||
}
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) CountByUserID(userID uint) (int64, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
count := int64(0)
|
||||
for _, post := range m.posts {
|
||||
if post.AuthorID != nil && *post.AuthorID == userID {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) GetTopPosts(limit int) ([]database.Post, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var posts []database.Post
|
||||
count := 0
|
||||
for _, post := range m.posts {
|
||||
if count < limit {
|
||||
posts = append(posts, *post)
|
||||
count++
|
||||
}
|
||||
}
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) GetNewestPosts(limit int) ([]database.Post, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var posts []database.Post
|
||||
count := 0
|
||||
for _, post := range m.posts {
|
||||
if count < limit {
|
||||
posts = append(posts, *post)
|
||||
count++
|
||||
}
|
||||
}
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) Search(query string, limit, offset int) ([]database.Post, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var posts []database.Post
|
||||
count := 0
|
||||
for _, post := range m.posts {
|
||||
if strings.Contains(strings.ToLower(post.Title), strings.ToLower(query)) {
|
||||
if count >= offset && count < offset+limit {
|
||||
posts = append(posts, *post)
|
||||
}
|
||||
count++
|
||||
}
|
||||
}
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) GetPostsByDeletedUsers() ([]database.Post, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var posts []database.Post
|
||||
for _, post := range m.posts {
|
||||
if post.AuthorID == nil {
|
||||
posts = append(posts, *post)
|
||||
}
|
||||
}
|
||||
return posts, nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) HardDeletePostsByDeletedUsers() (int64, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
count := int64(0)
|
||||
for id, post := range m.posts {
|
||||
if post.AuthorID == nil {
|
||||
delete(m.posts, id)
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) HardDeleteAll() (int64, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
count := int64(len(m.posts))
|
||||
m.posts = make(map[uint]*database.Post)
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *mockPostRepo) WithTx(tx *gorm.DB) repositories.PostRepository {
|
||||
return m
|
||||
}
|
||||
|
||||
func TestVoteService_CastVote_Authenticated(t *testing.T) {
|
||||
voteRepo := newMockVoteRepo()
|
||||
postRepo := newMockPostRepo()
|
||||
service := NewVoteService(voteRepo, postRepo, nil)
|
||||
|
||||
post := &database.Post{
|
||||
ID: 1,
|
||||
Title: "Test Post",
|
||||
URL: "https://example.com",
|
||||
AuthorID: &[]uint{1}[0],
|
||||
UpVotes: 0,
|
||||
DownVotes: 0,
|
||||
Score: 0,
|
||||
}
|
||||
postRepo.posts[1] = post
|
||||
|
||||
req := VoteRequest{
|
||||
UserID: 1,
|
||||
PostID: 1,
|
||||
Type: database.VoteUp,
|
||||
IPAddress: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
}
|
||||
|
||||
result, err := service.CastVote(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected result, got nil")
|
||||
}
|
||||
|
||||
if result.Type != database.VoteUp {
|
||||
t.Errorf("Expected vote type 'up', got '%v'", result.Type)
|
||||
}
|
||||
|
||||
if result.UpVotes != 1 {
|
||||
t.Errorf("Expected up votes to be 1, got %d", result.UpVotes)
|
||||
}
|
||||
|
||||
if result.DownVotes != 0 {
|
||||
t.Errorf("Expected down votes to be 0, got %d", result.DownVotes)
|
||||
}
|
||||
|
||||
if result.Score != 1 {
|
||||
t.Errorf("Expected score to be 1, got %d", result.Score)
|
||||
}
|
||||
|
||||
if result.IsUnauthenticated {
|
||||
t.Error("Expected IsUnauthenticated to be false for authenticated vote")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoteService_CastVote_Unauthenticated(t *testing.T) {
|
||||
voteRepo := newMockVoteRepo()
|
||||
postRepo := newMockPostRepo()
|
||||
service := NewVoteService(voteRepo, postRepo, nil)
|
||||
|
||||
post := &database.Post{
|
||||
ID: 1,
|
||||
Title: "Test Post",
|
||||
URL: "https://example.com",
|
||||
AuthorID: &[]uint{1}[0],
|
||||
UpVotes: 0,
|
||||
DownVotes: 0,
|
||||
Score: 0,
|
||||
}
|
||||
postRepo.posts[1] = post
|
||||
|
||||
req := VoteRequest{
|
||||
UserID: 0,
|
||||
PostID: 1,
|
||||
Type: database.VoteUp,
|
||||
IPAddress: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
}
|
||||
|
||||
result, err := service.CastVote(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected result, got nil")
|
||||
}
|
||||
|
||||
if result.Type != database.VoteUp {
|
||||
t.Errorf("Expected vote type 'up', got '%v'", result.Type)
|
||||
}
|
||||
|
||||
if result.UpVotes != 1 {
|
||||
t.Errorf("Expected up votes to be 1, got %d", result.UpVotes)
|
||||
}
|
||||
|
||||
if result.DownVotes != 0 {
|
||||
t.Errorf("Expected down votes to be 0, got %d", result.DownVotes)
|
||||
}
|
||||
|
||||
if result.Score != 1 {
|
||||
t.Errorf("Expected score to be 1, got %d", result.Score)
|
||||
}
|
||||
|
||||
if !result.IsUnauthenticated {
|
||||
t.Error("Expected IsUnauthenticated to be true for unauthenticated vote")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoteService_CastVote_UpdateExisting(t *testing.T) {
|
||||
voteRepo := newMockVoteRepo()
|
||||
postRepo := newMockPostRepo()
|
||||
service := NewVoteService(voteRepo, postRepo, nil)
|
||||
|
||||
post := &database.Post{
|
||||
ID: 1,
|
||||
Title: "Test Post",
|
||||
URL: "https://example.com",
|
||||
AuthorID: &[]uint{1}[0],
|
||||
UpVotes: 0,
|
||||
DownVotes: 0,
|
||||
Score: 0,
|
||||
}
|
||||
postRepo.posts[1] = post
|
||||
|
||||
req := VoteRequest{
|
||||
UserID: 1,
|
||||
PostID: 1,
|
||||
Type: database.VoteUp,
|
||||
IPAddress: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
}
|
||||
|
||||
_, err := service.CastVote(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
req.Type = database.VoteDown
|
||||
result, err := service.CastVote(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if result.UpVotes != 0 {
|
||||
t.Errorf("Expected up votes to be 0, got %d", result.UpVotes)
|
||||
}
|
||||
|
||||
if result.DownVotes != 1 {
|
||||
t.Errorf("Expected down votes to be 1, got %d", result.DownVotes)
|
||||
}
|
||||
|
||||
if result.Score != -1 {
|
||||
t.Errorf("Expected score to be -1, got %d", result.Score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoteService_CastVote_RemoveVote(t *testing.T) {
|
||||
voteRepo := newMockVoteRepo()
|
||||
postRepo := newMockPostRepo()
|
||||
service := NewVoteService(voteRepo, postRepo, nil)
|
||||
|
||||
post := &database.Post{
|
||||
ID: 1,
|
||||
Title: "Test Post",
|
||||
URL: "https://example.com",
|
||||
AuthorID: &[]uint{1}[0],
|
||||
UpVotes: 0,
|
||||
DownVotes: 0,
|
||||
Score: 0,
|
||||
}
|
||||
postRepo.posts[1] = post
|
||||
|
||||
req := VoteRequest{
|
||||
UserID: 1,
|
||||
PostID: 1,
|
||||
Type: database.VoteUp,
|
||||
IPAddress: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
}
|
||||
|
||||
_, err := service.CastVote(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
req.Type = database.VoteNone
|
||||
result, err := service.CastVote(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if result.UpVotes != 0 {
|
||||
t.Errorf("Expected up votes to be 0, got %d", result.UpVotes)
|
||||
}
|
||||
|
||||
if result.DownVotes != 0 {
|
||||
t.Errorf("Expected down votes to be 0, got %d", result.DownVotes)
|
||||
}
|
||||
|
||||
if result.Score != 0 {
|
||||
t.Errorf("Expected score to be 0, got %d", result.Score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoteService_GetUserVote_Authenticated(t *testing.T) {
|
||||
voteRepo := newMockVoteRepo()
|
||||
postRepo := newMockPostRepo()
|
||||
service := NewVoteService(voteRepo, postRepo, nil)
|
||||
|
||||
userID := uint(1)
|
||||
vote := &database.Vote{
|
||||
ID: 1,
|
||||
UserID: &userID,
|
||||
PostID: 1,
|
||||
Type: database.VoteUp,
|
||||
}
|
||||
voteRepo.votes[1] = vote
|
||||
voteRepo.byUserPost["1:1"] = vote
|
||||
|
||||
result, err := service.GetUserVote(1, 1, "127.0.0.1", "test-agent")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected vote, got nil")
|
||||
}
|
||||
|
||||
if result.Type != database.VoteUp {
|
||||
t.Errorf("Expected vote type 'up', got '%v'", result.Type)
|
||||
}
|
||||
|
||||
if result.UserID == nil || *result.UserID != 1 {
|
||||
t.Errorf("Expected user ID 1, got %v", result.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoteService_GetUserVote_Unauthenticated(t *testing.T) {
|
||||
voteRepo := newMockVoteRepo()
|
||||
postRepo := newMockPostRepo()
|
||||
service := NewVoteService(voteRepo, postRepo, nil)
|
||||
|
||||
voteHash := service.GenerateVoteHash("127.0.0.1", "test-agent", 1)
|
||||
vote := &database.Vote{
|
||||
ID: 1,
|
||||
UserID: nil,
|
||||
PostID: 1,
|
||||
Type: database.VoteUp,
|
||||
VoteHash: &voteHash,
|
||||
}
|
||||
voteRepo.votes[1] = vote
|
||||
voteRepo.byVoteHash[voteHash] = vote
|
||||
|
||||
result, err := service.GetUserVote(0, 1, "127.0.0.1", "test-agent")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected vote, got nil")
|
||||
}
|
||||
|
||||
if result.Type != database.VoteUp {
|
||||
t.Errorf("Expected vote type 'up', got '%v'", result.Type)
|
||||
}
|
||||
|
||||
if result.UserID != nil {
|
||||
t.Error("Expected UserID to be nil for unauthenticated vote")
|
||||
}
|
||||
|
||||
if result.VoteHash == nil || *result.VoteHash != voteHash {
|
||||
t.Errorf("Expected vote hash '%s', got %v", voteHash, result.VoteHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoteService_GetPostVotes(t *testing.T) {
|
||||
voteRepo := newMockVoteRepo()
|
||||
postRepo := newMockPostRepo()
|
||||
service := NewVoteService(voteRepo, postRepo, nil)
|
||||
|
||||
userID1 := uint(1)
|
||||
userID2 := uint(2)
|
||||
voteHash := "test-hash"
|
||||
|
||||
vote1 := &database.Vote{
|
||||
ID: 1,
|
||||
UserID: &userID1,
|
||||
PostID: 1,
|
||||
Type: database.VoteUp,
|
||||
}
|
||||
vote2 := &database.Vote{
|
||||
ID: 2,
|
||||
UserID: &userID2,
|
||||
PostID: 1,
|
||||
Type: database.VoteDown,
|
||||
}
|
||||
vote3 := &database.Vote{
|
||||
ID: 3,
|
||||
UserID: nil,
|
||||
PostID: 1,
|
||||
Type: database.VoteUp,
|
||||
VoteHash: &voteHash,
|
||||
}
|
||||
|
||||
voteRepo.votes[1] = vote1
|
||||
voteRepo.votes[2] = vote2
|
||||
voteRepo.votes[3] = vote3
|
||||
|
||||
votes, err := service.GetPostVotes(1)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if len(votes) != 3 {
|
||||
t.Errorf("Expected 3 votes, got %d", len(votes))
|
||||
}
|
||||
|
||||
hasAuthenticated := false
|
||||
hasUnauthenticated := false
|
||||
for _, vote := range votes {
|
||||
if vote.UserID != nil {
|
||||
hasAuthenticated = true
|
||||
}
|
||||
if vote.VoteHash != nil {
|
||||
hasUnauthenticated = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAuthenticated {
|
||||
t.Error("Expected to find authenticated votes")
|
||||
}
|
||||
|
||||
if !hasUnauthenticated {
|
||||
t.Error("Expected to find unauthenticated votes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoteService_GetVoteStatistics(t *testing.T) {
|
||||
voteRepo := newMockVoteRepo()
|
||||
postRepo := newMockPostRepo()
|
||||
service := NewVoteService(voteRepo, postRepo, nil)
|
||||
|
||||
userID1 := uint(1)
|
||||
userID2 := uint(2)
|
||||
voteHash := "test-hash"
|
||||
|
||||
vote1 := &database.Vote{
|
||||
ID: 1,
|
||||
UserID: &userID1,
|
||||
PostID: 1,
|
||||
Type: database.VoteUp,
|
||||
}
|
||||
vote2 := &database.Vote{
|
||||
ID: 2,
|
||||
UserID: &userID2,
|
||||
PostID: 1,
|
||||
Type: database.VoteDown,
|
||||
}
|
||||
vote3 := &database.Vote{
|
||||
ID: 3,
|
||||
UserID: nil,
|
||||
PostID: 1,
|
||||
Type: database.VoteUp,
|
||||
VoteHash: &voteHash,
|
||||
}
|
||||
|
||||
voteRepo.votes[1] = vote1
|
||||
voteRepo.votes[2] = vote2
|
||||
voteRepo.votes[3] = vote3
|
||||
|
||||
authenticatedCount, anonymousCount, err := service.GetVoteStatistics()
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if authenticatedCount != 3 {
|
||||
t.Errorf("Expected total count to be 3, got %d", authenticatedCount)
|
||||
}
|
||||
|
||||
if anonymousCount != 0 {
|
||||
t.Errorf("Expected unauthenticated count to be 0, got %d", anonymousCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoteService_Validation(t *testing.T) {
|
||||
voteRepo := newMockVoteRepo()
|
||||
postRepo := newMockPostRepo()
|
||||
service := NewVoteService(voteRepo, postRepo, nil)
|
||||
|
||||
req := VoteRequest{
|
||||
UserID: 1,
|
||||
PostID: 0,
|
||||
Type: database.VoteUp,
|
||||
IPAddress: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
}
|
||||
|
||||
_, err := service.CastVote(req)
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing post ID")
|
||||
}
|
||||
|
||||
req.PostID = 1
|
||||
req.Type = "invalid"
|
||||
|
||||
_, err = service.CastVote(req)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid vote type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoteService_PostNotFound(t *testing.T) {
|
||||
voteRepo := newMockVoteRepo()
|
||||
postRepo := newMockPostRepo()
|
||||
service := NewVoteService(voteRepo, postRepo, nil)
|
||||
|
||||
req := VoteRequest{
|
||||
UserID: 1,
|
||||
PostID: 999,
|
||||
Type: database.VoteUp,
|
||||
IPAddress: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
}
|
||||
|
||||
_, err := service.CastVote(req)
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent post")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "post not found") {
|
||||
t.Errorf("Expected 'post not found' error, got %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user