Compare commits

..

6 Commits

9 changed files with 205 additions and 101 deletions

View File

@@ -225,6 +225,11 @@ func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Req
email := strings.TrimSpace(req.Email)
if email == "" {
SendErrorResponse(w, "Email address is required", http.StatusBadRequest)
return
}
err := h.authService.ResendVerificationEmail(email)
if err != nil {
switch {
@@ -293,6 +298,11 @@ func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Reques
usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail)
if usernameOrEmail == "" {
SendErrorResponse(w, "Username or email is required", http.StatusBadRequest)
return
}
if err := h.authService.RequestPasswordReset(usernameOrEmail); err != nil {
}
@@ -319,6 +329,11 @@ func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
token := strings.TrimSpace(req.Token)
newPassword := strings.TrimSpace(req.NewPassword)
if token == "" {
SendErrorResponse(w, "Token is required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(newPassword); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
@@ -467,6 +482,11 @@ func (h *AuthHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
currentPassword := strings.TrimSpace(req.CurrentPassword)
newPassword := strings.TrimSpace(req.NewPassword)
if currentPassword == "" {
SendErrorResponse(w, "Current password is required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(newPassword); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
@@ -538,6 +558,11 @@ func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Requ
token := strings.TrimSpace(req.Token)
if token == "" {
SendErrorResponse(w, "Deletion token is required", http.StatusBadRequest)
return
}
if err := h.authService.ConfirmAccountDeletionWithPosts(token, req.DeletePosts); err != nil {
switch {
case errors.Is(err, services.ErrInvalidDeletionToken):
@@ -591,6 +616,11 @@ func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
return
}
if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
result, err := h.authService.RefreshAccessToken(req.RefreshToken)
if !HandleServiceError(w, err, "Token refresh failed", http.StatusInternalServerError) {
return
@@ -618,6 +648,11 @@ func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) {
return
}
if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
err := h.authService.RevokeRefreshToken(req.RefreshToken)
if err != nil {
SendErrorResponse(w, "Failed to revoke token", http.StatusInternalServerError)

View File

@@ -252,8 +252,8 @@ func TestAuthHandlerLoginSuccess(t *testing.T) {
}
handler := newAuthHandler(repo)
body := bytes.NewBufferString(`{"username":"user","password":"Password123!"}`)
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", body)
bodyStr := `{"username":"user","password":"Password123!"}`
request := createLoginRequest(bodyStr)
recorder := httptest.NewRecorder()
handler.Login(recorder, request)
@@ -274,17 +274,17 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler := newAuthHandler(&testutils.UserRepositoryStub{})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString("invalid"))
request := createLoginRequest("invalid")
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":" ","password":""}`))
request = createLoginRequest(`{"username":" ","password":""}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"WrongPass123!"}`))
request = createLoginRequest(`{"username":"user","password":"WrongPass123!"}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
@@ -294,7 +294,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
}}
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`))
request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden)
@@ -304,7 +304,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`))
request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError)
@@ -330,8 +330,7 @@ func TestAuthHandlerRegisterSuccess(t *testing.T) {
return nil
}})
body := bytes.NewBufferString(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`)
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", body)
request := createRegisterRequest(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`)
recorder := httptest.NewRecorder()
handler.Register(recorder, request)
@@ -354,12 +353,12 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
handler := newAuthHandler(repo)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString("invalid"))
request := createRegisterRequest("invalid")
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"","email":"","password":""}`))
request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -368,7 +367,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
}}
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"new","email":"taken@example.com","password":"Password123!"}`))
request = createRegisterRequest(`{"username":"new","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
@@ -382,7 +381,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
}
handler = newAuthHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"another","email":"taken@example.com","password":"Password123!"}`))
request = createRegisterRequest(`{"username":"another","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
}
@@ -477,7 +476,7 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user@example.com"}`))
request := createForgotPasswordRequest(`{"username_or_email":"user@example.com"}`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
@@ -495,19 +494,19 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}})
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user"}`))
request = createForgotPasswordRequest(`{"username_or_email":"user"}`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":""}`))
request = createForgotPasswordRequest(`{"username_or_email":""}`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`invalid json`))
request = createForgotPasswordRequest(`invalid json`)
handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -518,25 +517,25 @@ func TestAuthHandlerResetPassword(t *testing.T) {
handler := newAuthHandler(repo)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"new_password":"NewPassword123!"}`))
request := createResetPasswordRequest(`{"new_password":"NewPassword123!"}`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token"}`))
request = createResetPasswordRequest(`{"token":"valid_token"}`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token","new_password":"short"}`))
request = createResetPasswordRequest(`{"token":"valid_token","new_password":"short"}`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`invalid json`))
request = createResetPasswordRequest(`invalid json`)
handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -602,7 +601,7 @@ func TestAuthHandlerResetPasswordServiceOutcomes(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"abc","new_password":"Password123!"}`))
request := createResetPasswordRequest(`{"token":"abc","new_password":"Password123!"}`)
recorder := httptest.NewRecorder()
handler.ResetPassword(recorder, request)
@@ -664,7 +663,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
userID: 1,
mockSetup: func(repo *testutils.UserRepositoryStub) {},
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
expectedError: "Invalid request",
},
{
name: "empty email",
@@ -702,7 +701,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/email", bytes.NewBufferString(tt.requestBody))
request := createUpdateEmailRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -789,7 +788,7 @@ func TestAuthHandlerUpdateUsername(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/username", bytes.NewBufferString(tt.requestBody))
request := createUpdateUsernameRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -886,7 +885,7 @@ func TestAuthHandlerUpdatePassword(t *testing.T) {
tt.mockSetup(repo)
handler := newAuthHandler(repo)
request := httptest.NewRequest(http.MethodPut, "/api/auth/password", bytes.NewBufferString(tt.requestBody))
request := createUpdatePasswordRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
@@ -984,8 +983,7 @@ func TestAuthHandlerDeleteAccount(t *testing.T) {
func TestAuthHandlerResendVerificationEmail(t *testing.T) {
makeRequest := func(body string, setup func(*mockAuthService)) (*httptest.ResponseRecorder, AuthResponse) {
request := httptest.NewRequest(http.MethodPost, "/api/auth/resend-verification", bytes.NewBufferString(body))
request = request.WithContext(context.Background())
request := createResendVerificationRequest(body)
repo := &testutils.UserRepositoryStub{}
mockService := &mockAuthService{}
@@ -1014,7 +1012,7 @@ func TestAuthHandlerResendVerificationEmail(t *testing.T) {
name: "invalid json",
body: "not-json",
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
expectedError: "Invalid request",
},
{
name: "missing email",
@@ -1139,7 +1137,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
name: "invalid json",
body: "not-json",
expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body",
expectedError: "Invalid request",
},
{
name: "missing token",
@@ -1209,7 +1207,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/account/confirm", bytes.NewBufferString(tt.body))
request := createConfirmAccountDeletionRequest(tt.body)
recorder := httptest.NewRecorder()
handler.ConfirmAccountDeletion(recorder, request)
@@ -1338,9 +1336,7 @@ func TestAuthHandler_ConcurrentAccess(t *testing.T) {
for i := 0; i < concurrency; i++ {
go func() {
body := bytes.NewBufferString(`{"username":"testuser","password":"Password123!"}`)
req := httptest.NewRequest("POST", "/api/auth/login", body)
req.Header.Set("Content-Type", "application/json")
req := createLoginRequest(`{"username":"testuser","password":"Password123!"}`)
w := httptest.NewRecorder()
handler.Login(w, req)
@@ -1370,8 +1366,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
}, nil
}
body := bytes.NewBufferString(`{"refresh_token":"valid_refresh_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"valid_refresh_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1381,8 +1376,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
})
t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`invalid json`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1392,8 +1386,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
})
t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":""}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1407,8 +1400,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenExpired
}
body := bytes.NewBufferString(`{"refresh_token":"expired_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"expired_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1422,8 +1414,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenInvalid
}
body := bytes.NewBufferString(`{"refresh_token":"invalid_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"invalid_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1437,8 +1428,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrAccountLocked
}
body := bytes.NewBufferString(`{"refresh_token":"locked_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"locked_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1452,8 +1442,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, fmt.Errorf("internal error")
}
body := bytes.NewBufferString(`{"refresh_token":"error_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req := createRefreshTokenRequest(`{"refresh_token":"error_token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1473,8 +1462,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return nil
}
body := bytes.NewBufferString(`{"refresh_token":"token_to_revoke"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`{"refresh_token":"token_to_revoke"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1484,8 +1472,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
})
t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`invalid json`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1495,8 +1482,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
})
t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`{"refresh_token":""}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -1510,8 +1496,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return fmt.Errorf("revoke failed")
}
body := bytes.NewBufferString(`{"refresh_token":"token"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req := createRevokeTokenRequest(`{"refresh_token":"token"}`)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

View File

@@ -1,6 +1,7 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"errors"
@@ -11,6 +12,7 @@ import (
"testing"
"goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/middleware"
"goyco/internal/services"
"goyco/internal/testutils"
@@ -721,6 +723,74 @@ func TestDecodeJSONRequest(t *testing.T) {
}
}
func createRequestWithDTO[T any](method, url string, body []byte) *http.Request {
r := httptest.NewRequest(method, url, bytes.NewReader(body))
var dto T
if len(body) > 0 {
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&dto); err != nil {
return r
}
}
ctx := middleware.SetValidatedDTOInContext(r.Context(), &dto)
return r.WithContext(ctx)
}
func createLoginRequest(body string) *http.Request {
return createRequestWithDTO[dto.LoginRequest](http.MethodPost, "/api/auth/login", []byte(body))
}
func createRegisterRequest(body string) *http.Request {
return createRequestWithDTO[dto.RegisterRequest](http.MethodPost, "/api/auth/register", []byte(body))
}
func createResendVerificationRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResendVerificationRequest](http.MethodPost, "/api/auth/resend-verification", []byte(body))
}
func createForgotPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ForgotPasswordRequest](http.MethodPost, "/api/auth/forgot-password", []byte(body))
}
func createResetPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResetPasswordRequest](http.MethodPost, "/api/auth/reset-password", []byte(body))
}
func createUpdateEmailRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateEmailRequest](http.MethodPut, "/api/auth/email", []byte(body))
}
func createUpdateUsernameRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateUsernameRequest](http.MethodPut, "/api/auth/username", []byte(body))
}
func createUpdatePasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePasswordRequest](http.MethodPut, "/api/auth/password", []byte(body))
}
func createConfirmAccountDeletionRequest(body string) *http.Request {
return createRequestWithDTO[dto.ConfirmAccountDeletionRequest](http.MethodPost, "/api/auth/account/confirm", []byte(body))
}
func createRefreshTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RefreshTokenRequest](http.MethodPost, "/api/auth/refresh", []byte(body))
}
func createRevokeTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RevokeTokenRequest](http.MethodPost, "/api/auth/revoke", []byte(body))
}
func createCreatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.CreatePostRequest](http.MethodPost, "/api/posts", []byte(body))
}
func createUpdatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePostRequest](http.MethodPut, "/api/posts/1", []byte(body))
}
func createVoteRequest(body string) *http.Request {
return createRequestWithDTO[dto.VoteRequest](http.MethodPost, "/api/posts/1/vote", []byte(body))
}
func TestParsePagination(t *testing.T) {
tests := []struct {
name string

View File

@@ -130,6 +130,11 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
url := security.SanitizeURL(req.URL)
content := security.SanitizePostContent(req.Content)
if url == "" {
SendErrorResponse(w, "Invalid URL", http.StatusBadRequest)
return
}
if title == "" && h.titleFetcher != nil {
titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second)
defer cancel()
@@ -160,6 +165,16 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
return
}
if len(title) > 200 {
SendErrorResponse(w, "Title must be at most 200 characters", http.StatusBadRequest)
return
}
if len(content) > 10000 {
SendErrorResponse(w, "Content must be at most 10000 characters", http.StatusBadRequest)
return
}
post := &database.Post{
Title: title,
URL: url,

View File

@@ -69,9 +69,8 @@ func TestPostHandlerCreatePostWithTitleFetcher(t *testing.T) {
handler := NewPostHandler(repo, titleFetcher, nil)
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"url":"https://example.com","content":"Test content"}`))
request := createCreatePostRequest(`{"url":"https://example.com","content":"Test content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request)
@@ -171,7 +170,7 @@ func TestPostHandlerUpdatePostUnauthorized(t *testing.T) {
handler := NewPostHandler(repo, nil, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/1", bytes.NewBufferString(`{"title":"Updated Title","content":"Updated content"}`))
request := createUpdatePostRequest(`{"title":"Updated Title","content":"Updated content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
request.Header.Set("Content-Type", "application/json")
@@ -278,8 +277,7 @@ func TestPostHandlerCreatePostSuccess(t *testing.T) {
handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com","content":"Go"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
request := createCreatePostRequest(`{"title":" ","url":"https://example.com","content":"Go"}`)
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42))
request = request.WithContext(ctx)
@@ -297,7 +295,7 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
handler := NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"","url":"","content":""}`))
request := createCreatePostRequest(`{"title":"","url":"","content":""}`)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
@@ -305,14 +303,14 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`invalid json`))
request = createCreatePostRequest(`invalid json`)
handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid JSON, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"ok","url":"https://example.com"}`))
request = createCreatePostRequest(`{"title":"ok","url":"https://example.com"}`)
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
}
@@ -336,8 +334,7 @@ func TestPostHandlerCreatePostTitleFetcherErrors(t *testing.T) {
return "", tc.err
}}
handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
request := createCreatePostRequest(`{"title":" ","url":"https://example.com"}`)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
@@ -495,7 +492,7 @@ func TestPostHandlerUpdatePost(t *testing.T) {
}
handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/"+tt.postID, bytes.NewBufferString(tt.requestBody))
request := createUpdatePostRequest(tt.requestBody)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)

View File

@@ -41,7 +41,7 @@ func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
}
body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
@@ -123,7 +123,7 @@ func TestPostHandler_InputValidation(t *testing.T) {
}
body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
@@ -230,7 +230,7 @@ func TestAuthHandler_PasswordValidation(t *testing.T) {
}
body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
@@ -290,7 +290,7 @@ func TestAuthHandler_UsernameSanitization(t *testing.T) {
}
body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()

View File

@@ -1,7 +1,6 @@
package handlers
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@@ -103,7 +102,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
return nil
}})
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`))
request := createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
recorder := httptest.NewRecorder()
handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
@@ -126,14 +125,14 @@ func TestUserHandlerCreateUser(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString("invalid"))
request = createRegisterRequest("invalid")
handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"","email":"","password":""}`))
request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing fields, got %d", recorder.Result().StatusCode)
@@ -144,7 +143,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
}
handler = newUserHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`))
request = createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
}
@@ -350,7 +349,7 @@ func TestUserHandler_PasswordValidation(t *testing.T) {
handler := NewUserHandler(repo, authService)
requestBody := fmt.Sprintf(`{"username":"testuser","email":"test@example.com","password":"%s"}`, tt.password)
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(requestBody))
request := createRegisterRequest(requestBody)
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()

View File

@@ -1,7 +1,6 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -59,13 +58,13 @@ func TestVoteHandlerCastVote(t *testing.T) {
handler := newVoteHandlerWithRepos()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
handler.CastVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/abc/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -73,7 +72,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`invalid`))
request = createVoteRequest(`invalid`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -83,7 +82,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"maybe"}`))
request = createVoteRequest(`{"type":"maybe"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -93,7 +92,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -101,7 +100,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
@@ -111,7 +110,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`))
request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx)
@@ -125,7 +124,7 @@ func TestVoteHandlerCastVotePostNotFound(t *testing.T) {
handler, _, posts := newVoteHandlerWithReposRefs()
delete(posts, 1)
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -164,7 +163,7 @@ func TestVoteHandlerRemoveVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -202,7 +201,7 @@ func TestVoteHandlerRemoveVotePostNotFound(t *testing.T) {
func TestVoteHandlerRemoveVoteUnexpectedError(t *testing.T) {
handler, voteRepo, _ := newVoteHandlerWithReposRefs()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -257,7 +256,7 @@ func TestVoteHandlerGetUserVote(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -301,7 +300,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -311,7 +310,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
@@ -345,7 +344,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1"
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
@@ -363,7 +362,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
@@ -373,7 +372,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`))
request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
@@ -404,7 +403,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1"
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -414,7 +413,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
@@ -424,7 +423,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx)
@@ -452,7 +451,7 @@ func TestVoteFlowRegression(t *testing.T) {
t.Run("ErrorHandlingEdgeCases", func(t *testing.T) {
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(``))
request := createVoteRequest(``)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -460,7 +459,7 @@ func TestVoteFlowRegression(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{}`))
request = createVoteRequest(`{}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
@@ -470,7 +469,7 @@ func TestVoteFlowRegression(t *testing.T) {
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"invalid"}`))
request = createVoteRequest(`{"type":"invalid"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)

View File

@@ -77,3 +77,7 @@ func GetDTOTypeFromContext(ctx context.Context) reflect.Type {
func GetValidatedDTOFromContext(ctx context.Context) any {
return ctx.Value(validatedDTOKey)
}
func SetValidatedDTOInContext(ctx context.Context, dto any) context.Context {
return context.WithValue(ctx, validatedDTOKey, dto)
}