Files
goyco/internal/handlers/common_test.go

1159 lines
30 KiB
Go

package handlers
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"goyco/internal/database"
"goyco/internal/middleware"
"goyco/internal/services"
"goyco/internal/testutils"
"github.com/go-chi/chi/v5"
"gorm.io/gorm"
)
func TestSendSuccessResponse(t *testing.T) {
helper := testutils.NewHandlerTestHelper(t)
w := httptest.NewRecorder()
SendSuccessResponse(w, "Test message", map[string]string{"key": "value"})
helper.AssertStatusCode(t, w, http.StatusOK)
response := helper.DecodeResponse(t, w)
helper.AssertResponseSuccess(t, response)
if response["message"] != "Test message" {
t.Errorf("Expected message 'Test message', got %v", response["message"])
}
}
func TestSendCreatedResponse(t *testing.T) {
helper := testutils.NewHandlerTestHelper(t)
w := httptest.NewRecorder()
SendCreatedResponse(w, "Created message", map[string]string{"id": "123"})
helper.AssertStatusCode(t, w, http.StatusCreated)
response := helper.DecodeResponse(t, w)
helper.AssertResponseSuccess(t, response)
if response["message"] != "Created message" {
t.Errorf("Expected message 'Created message', got %v", response["message"])
}
}
func TestSendErrorResponse(t *testing.T) {
helper := testutils.NewHandlerTestHelper(t)
w := httptest.NewRecorder()
SendErrorResponse(w, "Error message", http.StatusBadRequest)
helper.AssertStatusCode(t, w, http.StatusBadRequest)
response := helper.DecodeResponse(t, w)
helper.AssertResponseError(t, response)
if response["error"] != "Error message" {
t.Errorf("Expected error 'Error message', got %v", response["error"])
}
}
func TestGetClientIP(t *testing.T) {
originalTrust := middleware.TrustProxyHeaders
defer func() {
middleware.TrustProxyHeaders = originalTrust
}()
tests := []struct {
name string
headers map[string]string
remoteAddr string
trustProxyHeaders bool
expected string
}{
{
name: "Default: RemoteAddr when TrustProxyHeaders is false",
headers: map[string]string{"X-Forwarded-For": "192.168.1.1"},
remoteAddr: "127.0.0.1:8080",
trustProxyHeaders: false,
expected: "127.0.0.1",
},
{
name: "X-Forwarded-For header when TrustProxyHeaders is true",
headers: map[string]string{"X-Forwarded-For": "192.168.1.1"},
remoteAddr: "127.0.0.1:8080",
trustProxyHeaders: true,
expected: "192.168.1.1",
},
{
name: "X-Real-IP header when TrustProxyHeaders is true",
headers: map[string]string{"X-Real-IP": "10.0.0.1"},
remoteAddr: "127.0.0.1:8080",
trustProxyHeaders: true,
expected: "10.0.0.1",
},
{
name: "RemoteAddr fallback",
headers: map[string]string{},
remoteAddr: "127.0.0.1:8080",
trustProxyHeaders: false,
expected: "127.0.0.1",
},
{
name: "X-Forwarded-For with multiple IPs uses leftmost",
headers: map[string]string{"X-Forwarded-For": "203.0.113.1, 198.51.100.1"},
remoteAddr: "127.0.0.1:8080",
trustProxyHeaders: true,
expected: "203.0.113.1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware.TrustProxyHeaders = tt.trustProxyHeaders
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteAddr
for key, value := range tt.headers {
req.Header.Set(key, value)
}
result := GetClientIP(req)
if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result)
}
})
}
middleware.TrustProxyHeaders = originalTrust
}
func TestIsHTTPS(t *testing.T) {
tests := []struct {
name string
headers map[string]string
tls bool
expected bool
}{
{
name: "TLS connection",
headers: map[string]string{},
tls: true,
expected: true,
},
{
name: "X-Forwarded-Proto https",
headers: map[string]string{"X-Forwarded-Proto": "https"},
tls: false,
expected: true,
},
{
name: "X-Forwarded-Ssl on",
headers: map[string]string{"X-Forwarded-Ssl": "on"},
tls: false,
expected: true,
},
{
name: "X-Forwarded-Scheme https",
headers: map[string]string{"X-Forwarded-Scheme": "https"},
tls: false,
expected: true,
},
{
name: "HTTP connection",
headers: map[string]string{},
tls: false,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
for key, value := range tt.headers {
req.Header.Set(key, value)
}
if tt.tls {
t.Skip("Cannot test TLS with httptest.NewRequest")
}
result := IsHTTPS(req)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
func TestSanitizeUser(t *testing.T) {
user := &database.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
Password: "hashedpassword",
}
sanitized := SanitizeUser(user)
if sanitized.ID != user.ID {
t.Errorf("Expected ID %d, got %d", user.ID, sanitized.ID)
}
if sanitized.Username != user.Username {
t.Errorf("Expected username %s, got %s", user.Username, sanitized.Username)
}
if sanitized.CreatedAt != user.CreatedAt {
t.Errorf("Expected CreatedAt to match")
}
if sanitized.UpdatedAt != user.UpdatedAt {
t.Errorf("Expected UpdatedAt to match")
}
}
func TestSanitizeUserNil(t *testing.T) {
sanitized := SanitizeUser(nil)
if sanitized.ID != 0 {
t.Errorf("Expected zero value for nil user, got %v", sanitized)
}
}
func TestSanitizeUsers(t *testing.T) {
users := []database.User{
{ID: 1, Username: "user1", Email: "user1@example.com", Password: "hash1"},
{ID: 2, Username: "user2", Email: "user2@example.com", Password: "hash2"},
}
sanitized := SanitizeUsers(users)
if len(sanitized) != len(users) {
t.Errorf("Expected %d users, got %d", len(users), len(sanitized))
}
for i, user := range sanitized {
if user.ID != users[i].ID {
t.Errorf("User %d: Expected ID %d, got %d", i, users[i].ID, user.ID)
}
if user.Username != users[i].Username {
t.Errorf("User %d: Expected username %s, got %s", i, users[i].Username, user.Username)
}
}
}
func TestSetVoteCookie(t *testing.T) {
tests := []struct {
name string
request *http.Request
expectSecure bool
}{
{
name: "HTTP request - Secure flag false",
request: httptest.NewRequest("POST", "/vote", nil),
expectSecure: false,
},
{
name: "HTTPS via X-Forwarded-Proto - Secure flag true",
request: func() *http.Request {
req := httptest.NewRequest("POST", "/vote", nil)
req.Header.Set("X-Forwarded-Proto", "https")
return req
}(),
expectSecure: true,
},
{
name: "HTTPS via X-Forwarded-Ssl - Secure flag true",
request: func() *http.Request {
req := httptest.NewRequest("POST", "/vote", nil)
req.Header.Set("X-Forwarded-Ssl", "on")
return req
}(),
expectSecure: true,
},
{
name: "HTTPS via X-Forwarded-Scheme - Secure flag true",
request: func() *http.Request {
req := httptest.NewRequest("POST", "/vote", nil)
req.Header.Set("X-Forwarded-Scheme", "https")
return req
}(),
expectSecure: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
SetVoteCookie(w, tt.request, 123, database.VoteUp)
cookies := w.Result().Cookies()
if len(cookies) != 1 {
t.Fatalf("Expected 1 cookie, got %d", len(cookies))
}
cookie := cookies[0]
if cookie.Name != "vote_123" {
t.Errorf("Expected cookie name 'vote_123', got %s", cookie.Name)
}
if cookie.MaxAge != 86400*30 {
t.Errorf("Expected MaxAge %d, got %d", 86400*30, cookie.MaxAge)
}
if cookie.Secure != tt.expectSecure {
t.Errorf("Expected Secure flag %v, got %v", tt.expectSecure, cookie.Secure)
}
if !cookie.HttpOnly {
t.Error("Expected HttpOnly flag to be true")
}
if cookie.SameSite != http.SameSiteLaxMode {
t.Errorf("Expected SameSite to be LaxMode, got %v", cookie.SameSite)
}
})
}
}
func TestGetVoteCookie(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.AddCookie(&http.Cookie{
Name: "vote_123",
Value: "up:1234567890",
})
value := GetVoteCookie(req, 123)
if value != "up:1234567890" {
t.Errorf("Expected 'up:1234567890', got %s", value)
}
value = GetVoteCookie(req, 456)
if value != "" {
t.Errorf("Expected empty string, got %s", value)
}
}
func TestClearVoteCookie(t *testing.T) {
w := httptest.NewRecorder()
ClearVoteCookie(w, 123)
cookies := w.Result().Cookies()
if len(cookies) != 1 {
t.Fatalf("Expected 1 cookie, got %d", len(cookies))
}
cookie := cookies[0]
if cookie.Name != "vote_123" {
t.Errorf("Expected cookie name 'vote_123', got %s", cookie.Name)
}
if cookie.Value != "" {
t.Errorf("Expected empty cookie value, got %s", cookie.Value)
}
if cookie.MaxAge != -1 {
t.Errorf("Expected MaxAge -1, got %d", cookie.MaxAge)
}
}
func TestValidateRedirectURL(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "Valid simple path",
input: "/posts/123",
expected: "/posts/123",
},
{
name: "Valid path with query",
input: "/posts/123?sort=date",
expected: "/posts/123?sort=date",
},
{
name: "Valid path with fragment",
input: "/posts/123#comments",
expected: "/posts/123#comments",
},
{
name: "Valid path with query and fragment",
input: "/posts/123?sort=date#comments",
expected: "/posts/123?sort=date#comments",
},
{
name: "Root path",
input: "/",
expected: "/",
},
{
name: "Path with multiple segments",
input: "/api/posts/123/comments",
expected: "/api/posts/123/comments",
},
{
name: "Absolute URL with scheme",
input: "https://evil.com",
expected: "",
},
{
name: "Absolute URL with http",
input: "http://evil.com",
expected: "",
},
{
name: "Protocol-relative URL",
input: "//evil.com",
expected: "",
},
{
name: "URL without leading slash",
input: "posts/123",
expected: "",
},
{
name: "Empty string",
input: "",
expected: "",
},
{
name: "Whitespace only",
input: " ",
expected: "",
},
{
name: "URL with scheme in path",
input: "/https://evil.com",
expected: "/https://evil.com",
},
{
name: "Too long URL",
input: "/" + strings.Repeat("a", 512),
expected: "",
},
{
name: "Path with encoded characters",
input: "/posts/123%20test",
expected: "/posts/123%20test",
},
{
name: "Malformed URL",
input: "/posts/\x00",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ValidateRedirectURL(tt.input)
if result != tt.expected {
t.Errorf("ValidateRedirectURL(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestParseUintParam(t *testing.T) {
tests := []struct {
name string
paramValue string
paramName string
entityName string
expectedID uint
expectedOK bool
expectedStatus int
expectedError string
}{
{
name: "valid ID",
paramValue: "123",
paramName: "id",
entityName: "Post",
expectedID: 123,
expectedOK: true,
expectedStatus: 0,
},
{
name: "missing parameter",
paramValue: "",
paramName: "id",
entityName: "Post",
expectedID: 0,
expectedOK: false,
expectedStatus: http.StatusBadRequest,
expectedError: "Post ID is required",
},
{
name: "invalid ID - not a number",
paramValue: "abc",
paramName: "id",
entityName: "Post",
expectedID: 0,
expectedOK: false,
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid Post ID",
},
{
name: "invalid ID - negative number",
paramValue: "-1",
paramName: "id",
entityName: "User",
expectedID: 0,
expectedOK: false,
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid User ID",
},
{
name: "large valid ID",
paramValue: "4294967295",
paramName: "id",
entityName: "Post",
expectedID: 4294967295,
expectedOK: true,
expectedStatus: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
ctx := chi.NewRouteContext()
if tt.paramValue != "" {
ctx.URLParams.Add(tt.paramName, tt.paramValue)
}
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
id, ok := ParseUintParam(w, r, tt.paramName, tt.entityName)
if ok != tt.expectedOK {
t.Errorf("ParseUintParam() ok = %v, want %v", ok, tt.expectedOK)
}
if id != tt.expectedID {
t.Errorf("ParseUintParam() id = %v, want %v", id, tt.expectedID)
}
if !tt.expectedOK {
result := w.Result()
if result.StatusCode != tt.expectedStatus {
t.Errorf("ParseUintParam() status = %v, want %v", result.StatusCode, tt.expectedStatus)
}
var response map[string]any
json.NewDecoder(w.Body).Decode(&response)
if tt.expectedError != "" && !strings.Contains(response["error"].(string), tt.expectedError) {
t.Errorf("ParseUintParam() error = %v, want to contain %v", response["error"], tt.expectedError)
}
}
})
}
}
func TestRequireAuth(t *testing.T) {
tests := []struct {
name string
userID uint
expectedID uint
expectedOK bool
expectedStatus int
expectedError string
}{
{
name: "authenticated user",
userID: 123,
expectedID: 123,
expectedOK: true,
expectedStatus: 0,
},
{
name: "unauthenticated user (no userID)",
userID: 0,
expectedID: 0,
expectedOK: false,
expectedStatus: http.StatusUnauthorized,
expectedError: "Authentication required",
},
{
name: "authenticated user with large ID",
userID: 4294967295,
expectedID: 4294967295,
expectedOK: true,
expectedStatus: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(r.Context(), middleware.UserIDKey, tt.userID)
r = r.WithContext(ctx)
userID, ok := RequireAuth(w, r)
if ok != tt.expectedOK {
t.Errorf("RequireAuth() ok = %v, want %v", ok, tt.expectedOK)
}
if userID != tt.expectedID {
t.Errorf("RequireAuth() userID = %v, want %v", userID, tt.expectedID)
}
if !tt.expectedOK {
result := w.Result()
if result.StatusCode != tt.expectedStatus {
t.Errorf("RequireAuth() status = %v, want %v", result.StatusCode, tt.expectedStatus)
}
var response map[string]any
json.NewDecoder(w.Body).Decode(&response)
if tt.expectedError != "" && response["error"] != tt.expectedError {
t.Errorf("RequireAuth() error = %v, want %v", response["error"], tt.expectedError)
}
}
})
}
}
func TestDecodeJSONRequest(t *testing.T) {
tests := []struct {
name string
body string
target any
expectedOK bool
expectedStatus int
expectedError string
}{
{
name: "valid JSON",
body: `{"username": "test", "password": "pass123"}`,
target: &struct {
Username string `json:"username"`
Password string `json:"password"`
}{},
expectedOK: true,
expectedStatus: 0,
},
{
name: "invalid JSON",
body: `{"username": "test", "password":}`,
target: &struct {
Username string `json:"username"`
Password string `json:"password"`
}{},
expectedOK: false,
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
},
{
name: "empty body",
body: "",
target: &struct {
Username string `json:"username"`
}{},
expectedOK: false,
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
},
{
name: "malformed JSON",
body: `{username: test}`,
target: &struct {
Username string `json:"username"`
}{},
expectedOK: false,
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", strings.NewReader(tt.body))
r.Header.Set("Content-Type", "application/json")
ok := DecodeJSONRequest(w, r, tt.target)
if ok != tt.expectedOK {
t.Errorf("DecodeJSONRequest() ok = %v, want %v", ok, tt.expectedOK)
}
if tt.expectedOK {
if tt.name == "valid JSON" {
decoded := tt.target.(*struct {
Username string `json:"username"`
Password string `json:"password"`
})
if decoded.Username != "test" || decoded.Password != "pass123" {
t.Errorf("DecodeJSONRequest() failed to decode data correctly")
}
}
} else {
result := w.Result()
if result.StatusCode != tt.expectedStatus {
t.Errorf("DecodeJSONRequest() status = %v, want %v", result.StatusCode, tt.expectedStatus)
}
var response map[string]any
json.NewDecoder(w.Body).Decode(&response)
if tt.expectedError != "" && response["error"] != tt.expectedError {
t.Errorf("DecodeJSONRequest() error = %v, want %v", response["error"], tt.expectedError)
}
}
})
}
}
func TestParsePagination(t *testing.T) {
tests := []struct {
name string
queryParams map[string]string
expectedLimit int
expectedOffset int
}{
{
name: "default values - no params",
queryParams: map[string]string{},
expectedLimit: 20,
expectedOffset: 0,
},
{
name: "valid limit and offset",
queryParams: map[string]string{"limit": "10", "offset": "5"},
expectedLimit: 10,
expectedOffset: 5,
},
{
name: "only limit",
queryParams: map[string]string{"limit": "50"},
expectedLimit: 50,
expectedOffset: 0,
},
{
name: "only offset",
queryParams: map[string]string{"offset": "100"},
expectedLimit: 20,
expectedOffset: 100,
},
{
name: "invalid limit - not a number",
queryParams: map[string]string{"limit": "abc", "offset": "5"},
expectedLimit: 20,
expectedOffset: 5,
},
{
name: "invalid limit - zero",
queryParams: map[string]string{"limit": "0", "offset": "5"},
expectedLimit: 20,
expectedOffset: 5,
},
{
name: "invalid limit - negative",
queryParams: map[string]string{"limit": "-5", "offset": "5"},
expectedLimit: 20,
expectedOffset: 5,
},
{
name: "invalid offset - not a number",
queryParams: map[string]string{"limit": "10", "offset": "abc"},
expectedLimit: 10,
expectedOffset: 0,
},
{
name: "invalid offset - negative",
queryParams: map[string]string{"limit": "10", "offset": "-5"},
expectedLimit: 10,
expectedOffset: 0,
},
{
name: "offset zero is valid",
queryParams: map[string]string{"limit": "10", "offset": "0"},
expectedLimit: 10,
expectedOffset: 0,
},
{
name: "large valid values",
queryParams: map[string]string{"limit": "1000", "offset": "500"},
expectedLimit: 1000,
expectedOffset: 500,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
q := req.URL.Query()
for key, value := range tt.queryParams {
q.Set(key, value)
}
req.URL.RawQuery = q.Encode()
limit, offset := parsePagination(req)
if limit != tt.expectedLimit {
t.Errorf("parsePagination() limit = %v, want %v", limit, tt.expectedLimit)
}
if offset != tt.expectedOffset {
t.Errorf("parsePagination() offset = %v, want %v", offset, tt.expectedOffset)
}
})
}
}
func TestNewVoteContext(t *testing.T) {
tests := []struct {
name string
userID uint
headers map[string]string
remoteAddr string
userAgent string
expectedUserID uint
expectedIP string
expectedAgent string
}{
{
name: "authenticated user with all fields",
userID: 123,
headers: map[string]string{"X-Forwarded-For": "192.168.1.1"},
remoteAddr: "127.0.0.1:8080",
userAgent: "Mozilla/5.0",
expectedUserID: 123,
expectedIP: "192.168.1.1",
expectedAgent: "Mozilla/5.0",
},
{
name: "unauthenticated user",
userID: 0,
headers: map[string]string{},
remoteAddr: "127.0.0.1:8080",
userAgent: "Go-http-client/1.1",
expectedUserID: 0,
expectedIP: "127.0.0.1",
expectedAgent: "Go-http-client/1.1",
},
{
name: "missing user agent",
userID: 456,
headers: map[string]string{},
remoteAddr: "10.0.0.1:8080",
userAgent: "",
expectedUserID: 456,
expectedIP: "10.0.0.1",
expectedAgent: "",
},
}
originalTrust := middleware.TrustProxyHeaders
defer func() {
middleware.TrustProxyHeaders = originalTrust
}()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware.TrustProxyHeaders = len(tt.headers) > 0
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteAddr
if tt.userAgent != "" {
req.Header.Set("User-Agent", tt.userAgent)
}
for key, value := range tt.headers {
req.Header.Set(key, value)
}
ctx := context.WithValue(req.Context(), middleware.UserIDKey, tt.userID)
req = req.WithContext(ctx)
voteCtx := NewVoteContext(req)
if voteCtx.UserID != tt.expectedUserID {
t.Errorf("NewVoteContext() UserID = %v, want %v", voteCtx.UserID, tt.expectedUserID)
}
if voteCtx.IPAddress != tt.expectedIP {
t.Errorf("NewVoteContext() IPAddress = %v, want %v", voteCtx.IPAddress, tt.expectedIP)
}
if voteCtx.UserAgent != tt.expectedAgent {
t.Errorf("NewVoteContext() UserAgent = %v, want %v", voteCtx.UserAgent, tt.expectedAgent)
}
})
}
middleware.TrustProxyHeaders = originalTrust
}
func TestHandleRepoError(t *testing.T) {
tests := []struct {
name string
err error
entityName string
expectedOK bool
expectedStatus int
expectedError string
}{
{
name: "nil error",
err: nil,
entityName: "Post",
expectedOK: true,
expectedStatus: 0,
},
{
name: "gorm.ErrRecordNotFound",
err: gorm.ErrRecordNotFound,
entityName: "Post",
expectedOK: false,
expectedStatus: http.StatusNotFound,
expectedError: "Post not found",
},
{
name: "wrapped gorm.ErrRecordNotFound",
err: fmt.Errorf("database error: %w", gorm.ErrRecordNotFound),
entityName: "User",
expectedOK: false,
expectedStatus: http.StatusNotFound,
expectedError: "User not found",
},
{
name: "other error",
err: errors.New("database connection failed"),
entityName: "Post",
expectedOK: false,
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to retrieve Post",
},
{
name: "generic error with custom entity",
err: errors.New("timeout"),
entityName: "Comment",
expectedOK: false,
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to retrieve Comment",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
ok := HandleRepoError(w, tt.err, tt.entityName)
if ok != tt.expectedOK {
t.Errorf("HandleRepoError() ok = %v, want %v", ok, tt.expectedOK)
}
if !tt.expectedOK {
result := w.Result()
if result.StatusCode != tt.expectedStatus {
t.Errorf("HandleRepoError() status = %v, want %v", result.StatusCode, tt.expectedStatus)
}
var response map[string]any
json.NewDecoder(w.Body).Decode(&response)
if tt.expectedError != "" && response["error"] != tt.expectedError {
t.Errorf("HandleRepoError() error = %v, want %v", response["error"], tt.expectedError)
}
}
})
}
}
func TestHandleServiceError(t *testing.T) {
tests := []struct {
name string
err error
defaultMsg string
defaultCode int
expectedOK bool
expectedStatus int
expectedError string
}{
{
name: "nil error",
err: nil,
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: true,
expectedStatus: 0,
},
{
name: "ErrInvalidCredentials",
err: services.ErrInvalidCredentials,
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid username or password",
},
{
name: "ErrEmailNotVerified",
err: services.ErrEmailNotVerified,
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusForbidden,
expectedError: "Please confirm your email before logging in",
},
{
name: "ErrAccountLocked",
err: services.ErrAccountLocked,
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusForbidden,
expectedError: "Your account has been locked. Please contact us for assistance.",
},
{
name: "ErrUsernameTaken",
err: services.ErrUsernameTaken,
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusConflict,
expectedError: "Username is already taken",
},
{
name: "ErrEmailTaken",
err: services.ErrEmailTaken,
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusConflict,
expectedError: "Email is already registered",
},
{
name: "ErrInvalidEmail",
err: services.ErrInvalidEmail,
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid email address",
},
{
name: "ErrPasswordTooShort",
err: services.ErrPasswordTooShort,
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusBadRequest,
expectedError: "Password must be at least 8 characters",
},
{
name: "ErrInvalidVerificationToken",
err: services.ErrInvalidVerificationToken,
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid or expired verification token",
},
{
name: "ErrRefreshTokenExpired",
err: services.ErrRefreshTokenExpired,
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusUnauthorized,
expectedError: "Refresh token has expired",
},
{
name: "ErrRefreshTokenInvalid",
err: services.ErrRefreshTokenInvalid,
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid refresh token",
},
{
name: "ErrInvalidDeletionToken",
err: services.ErrInvalidDeletionToken,
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusBadRequest,
expectedError: "This deletion link is invalid or has expired.",
},
{
name: "ErrEmailSenderUnavailable",
err: services.ErrEmailSenderUnavailable,
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusServiceUnavailable,
expectedError: "Email service is unavailable. Please try again later.",
},
{
name: "wrapped ErrInvalidCredentials",
err: fmt.Errorf("auth failed: %w", services.ErrInvalidCredentials),
defaultMsg: "Default error",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid username or password",
},
{
name: "unmapped error - uses default",
err: errors.New("unknown error"),
defaultMsg: "Something went wrong",
defaultCode: http.StatusInternalServerError,
expectedOK: false,
expectedStatus: http.StatusInternalServerError,
expectedError: "Something went wrong",
},
{
name: "unmapped error with custom default",
err: errors.New("timeout"),
defaultMsg: "Request timeout",
defaultCode: http.StatusGatewayTimeout,
expectedOK: false,
expectedStatus: http.StatusGatewayTimeout,
expectedError: "Request timeout",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
ok := HandleServiceError(w, tt.err, tt.defaultMsg, tt.defaultCode)
if ok != tt.expectedOK {
t.Errorf("HandleServiceError() ok = %v, want %v", ok, tt.expectedOK)
}
if !tt.expectedOK {
result := w.Result()
if result.StatusCode != tt.expectedStatus {
t.Errorf("HandleServiceError() status = %v, want %v", result.StatusCode, tt.expectedStatus)
}
var response map[string]any
json.NewDecoder(w.Body).Decode(&response)
if tt.expectedError != "" && response["error"] != tt.expectedError {
t.Errorf("HandleServiceError() error = %v, want %v", response["error"], tt.expectedError)
}
}
})
}
}