Compare commits

...

6 Commits

9 changed files with 205 additions and 101 deletions

View File

@@ -225,6 +225,11 @@ func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Req
email := strings.TrimSpace(req.Email) email := strings.TrimSpace(req.Email)
if email == "" {
SendErrorResponse(w, "Email address is required", http.StatusBadRequest)
return
}
err := h.authService.ResendVerificationEmail(email) err := h.authService.ResendVerificationEmail(email)
if err != nil { if err != nil {
switch { switch {
@@ -293,6 +298,11 @@ func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Reques
usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail) usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail)
if usernameOrEmail == "" {
SendErrorResponse(w, "Username or email is required", http.StatusBadRequest)
return
}
if err := h.authService.RequestPasswordReset(usernameOrEmail); err != nil { if err := h.authService.RequestPasswordReset(usernameOrEmail); err != nil {
} }
@@ -319,6 +329,11 @@ func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
token := strings.TrimSpace(req.Token) token := strings.TrimSpace(req.Token)
newPassword := strings.TrimSpace(req.NewPassword) newPassword := strings.TrimSpace(req.NewPassword)
if token == "" {
SendErrorResponse(w, "Token is required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(newPassword); err != nil { if err := validation.ValidatePassword(newPassword); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest) SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return return
@@ -467,6 +482,11 @@ func (h *AuthHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
currentPassword := strings.TrimSpace(req.CurrentPassword) currentPassword := strings.TrimSpace(req.CurrentPassword)
newPassword := strings.TrimSpace(req.NewPassword) newPassword := strings.TrimSpace(req.NewPassword)
if currentPassword == "" {
SendErrorResponse(w, "Current password is required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(newPassword); err != nil { if err := validation.ValidatePassword(newPassword); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest) SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return return
@@ -538,6 +558,11 @@ func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Requ
token := strings.TrimSpace(req.Token) token := strings.TrimSpace(req.Token)
if token == "" {
SendErrorResponse(w, "Deletion token is required", http.StatusBadRequest)
return
}
if err := h.authService.ConfirmAccountDeletionWithPosts(token, req.DeletePosts); err != nil { if err := h.authService.ConfirmAccountDeletionWithPosts(token, req.DeletePosts); err != nil {
switch { switch {
case errors.Is(err, services.ErrInvalidDeletionToken): case errors.Is(err, services.ErrInvalidDeletionToken):
@@ -591,6 +616,11 @@ func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
return return
} }
if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
result, err := h.authService.RefreshAccessToken(req.RefreshToken) result, err := h.authService.RefreshAccessToken(req.RefreshToken)
if !HandleServiceError(w, err, "Token refresh failed", http.StatusInternalServerError) { if !HandleServiceError(w, err, "Token refresh failed", http.StatusInternalServerError) {
return return
@@ -618,6 +648,11 @@ func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) {
return return
} }
if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
err := h.authService.RevokeRefreshToken(req.RefreshToken) err := h.authService.RevokeRefreshToken(req.RefreshToken)
if err != nil { if err != nil {
SendErrorResponse(w, "Failed to revoke token", http.StatusInternalServerError) SendErrorResponse(w, "Failed to revoke token", http.StatusInternalServerError)

View File

@@ -252,8 +252,8 @@ func TestAuthHandlerLoginSuccess(t *testing.T) {
} }
handler := newAuthHandler(repo) handler := newAuthHandler(repo)
body := bytes.NewBufferString(`{"username":"user","password":"Password123!"}`) bodyStr := `{"username":"user","password":"Password123!"}`
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", body) request := createLoginRequest(bodyStr)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.Login(recorder, request) handler.Login(recorder, request)
@@ -274,17 +274,17 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler := newAuthHandler(&testutils.UserRepositoryStub{}) handler := newAuthHandler(&testutils.UserRepositoryStub{})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString("invalid")) request := createLoginRequest("invalid")
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":" ","password":""}`)) request = createLoginRequest(`{"username":" ","password":""}`)
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"WrongPass123!"}`)) request = createLoginRequest(`{"username":"user","password":"WrongPass123!"}`)
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized) testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
@@ -294,7 +294,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
}} }}
handler = newAuthHandler(repo) handler = newAuthHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`)) request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden) testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden)
@@ -304,7 +304,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler = newAuthHandler(repo) handler = newAuthHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`)) request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError) testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError)
@@ -330,8 +330,7 @@ func TestAuthHandlerRegisterSuccess(t *testing.T) {
return nil return nil
}}) }})
body := bytes.NewBufferString(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`) request := createRegisterRequest(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`)
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", body)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.Register(recorder, request) handler.Register(recorder, request)
@@ -354,12 +353,12 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
handler := newAuthHandler(repo) handler := newAuthHandler(repo)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString("invalid")) request := createRegisterRequest("invalid")
handler.Register(recorder, request) handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"","email":"","password":""}`)) request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.Register(recorder, request) handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -368,7 +367,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
}} }}
handler = newAuthHandler(repo) handler = newAuthHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"new","email":"taken@example.com","password":"Password123!"}`)) request = createRegisterRequest(`{"username":"new","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request) handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict) testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
@@ -382,7 +381,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
} }
handler = newAuthHandler(repo) handler = newAuthHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"another","email":"taken@example.com","password":"Password123!"}`)) request = createRegisterRequest(`{"username":"another","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request) handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict) testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
} }
@@ -477,7 +476,7 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}}) }})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user@example.com"}`)) request := createForgotPasswordRequest(`{"username_or_email":"user@example.com"}`)
handler.RequestPasswordReset(recorder, request) handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK) testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
@@ -495,19 +494,19 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}}) }})
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user"}`)) request = createForgotPasswordRequest(`{"username_or_email":"user"}`)
handler.RequestPasswordReset(recorder, request) handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK) testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":""}`)) request = createForgotPasswordRequest(`{"username_or_email":""}`)
handler.RequestPasswordReset(recorder, request) handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`invalid json`)) request = createForgotPasswordRequest(`invalid json`)
handler.RequestPasswordReset(recorder, request) handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -518,25 +517,25 @@ func TestAuthHandlerResetPassword(t *testing.T) {
handler := newAuthHandler(repo) handler := newAuthHandler(repo)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"new_password":"NewPassword123!"}`)) request := createResetPasswordRequest(`{"new_password":"NewPassword123!"}`)
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token"}`)) request = createResetPasswordRequest(`{"token":"valid_token"}`)
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token","new_password":"short"}`)) request = createResetPasswordRequest(`{"token":"valid_token","new_password":"short"}`)
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`invalid json`)) request = createResetPasswordRequest(`invalid json`)
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -602,7 +601,7 @@ func TestAuthHandlerResetPasswordServiceOutcomes(t *testing.T) {
handler := newMockAuthHandler(repo, mockService) handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"abc","new_password":"Password123!"}`)) request := createResetPasswordRequest(`{"token":"abc","new_password":"Password123!"}`)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
@@ -664,7 +663,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
userID: 1, userID: 1,
mockSetup: func(repo *testutils.UserRepositoryStub) {}, mockSetup: func(repo *testutils.UserRepositoryStub) {},
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body", expectedError: "Invalid request",
}, },
{ {
name: "empty email", name: "empty email",
@@ -702,7 +701,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
handler := newMockAuthHandler(repo, mockService) handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/email", bytes.NewBufferString(tt.requestBody)) request := createUpdateEmailRequest(tt.requestBody)
if tt.userID > 0 { if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -789,7 +788,7 @@ func TestAuthHandlerUpdateUsername(t *testing.T) {
handler := newMockAuthHandler(repo, mockService) handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/username", bytes.NewBufferString(tt.requestBody)) request := createUpdateUsernameRequest(tt.requestBody)
if tt.userID > 0 { if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -886,7 +885,7 @@ func TestAuthHandlerUpdatePassword(t *testing.T) {
tt.mockSetup(repo) tt.mockSetup(repo)
handler := newAuthHandler(repo) handler := newAuthHandler(repo)
request := httptest.NewRequest(http.MethodPut, "/api/auth/password", bytes.NewBufferString(tt.requestBody)) request := createUpdatePasswordRequest(tt.requestBody)
if tt.userID > 0 { if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -984,8 +983,7 @@ func TestAuthHandlerDeleteAccount(t *testing.T) {
func TestAuthHandlerResendVerificationEmail(t *testing.T) { func TestAuthHandlerResendVerificationEmail(t *testing.T) {
makeRequest := func(body string, setup func(*mockAuthService)) (*httptest.ResponseRecorder, AuthResponse) { makeRequest := func(body string, setup func(*mockAuthService)) (*httptest.ResponseRecorder, AuthResponse) {
request := httptest.NewRequest(http.MethodPost, "/api/auth/resend-verification", bytes.NewBufferString(body)) request := createResendVerificationRequest(body)
request = request.WithContext(context.Background())
repo := &testutils.UserRepositoryStub{} repo := &testutils.UserRepositoryStub{}
mockService := &mockAuthService{} mockService := &mockAuthService{}
@@ -1014,7 +1012,7 @@ func TestAuthHandlerResendVerificationEmail(t *testing.T) {
name: "invalid json", name: "invalid json",
body: "not-json", body: "not-json",
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body", expectedError: "Invalid request",
}, },
{ {
name: "missing email", name: "missing email",
@@ -1139,7 +1137,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
name: "invalid json", name: "invalid json",
body: "not-json", body: "not-json",
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body", expectedError: "Invalid request",
}, },
{ {
name: "missing token", name: "missing token",
@@ -1209,7 +1207,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
handler := newMockAuthHandler(repo, mockService) handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/account/confirm", bytes.NewBufferString(tt.body)) request := createConfirmAccountDeletionRequest(tt.body)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.ConfirmAccountDeletion(recorder, request) handler.ConfirmAccountDeletion(recorder, request)
@@ -1338,9 +1336,7 @@ func TestAuthHandler_ConcurrentAccess(t *testing.T) {
for i := 0; i < concurrency; i++ { for i := 0; i < concurrency; i++ {
go func() { go func() {
body := bytes.NewBufferString(`{"username":"testuser","password":"Password123!"}`) req := createLoginRequest(`{"username":"testuser","password":"Password123!"}`)
req := httptest.NewRequest("POST", "/api/auth/login", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.Login(w, req) handler.Login(w, req)
@@ -1370,8 +1366,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
}, nil }, nil
} }
body := bytes.NewBufferString(`{"refresh_token":"valid_refresh_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"valid_refresh_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1381,8 +1376,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
}) })
t.Run("Invalid_Request_Body", func(t *testing.T) { t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`) req := createRefreshTokenRequest(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1392,8 +1386,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
}) })
t.Run("Missing_Refresh_Token", func(t *testing.T) { t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`) req := createRefreshTokenRequest(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1407,8 +1400,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenExpired return nil, services.ErrRefreshTokenExpired
} }
body := bytes.NewBufferString(`{"refresh_token":"expired_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"expired_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1422,8 +1414,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenInvalid return nil, services.ErrRefreshTokenInvalid
} }
body := bytes.NewBufferString(`{"refresh_token":"invalid_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"invalid_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1437,8 +1428,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrAccountLocked return nil, services.ErrAccountLocked
} }
body := bytes.NewBufferString(`{"refresh_token":"locked_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"locked_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1452,8 +1442,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, fmt.Errorf("internal error") return nil, fmt.Errorf("internal error")
} }
body := bytes.NewBufferString(`{"refresh_token":"error_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"error_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1473,8 +1462,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return nil return nil
} }
body := bytes.NewBufferString(`{"refresh_token":"token_to_revoke"}`) req := createRevokeTokenRequest(`{"refresh_token":"token_to_revoke"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1484,8 +1472,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
}) })
t.Run("Invalid_Request_Body", func(t *testing.T) { t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`) req := createRevokeTokenRequest(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1495,8 +1482,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
}) })
t.Run("Missing_Refresh_Token", func(t *testing.T) { t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`) req := createRevokeTokenRequest(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1510,8 +1496,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return fmt.Errorf("revoke failed") return fmt.Errorf("revoke failed")
} }
body := bytes.NewBufferString(`{"refresh_token":"token"}`) req := createRevokeTokenRequest(`{"refresh_token":"token"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()

View File

@@ -1,6 +1,7 @@
package handlers package handlers
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
@@ -11,6 +12,7 @@ import (
"testing" "testing"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/middleware" "goyco/internal/middleware"
"goyco/internal/services" "goyco/internal/services"
"goyco/internal/testutils" "goyco/internal/testutils"
@@ -721,6 +723,74 @@ func TestDecodeJSONRequest(t *testing.T) {
} }
} }
func createRequestWithDTO[T any](method, url string, body []byte) *http.Request {
r := httptest.NewRequest(method, url, bytes.NewReader(body))
var dto T
if len(body) > 0 {
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&dto); err != nil {
return r
}
}
ctx := middleware.SetValidatedDTOInContext(r.Context(), &dto)
return r.WithContext(ctx)
}
func createLoginRequest(body string) *http.Request {
return createRequestWithDTO[dto.LoginRequest](http.MethodPost, "/api/auth/login", []byte(body))
}
func createRegisterRequest(body string) *http.Request {
return createRequestWithDTO[dto.RegisterRequest](http.MethodPost, "/api/auth/register", []byte(body))
}
func createResendVerificationRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResendVerificationRequest](http.MethodPost, "/api/auth/resend-verification", []byte(body))
}
func createForgotPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ForgotPasswordRequest](http.MethodPost, "/api/auth/forgot-password", []byte(body))
}
func createResetPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResetPasswordRequest](http.MethodPost, "/api/auth/reset-password", []byte(body))
}
func createUpdateEmailRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateEmailRequest](http.MethodPut, "/api/auth/email", []byte(body))
}
func createUpdateUsernameRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateUsernameRequest](http.MethodPut, "/api/auth/username", []byte(body))
}
func createUpdatePasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePasswordRequest](http.MethodPut, "/api/auth/password", []byte(body))
}
func createConfirmAccountDeletionRequest(body string) *http.Request {
return createRequestWithDTO[dto.ConfirmAccountDeletionRequest](http.MethodPost, "/api/auth/account/confirm", []byte(body))
}
func createRefreshTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RefreshTokenRequest](http.MethodPost, "/api/auth/refresh", []byte(body))
}
func createRevokeTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RevokeTokenRequest](http.MethodPost, "/api/auth/revoke", []byte(body))
}
func createCreatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.CreatePostRequest](http.MethodPost, "/api/posts", []byte(body))
}
func createUpdatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePostRequest](http.MethodPut, "/api/posts/1", []byte(body))
}
func createVoteRequest(body string) *http.Request {
return createRequestWithDTO[dto.VoteRequest](http.MethodPost, "/api/posts/1/vote", []byte(body))
}
func TestParsePagination(t *testing.T) { func TestParsePagination(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -130,6 +130,11 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
url := security.SanitizeURL(req.URL) url := security.SanitizeURL(req.URL)
content := security.SanitizePostContent(req.Content) content := security.SanitizePostContent(req.Content)
if url == "" {
SendErrorResponse(w, "Invalid URL", http.StatusBadRequest)
return
}
if title == "" && h.titleFetcher != nil { if title == "" && h.titleFetcher != nil {
titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second) titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second)
defer cancel() defer cancel()
@@ -160,6 +165,16 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
return return
} }
if len(title) > 200 {
SendErrorResponse(w, "Title must be at most 200 characters", http.StatusBadRequest)
return
}
if len(content) > 10000 {
SendErrorResponse(w, "Content must be at most 10000 characters", http.StatusBadRequest)
return
}
post := &database.Post{ post := &database.Post{
Title: title, Title: title,
URL: url, URL: url,

View File

@@ -69,9 +69,8 @@ func TestPostHandlerCreatePostWithTitleFetcher(t *testing.T) {
handler := NewPostHandler(repo, titleFetcher, nil) handler := NewPostHandler(repo, titleFetcher, nil)
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"url":"https://example.com","content":"Test content"}`)) request := createCreatePostRequest(`{"url":"https://example.com","content":"Test content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1)) request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
@@ -171,7 +170,7 @@ func TestPostHandlerUpdatePostUnauthorized(t *testing.T) {
handler := NewPostHandler(repo, nil, nil) handler := NewPostHandler(repo, nil, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/1", bytes.NewBufferString(`{"title":"Updated Title","content":"Updated content"}`)) request := createUpdatePostRequest(`{"title":"Updated Title","content":"Updated content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1)) request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
@@ -278,8 +277,7 @@ func TestPostHandlerCreatePostSuccess(t *testing.T) {
handler := NewPostHandler(repo, fetcher, nil) handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com","content":"Go"}`) request := createCreatePostRequest(`{"title":" ","url":"https://example.com","content":"Go"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -297,7 +295,7 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
handler := NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil) handler := NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"","url":"","content":""}`)) request := createCreatePostRequest(`{"title":"","url":"","content":""}`)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1))) request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest { if recorder.Result().StatusCode != http.StatusBadRequest {
@@ -305,14 +303,14 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`invalid json`)) request = createCreatePostRequest(`invalid json`)
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest { if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid JSON, got %d", recorder.Result().StatusCode) t.Fatalf("expected 400 for invalid JSON, got %d", recorder.Result().StatusCode)
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"ok","url":"https://example.com"}`)) request = createCreatePostRequest(`{"title":"ok","url":"https://example.com"}`)
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized) testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
} }
@@ -329,15 +327,14 @@ func TestPostHandlerCreatePostTitleFetcherErrors(t *testing.T) {
{name: "Generic", err: errors.New("timeout"), wantStatus: http.StatusBadGateway, wantMsg: "Failed to fetch title"}, {name: "Generic", err: errors.New("timeout"), wantStatus: http.StatusBadGateway, wantMsg: "Failed to fetch title"},
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
repo := testutils.NewPostRepositoryStub() repo := testutils.NewPostRepositoryStub()
fetcher := &testutils.TitleFetcherStub{FetchTitleFn: func(ctx context.Context, rawURL string) (string, error) { fetcher := &testutils.TitleFetcherStub{FetchTitleFn: func(ctx context.Context, rawURL string) (string, error) {
return "", tc.err return "", tc.err
}} }}
handler := NewPostHandler(repo, fetcher, nil) handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com"}`) request := createCreatePostRequest(`{"title":" ","url":"https://example.com"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1))) request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -495,7 +492,7 @@ func TestPostHandlerUpdatePost(t *testing.T) {
} }
handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, nil) handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/"+tt.postID, bytes.NewBufferString(tt.requestBody)) request := createUpdatePostRequest(tt.requestBody)
if tt.userID > 0 { if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)

View File

@@ -41,7 +41,7 @@ func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
} }
body, _ := json.Marshal(postData) body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1))) request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -123,7 +123,7 @@ func TestPostHandler_InputValidation(t *testing.T) {
} }
body, _ := json.Marshal(postData) body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1))) request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -230,7 +230,7 @@ func TestAuthHandler_PasswordValidation(t *testing.T) {
} }
body, _ := json.Marshal(registerData) body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body)) request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -290,7 +290,7 @@ func TestAuthHandler_UsernameSanitization(t *testing.T) {
} }
body, _ := json.Marshal(registerData) body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body)) request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()

View File

@@ -1,7 +1,6 @@
package handlers package handlers
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -103,7 +102,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
return nil return nil
}}) }})
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`)) request := createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.CreateUser(recorder, request) handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated) testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
@@ -126,14 +125,14 @@ func TestUserHandlerCreateUser(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString("invalid")) request = createRegisterRequest("invalid")
handler.CreateUser(recorder, request) handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest { if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode) t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode)
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"","email":"","password":""}`)) request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.CreateUser(recorder, request) handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest { if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing fields, got %d", recorder.Result().StatusCode) t.Fatalf("expected 400 for missing fields, got %d", recorder.Result().StatusCode)
@@ -144,7 +143,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
} }
handler = newUserHandler(repo) handler = newUserHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`)) request = createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
handler.CreateUser(recorder, request) handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict) testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
} }
@@ -350,7 +349,7 @@ func TestUserHandler_PasswordValidation(t *testing.T) {
handler := NewUserHandler(repo, authService) handler := NewUserHandler(repo, authService)
requestBody := fmt.Sprintf(`{"username":"testuser","email":"test@example.com","password":"%s"}`, tt.password) requestBody := fmt.Sprintf(`{"username":"testuser","email":"test@example.com","password":"%s"}`, tt.password)
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(requestBody)) request := createRegisterRequest(requestBody)
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()

View File

@@ -1,7 +1,6 @@
package handlers package handlers
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -59,13 +58,13 @@ func TestVoteHandlerCastVote(t *testing.T) {
handler := newVoteHandlerWithRepos() handler := newVoteHandlerWithRepos()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
handler.CastVote(recorder, request) handler.CastVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized) testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/abc/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "abc"}) request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -73,7 +72,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`invalid`)) request = createVoteRequest(`invalid`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -83,7 +82,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"maybe"}`)) request = createVoteRequest(`{"type":"maybe"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -93,7 +92,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -101,7 +100,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusOK) testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`)) request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -111,7 +110,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`)) request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -125,7 +124,7 @@ func TestVoteHandlerCastVotePostNotFound(t *testing.T) {
handler, _, posts := newVoteHandlerWithReposRefs() handler, _, posts := newVoteHandlerWithReposRefs()
delete(posts, 1) delete(posts, 1)
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -164,7 +163,7 @@ func TestVoteHandlerRemoveVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -202,7 +201,7 @@ func TestVoteHandlerRemoveVotePostNotFound(t *testing.T) {
func TestVoteHandlerRemoveVoteUnexpectedError(t *testing.T) { func TestVoteHandlerRemoveVoteUnexpectedError(t *testing.T) {
handler, voteRepo, _ := newVoteHandlerWithReposRefs() handler, voteRepo, _ := newVoteHandlerWithReposRefs()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -257,7 +256,7 @@ func TestVoteHandlerGetUserVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -301,7 +300,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -311,7 +310,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`)) request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -345,7 +344,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1" postID := "1"
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -363,7 +362,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`)) request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID) ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -373,7 +372,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`)) request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID) ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -404,7 +403,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1" postID := "1"
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -414,7 +413,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`)) request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -424,7 +423,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -452,7 +451,7 @@ func TestVoteFlowRegression(t *testing.T) {
t.Run("ErrorHandlingEdgeCases", func(t *testing.T) { t.Run("ErrorHandlingEdgeCases", func(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(``)) request := createVoteRequest(``)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -460,7 +459,7 @@ func TestVoteFlowRegression(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{}`)) request = createVoteRequest(`{}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -470,7 +469,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"invalid"}`)) request = createVoteRequest(`{"type":"invalid"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)

View File

@@ -77,3 +77,7 @@ func GetDTOTypeFromContext(ctx context.Context) reflect.Type {
func GetValidatedDTOFromContext(ctx context.Context) any { func GetValidatedDTOFromContext(ctx context.Context) any {
return ctx.Value(validatedDTOKey) return ctx.Value(validatedDTOKey)
} }
func SetValidatedDTOInContext(ctx context.Context, dto any) context.Context {
return context.WithValue(ctx, validatedDTOKey, dto)
}