To gitea and beyond, let's go(-yco)

This commit is contained in:
2025-11-10 19:12:09 +01:00
parent 8f6133392d
commit 71a031342b
245 changed files with 83994 additions and 0 deletions

View File

@@ -0,0 +1,238 @@
package handlers
import (
"fmt"
"net/http"
"time"
"goyco/internal/config"
"goyco/internal/middleware"
"goyco/internal/repositories"
"goyco/internal/services"
"goyco/internal/version"
"github.com/go-chi/chi/v5"
"gorm.io/gorm"
)
type APIHandler struct {
config *config.Config
postRepo repositories.PostRepository
userRepo repositories.UserRepository
voteService *services.VoteService
dbMonitor middleware.DBMonitor
healthChecker *middleware.DatabaseHealthChecker
metricsCollector *middleware.MetricsCollector
}
func NewAPIHandler(config *config.Config, postRepo repositories.PostRepository, userRepo repositories.UserRepository, voteService *services.VoteService) *APIHandler {
return &APIHandler{
config: config,
postRepo: postRepo,
userRepo: userRepo,
voteService: voteService,
}
}
func NewAPIHandlerWithMonitoring(config *config.Config, postRepo repositories.PostRepository, userRepo repositories.UserRepository, voteService *services.VoteService, db *gorm.DB, dbMonitor middleware.DBMonitor) *APIHandler {
if db == nil {
return NewAPIHandler(config, postRepo, userRepo, voteService)
}
sqlDB, err := db.DB()
if err != nil {
return NewAPIHandler(config, postRepo, userRepo, voteService)
}
healthChecker := middleware.NewDatabaseHealthChecker(sqlDB, dbMonitor)
metricsCollector := middleware.NewMetricsCollector(dbMonitor)
return &APIHandler{
config: config,
postRepo: postRepo,
userRepo: userRepo,
voteService: voteService,
dbMonitor: dbMonitor,
healthChecker: healthChecker,
metricsCollector: metricsCollector,
}
}
type APIInfo = CommonResponse
func (h *APIHandler) GetAPIInfo(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api" {
http.NotFound(w, r)
return
}
apiInfo := map[string]any{
"name": fmt.Sprintf("%s API", h.config.App.Title),
"version": version.Version,
"description": "Y Combinator-style news board API",
"endpoints": map[string]any{
"authentication": map[string]any{
"POST /api/auth/register": "Register new user",
"POST /api/auth/login": "Login user",
"GET /api/auth/confirm": "Confirm email address",
"POST /api/auth/resend-verification": "Resend verification email",
"POST /api/auth/forgot-password": "Request password reset",
"POST /api/auth/reset-password": "Reset password",
"POST /api/auth/account/confirm": "Confirm account deletion",
"GET /api/auth/me": "Get current user profile",
"POST /api/auth/logout": "Logout user",
"PUT /api/auth/email": "Update email address",
"PUT /api/auth/username": "Update username",
"PUT /api/auth/password": "Update password",
"DELETE /api/auth/account": "Request account deletion",
},
"posts": map[string]any{
"GET /api/posts": "List all posts",
"GET /api/posts/search": "Search posts",
"GET /api/posts/title": "Fetch title from URL",
"GET /api/posts/{id}": "Get specific post",
"POST /api/posts": "Create new post",
"PUT /api/posts/{id}": "Update post",
"DELETE /api/posts/{id}": "Delete post",
},
"votes": map[string]any{
"POST /api/posts/{id}/vote": "Cast a vote",
"DELETE /api/posts/{id}/vote": "Remove vote",
"GET /api/posts/{id}/vote": "Get user's vote",
"GET /api/posts/{id}/votes": "Get all votes for post",
},
"users": map[string]any{
"GET /api/users": "List all users",
"POST /api/users": "Create new user",
"GET /api/users/{id}": "Get specific user",
"GET /api/users/{id}/posts": "Get user's posts",
},
"system": map[string]any{
"GET /health": "Health check",
"GET /metrics": "Service metrics",
},
},
"authentication": map[string]any{
"type": "Bearer Token (JWT)",
"note": "Include Authorization header with 'Bearer <token>' for protected endpoints",
},
"response_format": map[string]any{
"success": "boolean",
"message": "string",
"data": "object or array",
"error": "string (on error)",
},
}
SendSuccessResponse(w, "API information retrieved successfully", apiInfo)
}
func (h *APIHandler) GetHealth(w http.ResponseWriter, r *http.Request) {
if h.healthChecker != nil {
health := h.healthChecker.CheckHealth()
health["version"] = version.Version
SendSuccessResponse(w, "Health check successful", health)
return
}
currentTimestamp := time.Now().UTC().Format(time.RFC3339)
health := map[string]any{
"status": "healthy",
"timestamp": currentTimestamp,
"version": version.Version,
"services": map[string]any{
"database": "connected",
"api": "running",
},
}
SendSuccessResponse(w, "Health check successful", health)
}
func (h *APIHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
postCount, err := h.postRepo.Count()
if err != nil {
SendErrorResponse(w, "Failed to get post count", http.StatusInternalServerError)
return
}
userCount, err := h.userRepo.Count()
if err != nil {
SendErrorResponse(w, "Failed to get user count", http.StatusInternalServerError)
return
}
totalVoteCount, _, err := h.voteService.GetVoteStatistics()
if err != nil {
SendErrorResponse(w, "Failed to get vote statistics", http.StatusInternalServerError)
return
}
topPosts, err := h.postRepo.GetTopPosts(5)
if err != nil {
SendErrorResponse(w, "Failed to get top posts", http.StatusInternalServerError)
return
}
var avgVotesPerPost float64
if postCount > 0 {
avgVotesPerPost = float64(totalVoteCount) / float64(postCount)
}
var totalScore int
for _, post := range topPosts {
totalScore += post.Score
}
var avgScore float64
if len(topPosts) > 0 {
avgScore = float64(totalScore) / float64(len(topPosts))
}
metrics := map[string]any{
"posts": map[string]any{
"total_count": postCount,
"top_posts_count": len(topPosts),
"total_score": totalScore,
"average_score": avgScore,
},
"users": map[string]any{
"total_count": userCount,
},
"votes": map[string]any{
"total_count": totalVoteCount,
"average_per_post": avgVotesPerPost,
"note": "All votes are counted together",
},
"system": map[string]any{
"timestamp": time.Now().UTC().Format(time.RFC3339),
"version": version.Version,
},
}
if h.metricsCollector != nil {
performanceMetrics := h.metricsCollector.GetMetrics()
metrics["database"] = map[string]any{
"total_queries": performanceMetrics.DBStats.TotalQueries,
"slow_queries": performanceMetrics.DBStats.SlowQueries,
"average_duration": performanceMetrics.DBStats.AverageDuration.String(),
"max_duration": performanceMetrics.DBStats.MaxDuration.String(),
"error_count": performanceMetrics.DBStats.ErrorCount,
"last_query_time": performanceMetrics.DBStats.LastQueryTime.Format(time.RFC3339),
}
metrics["performance"] = map[string]any{
"request_count": performanceMetrics.RequestCount,
"average_response": performanceMetrics.AverageResponse.String(),
"max_response": performanceMetrics.MaxResponse.String(),
"error_count": performanceMetrics.ErrorCount,
}
}
SendSuccessResponse(w, "Metrics retrieved successfully", metrics)
}
func (h *APIHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
}

View File

@@ -0,0 +1,280 @@
package handlers
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"goyco/internal/database"
"goyco/internal/middleware"
"goyco/internal/repositories"
"goyco/internal/services"
"goyco/internal/testutils"
)
func TestAPIHandlerGetAPIInfo(t *testing.T) {
mockPostRepo := testutils.NewPostRepositoryStub()
mockUserRepo := testutils.NewUserRepositoryStub()
handler := newAPIHandlerForTest(mockPostRepo, mockUserRepo)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/api", nil)
handler.GetAPIInfo(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var resp APIInfo
if err := json.NewDecoder(recorder.Body).Decode(&resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if !resp.Success || resp.Message == "" {
t.Fatalf("expected success response, got %+v", resp)
}
data, ok := resp.Data.(map[string]any)
if !ok || data["name"] != fmt.Sprintf("%s API", testutils.AppTestConfig.App.Title) {
t.Fatalf("unexpected data payload: %#v", resp.Data)
}
endpoints, ok := data["endpoints"].(map[string]any)
if !ok {
t.Fatalf("expected endpoints map, got %#v", data["endpoints"])
}
authEndpoints := endpoints["authentication"].(map[string]any)
for _, route := range []string{
"POST /api/auth/resend-verification",
"POST /api/auth/account/confirm",
} {
if _, found := authEndpoints[route]; !found {
t.Fatalf("expected authentication catalogue to include %s", route)
}
}
systemEndpoints := endpoints["system"].(map[string]any)
if _, found := systemEndpoints["GET /metrics"]; !found {
t.Fatalf("expected system catalogue to include GET /metrics")
}
}
func TestAPIHandlerGetHealth(t *testing.T) {
mockPostRepo := testutils.NewPostRepositoryStub()
mockUserRepo := testutils.NewUserRepositoryStub()
handler := newAPIHandlerForTest(mockPostRepo, mockUserRepo)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/health", nil)
handler.GetHealth(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var resp APIInfo
if err := json.NewDecoder(recorder.Body).Decode(&resp); err != nil {
t.Fatalf("decode error: %v", err)
}
if !resp.Success || resp.Message == "" {
t.Fatalf("expected success message, got %+v", resp)
}
data := resp.Data.(map[string]any)
if data["status"] != "healthy" {
t.Fatalf("expected health status, got %+v", data)
}
}
func TestAPIHandlerGetMetrics(t *testing.T) {
mockPostRepo := testutils.NewPostRepositoryStub()
mockPostRepo.CountFn = func() (int64, error) { return 10, nil }
mockPostRepo.GetTopPostsFn = func(limit int) ([]database.Post, error) {
return []database.Post{
{ID: 1, Score: 100},
{ID: 2, Score: 50},
{ID: 3, Score: 25},
}, nil
}
mockUserRepo := testutils.NewUserRepositoryStub()
mockUserRepo.CountFn = func() (int64, error) { return 5, nil }
handler := newAPIHandlerForTest(mockPostRepo, mockUserRepo)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/metrics", nil)
handler.GetMetrics(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var resp APIInfo
if err := json.NewDecoder(recorder.Body).Decode(&resp); err != nil {
t.Fatalf("decode error: %v", err)
}
if !resp.Success || resp.Message == "" {
t.Fatalf("expected success response, got %+v", resp)
}
data, ok := resp.Data.(map[string]any)
if !ok {
t.Fatalf("expected metrics data map, got %T", resp.Data)
}
if data["posts"] == nil {
t.Fatalf("expected metrics payload to include posts")
}
if data["users"] == nil {
t.Fatalf("expected metrics payload to include users")
}
if data["votes"] == nil {
t.Fatalf("expected metrics payload to include votes")
}
if data["system"] == nil {
t.Fatalf("expected metrics payload to include system")
}
posts, ok := data["posts"].(map[string]any)
if !ok {
t.Fatalf("expected posts to be a map, got %T", data["posts"])
}
if posts["total_count"] != float64(10) {
t.Fatalf("expected posts total_count to be 10, got %v", posts["total_count"])
}
}
func newAPIHandlerForTest(postRepo repositories.PostRepository, userRepo repositories.UserRepository) *APIHandler {
voteRepo := testutils.NewMockVoteRepository()
voteService := services.NewVoteService(voteRepo, postRepo, nil)
return NewAPIHandler(testutils.AppTestConfig, postRepo, userRepo, voteService)
}
func TestAPIHandlerGetMetricsErrorHandling(t *testing.T) {
mockPostRepo := testutils.NewPostRepositoryStub()
mockPostRepo.CountFn = func() (int64, error) { return 0, errors.New("database error") }
mockUserRepo := testutils.NewUserRepositoryStub()
handler := newAPIHandlerForTest(mockPostRepo, mockUserRepo)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/metrics", nil)
handler.GetMetrics(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError)
var resp APIInfo
if err := json.NewDecoder(recorder.Body).Decode(&resp); err != nil {
t.Fatalf("decode error: %v", err)
}
if resp.Success {
t.Fatalf("expected error response, got %+v", resp)
}
}
func TestAPIHandlerGetMetricsWithDatabaseMonitoring(t *testing.T) {
mockPostRepo := testutils.NewPostRepositoryStub()
mockPostRepo.CountFn = func() (int64, error) { return 10, nil }
mockPostRepo.GetTopPostsFn = func(limit int) ([]database.Post, error) {
return []database.Post{
{ID: 1, Score: 100},
{ID: 2, Score: 50},
}, nil
}
mockUserRepo := testutils.NewUserRepositoryStub()
mockUserRepo.CountFn = func() (int64, error) { return 5, nil }
voteRepo := testutils.NewMockVoteRepository()
voteService := services.NewVoteService(voteRepo, mockPostRepo, nil)
handler := NewAPIHandler(testutils.AppTestConfig, mockPostRepo, mockUserRepo, voteService)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/metrics", nil)
handler.GetMetrics(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var resp APIInfo
if err := json.NewDecoder(recorder.Body).Decode(&resp); err != nil {
t.Fatalf("decode error: %v", err)
}
if !resp.Success {
t.Fatalf("expected success response, got %+v", resp)
}
data, ok := resp.Data.(map[string]any)
if !ok {
t.Fatalf("expected metrics data map, got %T", resp.Data)
}
expectedSections := []string{"posts", "users", "votes", "system"}
for _, section := range expectedSections {
if data[section] == nil {
t.Fatalf("expected metrics payload to include %s", section)
}
}
}
func TestNewAPIHandlerWithMonitoring(t *testing.T) {
mockPostRepo := testutils.NewPostRepositoryStub()
mockUserRepo := testutils.NewUserRepositoryStub()
voteRepo := testutils.NewMockVoteRepository()
voteService := services.NewVoteService(voteRepo, mockPostRepo, nil)
monitor := middleware.NewInMemoryDBMonitor()
db := testutils.NewTestDB(t)
defer func() {
sqlDB, _ := db.DB()
sqlDB.Close()
}()
handler := NewAPIHandlerWithMonitoring(testutils.AppTestConfig, mockPostRepo, mockUserRepo, voteService, db, monitor)
if handler == nil {
t.Fatal("Expected handler to be created")
}
if handler.dbMonitor == nil {
t.Error("Expected dbMonitor to be set")
}
if handler.healthChecker == nil {
t.Error("Expected healthChecker to be set")
}
if handler.metricsCollector == nil {
t.Error("Expected metricsCollector to be set")
}
}
func TestNewAPIHandlerWithMonitoring_NilDB(t *testing.T) {
mockPostRepo := testutils.NewPostRepositoryStub()
mockUserRepo := testutils.NewUserRepositoryStub()
voteRepo := testutils.NewMockVoteRepository()
voteService := services.NewVoteService(voteRepo, mockPostRepo, nil)
handler := NewAPIHandlerWithMonitoring(testutils.AppTestConfig, mockPostRepo, mockUserRepo, voteService, nil, nil)
if handler == nil {
t.Fatal("Expected handler to be created")
}
if handler.dbMonitor != nil {
t.Error("Expected dbMonitor to be nil when db is nil")
}
if handler.healthChecker != nil {
t.Error("Expected healthChecker to be nil when db is nil")
}
if handler.metricsCollector != nil {
t.Error("Expected metricsCollector to be nil when db is nil")
}
}

View File

@@ -0,0 +1,825 @@
package handlers
import (
"errors"
"net/http"
"strings"
"goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/repositories"
"goyco/internal/security"
"goyco/internal/services"
"goyco/internal/validation"
"github.com/go-chi/chi/v5"
)
type AuthServiceInterface interface {
Login(username, password string) (*services.AuthResult, error)
Register(username, email, password string) (*services.RegistrationResult, error)
ConfirmEmail(token string) (*database.User, error)
ResendVerificationEmail(email string) error
RequestPasswordReset(usernameOrEmail string) error
ResetPassword(token, newPassword string) error
UpdateEmail(userID uint, email string) (*database.User, error)
UpdateUsername(userID uint, username string) (*database.User, error)
UpdatePassword(userID uint, currentPassword, newPassword string) (*database.User, error)
RequestAccountDeletion(userID uint) error
ConfirmAccountDeletionWithPosts(token string, deletePosts bool) error
RefreshAccessToken(refreshToken string) (*services.AuthResult, error)
RevokeRefreshToken(refreshToken string) error
RevokeAllUserTokens(userID uint) error
InvalidateAllSessions(userID uint) error
GetAdminEmail() string
VerifyToken(tokenString string) (uint, error)
GetUserIDFromDeletionToken(token string) (uint, error)
UserHasPosts(userID uint) (bool, int64, error)
}
type AuthHandler struct {
authService AuthServiceInterface
userRepo repositories.UserRepository
}
type AuthResponse = CommonResponse
type AuthTokensResponse struct {
Success bool `json:"success" example:"true"`
Message string `json:"message" example:"Authentication successful"`
Data AuthTokensDetail `json:"data"`
}
type AuthTokensDetail struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token" example:"f94d4ddc7d9b4fcb9d3a2c44c400b780"`
User AuthUserSummary `json:"user"`
}
type AuthUserSummary struct {
ID uint `json:"id" example:"42"`
Username string `json:"username" example:"janedoe"`
Email string `json:"email" example:"jane@example.com"`
EmailVerified bool `json:"email_verified" example:"true"`
Locked bool `json:"locked" example:"false"`
}
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type RegisterRequest struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
}
type CreatePostRequest struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
}
type ResendVerificationRequest struct {
Email string `json:"email"`
}
type ForgotPasswordRequest struct {
UsernameOrEmail string `json:"username_or_email"`
}
type ResetPasswordRequest struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
type UpdateEmailRequest struct {
Email string `json:"email"`
}
type UpdateUsernameRequest struct {
Username string `json:"username"`
}
type UpdatePasswordRequest struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
type ConfirmAccountDeletionRequest struct {
Token string `json:"token"`
DeletePosts bool `json:"delete_posts"`
}
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." binding:"required"`
}
type RevokeTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." binding:"required"`
}
func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.UserRepository) *AuthHandler {
return &AuthHandler{
authService: authService,
userRepo: userRepo,
}
}
// @Summary Login user
// @Description Authenticate user with username and password
// @Tags auth
// @Accept json
// @Produce json
// @Param request body LoginRequest true "Login credentials"
// @Success 200 {object} AuthTokensResponse "Authentication successful"
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
// @Failure 401 {object} AuthResponse "Invalid credentials"
// @Failure 403 {object} AuthResponse "Account is locked"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /auth/login [post]
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
username := security.SanitizeUsername(req.Username)
password := strings.TrimSpace(req.Password)
if username == "" || password == "" {
SendErrorResponse(w, "Username and password are required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(password); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
result, err := h.authService.Login(username, password)
if !HandleServiceError(w, err, "Authentication failed", http.StatusInternalServerError) {
return
}
SendSuccessResponse(w, "Authentication successful", result)
}
// @Summary Register a new user
// @Description Register a new user with username, email and password
// @Tags auth
// @Accept json
// @Produce json
// @Param request body RegisterRequest true "Registration data"
// @Success 201 {object} AuthResponse "Registration successful"
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
// @Failure 409 {object} AuthResponse "Username or email already exists"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /auth/register [post]
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
username := strings.TrimSpace(req.Username)
email := strings.TrimSpace(req.Email)
password := strings.TrimSpace(req.Password)
if username == "" || email == "" || password == "" {
SendErrorResponse(w, "Username, email, and password are required", http.StatusBadRequest)
return
}
username = security.SanitizeUsername(username)
if err := validation.ValidateUsername(username); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := validation.ValidateEmail(email); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(password); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
result, err := h.authService.Register(username, email, password)
if err != nil {
var validationErr *validation.ValidationError
if errors.As(err, &validationErr) {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if !HandleServiceError(w, err, "Registration failed", http.StatusInternalServerError) {
return
}
}
userData := map[string]any{
"id": result.User.ID,
"username": result.User.Username,
"email": result.User.Email,
"email_verified": result.User.EmailVerified,
"created_at": result.User.CreatedAt,
"updated_at": result.User.UpdatedAt,
"deleted_at": result.User.DeletedAt,
}
responseData := map[string]any{
"user": userData,
"verification_sent": result.VerificationSent,
}
SendCreatedResponse(w, "Registration successful. Check your email to confirm your account.", responseData)
}
// @Summary Confirm email address
// @Description Confirm user email with verification token
// @Tags auth
// @Accept json
// @Produce json
// @Param token query string true "Email verification token"
// @Success 200 {object} AuthResponse "Email confirmed successfully"
// @Failure 400 {object} AuthResponse "Invalid or missing token"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /auth/confirm [get]
func (h *AuthHandler) ConfirmEmail(w http.ResponseWriter, r *http.Request) {
token := strings.TrimSpace(r.URL.Query().Get("token"))
if token == "" {
SendErrorResponse(w, "Verification token is required", http.StatusBadRequest)
return
}
user, err := h.authService.ConfirmEmail(token)
if !HandleServiceError(w, err, "Unable to verify email", http.StatusInternalServerError) {
return
}
userDTO := dto.ToUserDTO(user)
SendSuccessResponse(w, "Email confirmed successfully", map[string]any{
"user": userDTO,
})
}
// @Summary Resend verification email
// @Description Send a new verification email to the provided address
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ResendVerificationRequest true "Email address"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 404 {object} AuthResponse
// @Failure 409 {object} AuthResponse
// @Failure 429 {object} AuthResponse
// @Failure 503 {object} AuthResponse
// @Failure 500 {object} AuthResponse
// @Router /auth/resend-verification [post]
func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Request) {
var req struct {
Email string `json:"email"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
email := strings.TrimSpace(req.Email)
if email == "" {
SendErrorResponse(w, "Email address is required", http.StatusBadRequest)
return
}
err := h.authService.ResendVerificationEmail(email)
if err != nil {
switch {
case errors.Is(err, services.ErrInvalidCredentials):
SendErrorResponse(w, "No account found with this email address", http.StatusNotFound)
case errors.Is(err, services.ErrInvalidEmail):
SendErrorResponse(w, "Invalid email address format", http.StatusBadRequest)
case errors.Is(err, services.ErrEmailSenderUnavailable):
SendErrorResponse(w, "We couldn't send the verification email. Try again later.", http.StatusServiceUnavailable)
case err.Error() == "email already verified":
SendErrorResponse(w, "This email address is already verified", http.StatusConflict)
case err.Error() == "verification email sent recently, please wait before requesting another":
SendErrorResponse(w, "Please wait 5 minutes before requesting another verification email", http.StatusTooManyRequests)
default:
SendErrorResponse(w, "Unable to resend verification email", http.StatusInternalServerError)
}
return
}
SendSuccessResponse(w, "Verification email sent successfully", map[string]any{
"message": "Check your inbox for the verification link",
})
}
// @Summary Get current user profile
// @Description Retrieve the authenticated user's profile information
// @Tags auth
// @Accept json
// @Produce json
// @Security BearerAuth
// @Success 200 {object} AuthResponse "User profile retrieved successfully"
// @Failure 401 {object} AuthResponse "Authentication required"
// @Failure 404 {object} AuthResponse "User not found"
// @Router /auth/me [get]
func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
userID, ok := RequireAuth(w, r)
if !ok {
return
}
user, err := h.userRepo.GetByID(userID)
if err != nil {
SendErrorResponse(w, "User not found", http.StatusNotFound)
return
}
userDTO := dto.ToUserDTO(user)
SendSuccessResponse(w, "User profile fetched", userDTO)
}
// @Summary Request a password reset
// @Description Send a password reset email using a username or email
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ForgotPasswordRequest true "Username or email"
// @Success 200 {object} AuthResponse "Password reset email sent if account exists"
// @Failure 400 {object} AuthResponse "Invalid request data"
// @Router /auth/forgot-password [post]
func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) {
var req struct {
UsernameOrEmail string `json:"username_or_email"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail)
if usernameOrEmail == "" {
SendErrorResponse(w, "Username or email is required", http.StatusBadRequest)
return
}
if err := h.authService.RequestPasswordReset(usernameOrEmail); err != nil {
}
SendSuccessResponse(w, "If an account with that username or email exists, we've sent a password reset link.", nil)
}
// @Summary Reset password
// @Description Reset a user's password using a reset token
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ResetPasswordRequest true "Password reset data"
// @Success 200 {object} AuthResponse "Password reset successfully"
// @Failure 400 {object} AuthResponse "Invalid or expired token, or validation failed"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /auth/reset-password [post]
func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
var req struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
token := strings.TrimSpace(req.Token)
newPassword := strings.TrimSpace(req.NewPassword)
if token == "" {
SendErrorResponse(w, "Reset token is required", http.StatusBadRequest)
return
}
if newPassword == "" {
SendErrorResponse(w, "New password is required", http.StatusBadRequest)
return
}
if len(newPassword) < 8 {
SendErrorResponse(w, "Password must be at least 8 characters long", http.StatusBadRequest)
return
}
if err := h.authService.ResetPassword(token, newPassword); err != nil {
switch {
case strings.Contains(err.Error(), "expired"):
SendErrorResponse(w, "The reset link has expired. Please request a new one.", http.StatusBadRequest)
case strings.Contains(err.Error(), "invalid"):
SendErrorResponse(w, "The reset link is invalid. Please request a new one.", http.StatusBadRequest)
default:
SendErrorResponse(w, "Unable to reset password. Please try again later.", http.StatusInternalServerError)
}
return
}
SendSuccessResponse(w, "Password reset successfully. You can now sign in with your new password.", nil)
}
// @Summary Update email address
// @Description Update the authenticated user's email address
// @Tags auth
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body UpdateEmailRequest true "New email address"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse
// @Failure 409 {object} AuthResponse
// @Failure 503 {object} AuthResponse
// @Failure 500 {object} AuthResponse
// @Router /auth/email [put]
func (h *AuthHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
userID, ok := RequireAuth(w, r)
if !ok {
return
}
var req struct {
Email string `json:"email"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
email := strings.TrimSpace(req.Email)
if err := validation.ValidateEmail(email); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
user, err := h.authService.UpdateEmail(userID, email)
if err != nil {
switch {
case errors.Is(err, services.ErrEmailTaken):
SendErrorResponse(w, "That email is already in use. Choose another one.", http.StatusConflict)
case errors.Is(err, services.ErrEmailSenderUnavailable):
SendErrorResponse(w, "We couldn't send the confirmation email. Try again later.", http.StatusServiceUnavailable)
case errors.Is(err, services.ErrInvalidEmail):
SendErrorResponse(w, "Invalid email address", http.StatusBadRequest)
default:
SendErrorResponse(w, "We couldn't update your email right now.", http.StatusInternalServerError)
}
return
}
userDTO := dto.ToUserDTO(user)
SendSuccessResponse(w, "Email updated. Check your inbox to confirm the new address.", map[string]any{
"user": userDTO,
})
}
// @Summary Update username
// @Description Update the authenticated user's username
// @Tags auth
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body UpdateUsernameRequest true "New username"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse
// @Failure 409 {object} AuthResponse
// @Failure 500 {object} AuthResponse
// @Router /auth/username [put]
func (h *AuthHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
userID, ok := RequireAuth(w, r)
if !ok {
return
}
var req struct {
Username string `json:"username"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
username := strings.TrimSpace(req.Username)
if err := validation.ValidateUsername(username); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
user, err := h.authService.UpdateUsername(userID, username)
if err != nil {
switch {
case errors.Is(err, services.ErrUsernameTaken):
SendErrorResponse(w, "That username is already taken. Try another one.", http.StatusConflict)
default:
SendErrorResponse(w, "We couldn't update your username right now.", http.StatusInternalServerError)
}
return
}
userDTO := dto.ToUserDTO(user)
SendSuccessResponse(w, "Username updated successfully.", map[string]any{
"user": userDTO,
})
}
// @Summary Update password
// @Description Update the authenticated user's password
// @Tags auth
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body UpdatePasswordRequest true "Password update data"
// @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse
// @Failure 500 {object} AuthResponse
// @Router /auth/password [put]
func (h *AuthHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
userID, ok := RequireAuth(w, r)
if !ok {
return
}
var req struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
currentPassword := strings.TrimSpace(req.CurrentPassword)
newPassword := strings.TrimSpace(req.NewPassword)
if currentPassword == "" {
SendErrorResponse(w, "Current password is required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(newPassword); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
user, err := h.authService.UpdatePassword(userID, currentPassword, newPassword)
if err != nil {
if strings.Contains(err.Error(), "current password is incorrect") {
SendErrorResponse(w, "Current password is incorrect", http.StatusBadRequest)
} else {
SendErrorResponse(w, "We couldn't update your password right now.", http.StatusInternalServerError)
}
return
}
userDTO := dto.ToUserDTO(user)
SendSuccessResponse(w, "Password updated successfully.", map[string]any{
"user": userDTO,
})
}
// @Summary Request account deletion
// @Description Initiate the deletion process for the authenticated user's account
// @Tags auth
// @Accept json
// @Produce json
// @Security BearerAuth
// @Success 200 {object} AuthResponse "Deletion email sent"
// @Failure 401 {object} AuthResponse "Authentication required"
// @Failure 503 {object} AuthResponse "Email delivery unavailable"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /auth/account [delete]
func (h *AuthHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
userID, ok := RequireAuth(w, r)
if !ok {
return
}
err := h.authService.RequestAccountDeletion(userID)
if err != nil {
if errors.Is(err, services.ErrEmailSenderUnavailable) {
SendErrorResponse(w, "Account deletion isn't available right now because email delivery is disabled.", http.StatusServiceUnavailable)
} else {
SendErrorResponse(w, "We couldn't start the deletion process right now.", http.StatusInternalServerError)
}
return
}
SendSuccessResponse(w, "Check your inbox for a confirmation link to finish deleting your account.", nil)
}
// @Summary Confirm account deletion
// @Description Confirm account deletion using the provided token
// @Tags auth
// @Accept json
// @Produce json
// @Param request body ConfirmAccountDeletionRequest true "Account deletion data"
// @Success 200 {object} AuthResponse "Account deleted successfully"
// @Failure 400 {object} AuthResponse "Invalid or expired token"
// @Failure 503 {object} AuthResponse "Email delivery unavailable"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /auth/account/confirm [post]
func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
var req struct {
Token string `json:"token"`
DeletePosts bool `json:"delete_posts"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
token := strings.TrimSpace(req.Token)
if token == "" {
SendErrorResponse(w, "Deletion token is required", http.StatusBadRequest)
return
}
if err := h.authService.ConfirmAccountDeletionWithPosts(token, req.DeletePosts); err != nil {
switch {
case errors.Is(err, services.ErrInvalidDeletionToken):
SendErrorResponse(w, "This deletion link is invalid or has expired.", http.StatusBadRequest)
case errors.Is(err, services.ErrEmailSenderUnavailable):
SendErrorResponse(w, "Account deletion isn't available right now because email delivery is disabled.", http.StatusServiceUnavailable)
case errors.Is(err, services.ErrDeletionEmailFailed):
SendSuccessResponse(w, "Your account has been deleted, but we couldn't send the confirmation email.", map[string]any{
"posts_deleted": req.DeletePosts,
})
default:
SendErrorResponse(w, "We couldn't confirm the deletion right now.", http.StatusInternalServerError)
}
return
}
SendSuccessResponse(w, "Your account has been deleted.", map[string]any{
"posts_deleted": req.DeletePosts,
})
}
// @Summary Logout user
// @Description Logout the authenticated user and invalidate their session
// @Tags auth
// @Accept json
// @Produce json
// @Security BearerAuth
// @Success 200 {object} AuthResponse "Logged out successfully"
// @Failure 401 {object} AuthResponse "Authentication required"
// @Router /auth/logout [post]
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
SendSuccessResponse(w, "Logged out successfully", nil)
}
// @Summary Refresh access token
// @Description Use a refresh token to get a new access token. This endpoint allows clients to obtain a new access token using a valid refresh token without requiring user credentials.
// @Tags auth
// @Accept json
// @Produce json
// @Param request body RefreshTokenRequest true "Refresh token data"
// @Success 200 {object} AuthTokensResponse "Token refreshed successfully"
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
// @Failure 401 {object} AuthResponse "Invalid or expired refresh token"
// @Failure 403 {object} AuthResponse "Account is locked"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /auth/refresh [post]
func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
var req RefreshTokenRequest
if !DecodeJSONRequest(w, r, &req) {
return
}
if strings.TrimSpace(req.RefreshToken) == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
result, err := h.authService.RefreshAccessToken(req.RefreshToken)
if !HandleServiceError(w, err, "Token refresh failed", http.StatusInternalServerError) {
return
}
SendSuccessResponse(w, "Token refreshed successfully", result)
}
// @Summary Revoke refresh token
// @Description Revoke a specific refresh token. This endpoint allows authenticated users to invalidate a specific refresh token, preventing its future use.
// @Tags auth
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body RevokeTokenRequest true "Token revocation data"
// @Success 200 {object} AuthResponse "Token revoked successfully"
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
// @Failure 401 {object} AuthResponse "Invalid or expired access token"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /auth/revoke [post]
func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) {
var req RevokeTokenRequest
if !DecodeJSONRequest(w, r, &req) {
return
}
if strings.TrimSpace(req.RefreshToken) == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
err := h.authService.RevokeRefreshToken(req.RefreshToken)
if err != nil {
SendErrorResponse(w, "Failed to revoke token", http.StatusInternalServerError)
return
}
SendSuccessResponse(w, "Token revoked successfully", nil)
}
// @Summary Revoke all user tokens
// @Description Revoke all refresh tokens for the authenticated user. This endpoint allows users to invalidate all their refresh tokens at once, effectively logging them out from all devices.
// @Tags auth
// @Accept json
// @Produce json
// @Security BearerAuth
// @Success 200 {object} AuthResponse "All tokens revoked successfully"
// @Failure 401 {object} AuthResponse "Invalid or expired access token"
// @Failure 500 {object} AuthResponse "Internal server error"
// @Router /auth/revoke-all [post]
func (h *AuthHandler) RevokeAllTokens(w http.ResponseWriter, r *http.Request) {
userID, ok := RequireAuth(w, r)
if !ok {
return
}
err := h.authService.RevokeAllUserTokens(userID)
if err != nil {
SendErrorResponse(w, "Failed to revoke tokens", http.StatusInternalServerError)
return
}
SendSuccessResponse(w, "All tokens revoked successfully", nil)
}
func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
if config.GeneralRateLimit != nil {
rateLimited := config.GeneralRateLimit(r)
rateLimited.Post("/auth/refresh", h.RefreshToken)
rateLimited.Get("/auth/confirm", h.ConfirmEmail)
rateLimited.Post("/auth/resend-verification", h.ResendVerificationEmail)
} else {
r.Post("/auth/refresh", h.RefreshToken)
r.Get("/auth/confirm", h.ConfirmEmail)
r.Post("/auth/resend-verification", h.ResendVerificationEmail)
}
if config.AuthRateLimit != nil {
rateLimited := config.AuthRateLimit(r)
rateLimited.Post("/auth/register", h.Register)
rateLimited.Post("/auth/login", h.Login)
rateLimited.Post("/auth/forgot-password", h.RequestPasswordReset)
rateLimited.Post("/auth/reset-password", h.ResetPassword)
rateLimited.Post("/auth/account/confirm", h.ConfirmAccountDeletion)
} else {
r.Post("/auth/register", h.Register)
r.Post("/auth/login", h.Login)
r.Post("/auth/forgot-password", h.RequestPasswordReset)
r.Post("/auth/reset-password", h.ResetPassword)
r.Post("/auth/account/confirm", h.ConfirmAccountDeletion)
}
protected := r
if config.AuthMiddleware != nil {
protected = r.With(config.AuthMiddleware)
}
if config.GeneralRateLimit != nil {
protected = config.GeneralRateLimit(protected)
}
protected.Get("/auth/me", h.Me)
protected.Post("/auth/logout", h.Logout)
protected.Post("/auth/revoke", h.RevokeToken)
protected.Post("/auth/revoke-all", h.RevokeAllTokens)
protected.Put("/auth/email", h.UpdateEmail)
protected.Put("/auth/username", h.UpdateUsername)
protected.Put("/auth/password", h.UpdatePassword)
protected.Delete("/auth/account", h.DeleteAccount)
}

File diff suppressed because it is too large Load Diff

292
internal/handlers/common.go Normal file
View File

@@ -0,0 +1,292 @@
package handlers
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/middleware"
"goyco/internal/services"
"github.com/go-chi/chi/v5"
"gorm.io/gorm"
)
type CommonResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
type PaginationData struct {
Count int `json:"count"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type VoteCookieData struct {
Type database.VoteType `json:"type"`
Timestamp int64 `json:"timestamp"`
}
func sendResponse(w http.ResponseWriter, statusCode int, success bool, message string, data any, errMsg string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
response := CommonResponse{
Success: success,
Message: message,
Data: data,
Error: errMsg,
}
json.NewEncoder(w).Encode(response)
}
func SendSuccessResponse(w http.ResponseWriter, message string, data any) {
sendResponse(w, http.StatusOK, true, message, data, "")
}
func SendCreatedResponse(w http.ResponseWriter, message string, data any) {
sendResponse(w, http.StatusCreated, true, message, data, "")
}
func SendErrorResponse(w http.ResponseWriter, message string, statusCode int) {
sendResponse(w, statusCode, false, "", nil, message)
}
func DecodeJSONRequest(w http.ResponseWriter, r *http.Request, req any) bool {
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
SendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
return false
}
return true
}
func GetClientIP(r *http.Request) string {
return middleware.GetSecureClientIP(r)
}
const (
CookieMaxAgeDays = 30
SecondsPerDay = 86400
DefaultPaginationLimit = 20
DefaultPaginationOffset = 0
)
func SetVoteCookie(w http.ResponseWriter, r *http.Request, postID uint, voteType database.VoteType) {
cookieName := fmt.Sprintf("vote_%d", postID)
cookieValue := fmt.Sprintf("%s:%d", voteType, time.Now().Unix())
cookie := &http.Cookie{
Name: cookieName,
Value: cookieValue,
Path: "/",
MaxAge: SecondsPerDay * CookieMaxAgeDays,
HttpOnly: true,
Secure: IsHTTPS(r),
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(w, cookie)
}
func GetVoteCookie(r *http.Request, postID uint) string {
cookieName := fmt.Sprintf("vote_%d", postID)
cookie, err := r.Cookie(cookieName)
if err != nil {
return ""
}
return cookie.Value
}
func ClearVoteCookie(w http.ResponseWriter, postID uint) {
cookieName := fmt.Sprintf("vote_%d", postID)
cookie := &http.Cookie{
Name: cookieName,
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
}
http.SetCookie(w, cookie)
}
func IsHTTPS(r *http.Request) bool {
if r.TLS != nil {
return true
}
if proto := r.Header.Get("X-Forwarded-Proto"); proto == "https" {
return true
}
if proto := r.Header.Get("X-Forwarded-Ssl"); proto == "on" {
return true
}
if proto := r.Header.Get("X-Forwarded-Scheme"); proto == "https" {
return true
}
return false
}
func SanitizeUser(user *database.User) dto.SanitizedUserDTO {
if user == nil {
return dto.SanitizedUserDTO{}
}
return dto.ToSanitizedUserDTO(user)
}
func SanitizeUsers(users []database.User) []dto.SanitizedUserDTO {
return dto.ToSanitizedUserDTOs(users)
}
func parsePagination(r *http.Request) (limit, offset int) {
limit = DefaultPaginationLimit
offset = DefaultPaginationOffset
limitStr := r.URL.Query().Get("limit")
if limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
limit = l
}
}
offsetStr := r.URL.Query().Get("offset")
if offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
offset = o
}
}
return limit, offset
}
func ValidateRedirectURL(redirectURL string) string {
redirectURL = strings.TrimSpace(redirectURL)
if redirectURL == "" || len(redirectURL) > 512 {
return ""
}
if !strings.HasPrefix(redirectURL, "/") || strings.HasPrefix(redirectURL, "//") {
return ""
}
parsed, err := url.Parse(redirectURL)
if err != nil || parsed.Scheme != "" || parsed.Host != "" || parsed.User != nil || parsed.Path == "" {
return ""
}
path := parsed.EscapedPath()
if path == "" {
path = parsed.Path
}
validated := path
if parsed.RawQuery != "" {
validated += "?" + parsed.RawQuery
}
if parsed.Fragment != "" {
validated += "#" + parsed.Fragment
}
return validated
}
func ParseUintParam(w http.ResponseWriter, r *http.Request, paramName, entityName string) (uint, bool) {
str := chi.URLParam(r, paramName)
if str == "" {
SendErrorResponse(w, entityName+" ID is required", http.StatusBadRequest)
return 0, false
}
id, err := strconv.ParseUint(str, 10, 32)
if err != nil {
SendErrorResponse(w, "Invalid "+entityName+" ID", http.StatusBadRequest)
return 0, false
}
return uint(id), true
}
func RequireAuth(w http.ResponseWriter, r *http.Request) (uint, bool) {
userID := middleware.GetUserIDFromContext(r.Context())
if userID == 0 {
SendErrorResponse(w, "Authentication required", http.StatusUnauthorized)
return 0, false
}
return userID, true
}
func NewVoteContext(r *http.Request) services.VoteContext {
return services.VoteContext{
UserID: middleware.GetUserIDFromContext(r.Context()),
IPAddress: GetClientIP(r),
UserAgent: r.UserAgent(),
}
}
func HandleRepoError(w http.ResponseWriter, err error, entityName string) bool {
if err == nil {
return true
}
if errors.Is(err, gorm.ErrRecordNotFound) {
SendErrorResponse(w, entityName+" not found", http.StatusNotFound)
} else {
SendErrorResponse(w, "Failed to retrieve "+entityName, http.StatusInternalServerError)
}
return false
}
var AuthErrorMapping = []struct {
err error
msg string
code int
}{
{services.ErrInvalidCredentials, "Invalid username or password", http.StatusUnauthorized},
{services.ErrEmailNotVerified, "Please confirm your email before logging in", http.StatusForbidden},
{services.ErrAccountLocked, "Your account has been locked. Please contact us for assistance.", http.StatusForbidden},
{services.ErrUsernameTaken, "Username is already taken", http.StatusConflict},
{services.ErrEmailTaken, "Email is already registered", http.StatusConflict},
{services.ErrInvalidEmail, "Invalid email address", http.StatusBadRequest},
{services.ErrPasswordTooShort, "Password must be at least 8 characters", http.StatusBadRequest},
{services.ErrInvalidVerificationToken, "Invalid or expired verification token", http.StatusBadRequest},
{services.ErrRefreshTokenExpired, "Refresh token has expired", http.StatusUnauthorized},
{services.ErrRefreshTokenInvalid, "Invalid refresh token", http.StatusUnauthorized},
{services.ErrInvalidDeletionToken, "This deletion link is invalid or has expired.", http.StatusBadRequest},
{services.ErrDeletionRequestNotFound, "Deletion request not found", http.StatusBadRequest},
{services.ErrUserNotFound, "User not found", http.StatusNotFound},
{services.ErrEmailSenderUnavailable, "Email service is unavailable. Please try again later.", http.StatusServiceUnavailable},
}
func HandleServiceError(w http.ResponseWriter, err error, defaultMsg string, defaultCode int) bool {
if err == nil {
return true
}
for _, mapping := range AuthErrorMapping {
if err == mapping.err || errors.Is(err, mapping.err) {
SendErrorResponse(w, mapping.msg, mapping.code)
return false
}
}
errMsg := err.Error()
for _, mapping := range AuthErrorMapping {
if mapping.err.Error() == errMsg {
SendErrorResponse(w, mapping.msg, mapping.code)
return false
}
}
SendErrorResponse(w, defaultMsg, defaultCode)
return false
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,146 @@
package handlers
import (
"net/http/httptest"
"strings"
"testing"
"unicode/utf8"
"goyco/internal/fuzz"
)
func FuzzJSONParsing(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
testCases := []map[string]any{
{
"name": "auth_login",
"body": `{"username":"FUZZED_INPUT","password":"test"}`,
},
{
"name": "auth_register",
"body": `{"username":"FUZZED_INPUT","email":"test@example.com","password":"test123"}`,
},
{
"name": "post_create",
"body": `{"title":"FUZZED_INPUT","url":"https://example.com","content":"test"}`,
},
{
"name": "vote_cast",
"body": `{"type":"FUZZED_INPUT"}`,
},
}
helper.RunJSONFuzzTest(f, testCases)
}
func FuzzURLParsing(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
helper.RunBasicFuzzTest(f, func(t *testing.T, input string) {
sanitized := ""
for _, char := range input {
if (char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') ||
(char >= '0' && char <= '9') || char == '-' || char == '_' {
sanitized += string(char)
}
}
if len(sanitized) > 20 {
sanitized = sanitized[:20]
}
if len(sanitized) == 0 {
return
}
url := "/api/posts/" + sanitized
req := httptest.NewRequest("GET", url, nil)
pathParts := strings.Split(req.URL.Path, "/")
if len(pathParts) >= 4 {
idStr := pathParts[3]
_ = idStr
}
})
}
func FuzzQueryParameters(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
helper.RunBasicFuzzTest(f, func(t *testing.T, input string) {
if !utf8.ValidString(input) {
return
}
sanitized := ""
for _, char := range input {
if char >= 32 && char <= 126 {
switch char {
case ' ', '\n', '\r', '\t':
continue
case '&':
sanitized += "%26"
case '=':
sanitized += "%3D"
case '?':
sanitized += "%3F"
case '#':
sanitized += "%23"
case '/':
sanitized += "%2F"
case '\\':
sanitized += "%5C"
default:
sanitized += string(char)
}
}
}
if len(sanitized) > 100 {
sanitized = sanitized[:100]
}
if len(sanitized) == 0 {
return
}
query := "?q=" + sanitized + "&limit=10&offset=0"
req := httptest.NewRequest("GET", "/api/posts/search"+query, nil)
q := req.URL.Query().Get("q")
limit := req.URL.Query().Get("limit")
offset := req.URL.Query().Get("offset")
if !utf8.ValidString(q) {
return
}
_ = limit
_ = offset
})
}
func FuzzHTTPHeaders(f *testing.F) {
helper := fuzz.NewFuzzTestHelper()
helper.RunBasicFuzzTest(f, func(t *testing.T, input string) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", "Bearer "+input)
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("User-Agent", input)
req.Header.Set("X-Forwarded-For", input)
for name, values := range req.Header {
if !utf8.ValidString(name) {
t.Fatal("Header name contains invalid UTF-8")
}
for _, value := range values {
if !utf8.ValidString(value) {
t.Fatal("Header value contains invalid UTF-8")
}
}
}
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,464 @@
package handlers
import (
"context"
"errors"
"net/http"
"strings"
"time"
"goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/repositories"
"goyco/internal/security"
"goyco/internal/services"
"goyco/internal/validation"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgconn"
)
type PostHandler struct {
postRepo repositories.PostRepository
titleFetcher services.TitleFetcher
voteService *services.VoteService
postQueries *services.PostQueries
}
func NewPostHandler(postRepo repositories.PostRepository, titleFetcher services.TitleFetcher, voteService *services.VoteService) *PostHandler {
return &PostHandler{
postRepo: postRepo,
titleFetcher: titleFetcher,
voteService: voteService,
postQueries: services.NewPostQueries(postRepo, voteService),
}
}
type PostResponse = CommonResponse
type UpdatePostRequest struct {
Title string `json:"title"`
Content string `json:"content"`
}
// @Summary Get posts
// @Description Get a list of posts with pagination. Posts include vote statistics (up_votes, down_votes, score) and current user's vote status.
// @Tags posts
// @Accept json
// @Produce json
// @Param limit query int false "Number of posts to return" default(20)
// @Param offset query int false "Number of posts to skip" default(0)
// @Success 200 {object} PostResponse "Posts retrieved successfully with vote statistics"
// @Failure 400 {object} PostResponse "Invalid pagination parameters"
// @Failure 500 {object} PostResponse "Internal server error"
// @Router /posts [get]
func (h *PostHandler) GetPosts(w http.ResponseWriter, r *http.Request) {
limit, offset := parsePagination(r)
opts := services.QueryOptions{
Limit: limit,
Offset: offset,
}
ctx := NewVoteContext(r)
posts, err := h.postQueries.GetAll(opts, ctx)
if err != nil {
SendErrorResponse(w, "Failed to fetch posts", http.StatusInternalServerError)
return
}
postDTOs := dto.ToPostDTOs(posts)
SendSuccessResponse(w, "Posts retrieved successfully", map[string]any{
"posts": postDTOs,
"count": len(postDTOs),
"limit": limit,
"offset": offset,
})
}
// @Summary Get a single post
// @Description Get a post by ID with vote statistics and current user's vote status
// @Tags posts
// @Accept json
// @Produce json
// @Param id path int true "Post ID"
// @Success 200 {object} PostResponse "Post retrieved successfully with vote statistics"
// @Failure 400 {object} PostResponse "Invalid post ID"
// @Failure 404 {object} PostResponse "Post not found"
// @Failure 500 {object} PostResponse "Internal server error"
// @Router /posts/{id} [get]
func (h *PostHandler) GetPost(w http.ResponseWriter, r *http.Request) {
postID, ok := ParseUintParam(w, r, "id", "Post")
if !ok {
return
}
ctx := NewVoteContext(r)
post, err := h.postQueries.GetByID(postID, ctx)
if !HandleRepoError(w, err, "Post") {
return
}
postDTO := dto.ToPostDTO(post)
SendSuccessResponse(w, "Post retrieved successfully", postDTO)
}
// @Summary Create a new post
// @Description Create a new post with URL and optional title
// @Tags posts
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body CreatePostRequest true "Post data"
// @Success 201 {object} PostResponse
// @Failure 400 {object} PostResponse "Invalid request data or validation failed"
// @Failure 401 {object} PostResponse "Authentication required"
// @Failure 409 {object} PostResponse "URL already submitted"
// @Failure 502 {object} PostResponse "Failed to fetch title from URL"
// @Failure 500 {object} PostResponse "Internal server error"
// @Router /posts [post]
func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
var req struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
req.Title = security.SanitizeInput(req.Title)
req.URL = security.SanitizeURL(req.URL)
req.Content = security.SanitizePostContent(req.Content)
if req.URL == "" {
SendErrorResponse(w, "URL is required", http.StatusBadRequest)
return
}
if len(req.Title) > 200 {
SendErrorResponse(w, "Title must be no more than 200 characters", http.StatusBadRequest)
return
}
if len(req.Content) > 10000 {
SendErrorResponse(w, "Content must be no more than 10,000 characters", http.StatusBadRequest)
return
}
userID, ok := RequireAuth(w, r)
if !ok {
return
}
title := req.Title
if title == "" && h.titleFetcher != nil {
titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second)
defer cancel()
fetchedTitle, err := h.titleFetcher.FetchTitle(titleCtx, req.URL)
if err != nil {
switch {
case errors.Is(err, services.ErrUnsupportedScheme):
SendErrorResponse(w, "Only HTTP and HTTPS URLs are supported", http.StatusBadRequest)
case errors.Is(err, services.ErrTitleNotFound):
SendErrorResponse(w, "Title could not be extracted from the provided URL", http.StatusBadRequest)
default:
SendErrorResponse(w, "Failed to fetch title from URL", http.StatusBadGateway)
}
return
}
title = fetchedTitle
}
if title == "" {
SendErrorResponse(w, "Title is required", http.StatusBadRequest)
return
}
if len(title) < 3 {
SendErrorResponse(w, "Title must be at least 3 characters", http.StatusBadRequest)
return
}
post := &database.Post{
Title: title,
URL: req.URL,
Content: req.Content,
AuthorID: &userID,
}
if err := h.postRepo.Create(post); err != nil {
if errMsg, status := translatePostCreateError(err); status != 0 {
SendErrorResponse(w, errMsg, status)
return
}
SendErrorResponse(w, "Failed to create post", http.StatusInternalServerError)
return
}
postDTO := dto.ToPostDTO(post)
SendCreatedResponse(w, "Post created successfully", postDTO)
}
// @Summary Search posts
// @Description Search posts by title or content keywords. Results include vote statistics and current user's vote status.
// @Tags posts
// @Accept json
// @Produce json
// @Param q query string false "Search term"
// @Param limit query int false "Number of posts to return" default(20)
// @Param offset query int false "Number of posts to skip" default(0)
// @Success 200 {object} PostResponse "Search results with vote statistics"
// @Failure 400 {object} PostResponse "Invalid search parameters"
// @Failure 500 {object} PostResponse "Internal server error"
// @Router /posts/search [get]
func (h *PostHandler) SearchPosts(w http.ResponseWriter, r *http.Request) {
query := strings.TrimSpace(r.URL.Query().Get("q"))
limit, offset := parsePagination(r)
opts := services.QueryOptions{
Limit: limit,
Offset: offset,
}
ctx := NewVoteContext(r)
posts, err := h.postQueries.GetSearch(query, opts, ctx)
if err != nil {
if searchErr, ok := err.(*repositories.SearchError); ok {
SendErrorResponse(w, "Invalid search query: "+searchErr.Message, http.StatusBadRequest)
return
}
SendErrorResponse(w, "Failed to search posts", http.StatusInternalServerError)
return
}
postDTOs := dto.ToPostDTOs(posts)
SendSuccessResponse(w, "Search results retrieved successfully", map[string]any{
"posts": postDTOs,
"count": len(postDTOs),
"query": query,
"limit": limit,
"offset": offset,
})
}
// @Summary Update a post
// @Description Update the title and content of a post owned by the authenticated user
// @Tags posts
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param id path int true "Post ID"
// @Param request body UpdatePostRequest true "Post update data"
// @Success 200 {object} PostResponse "Post updated successfully"
// @Failure 400 {object} PostResponse "Invalid request data or validation failed"
// @Failure 401 {object} PostResponse "Authentication required"
// @Failure 403 {object} PostResponse "Not authorized to update this post"
// @Failure 404 {object} PostResponse "Post not found"
// @Failure 500 {object} PostResponse "Internal server error"
// @Router /posts/{id} [put]
func (h *PostHandler) UpdatePost(w http.ResponseWriter, r *http.Request) {
userID, ok := RequireAuth(w, r)
if !ok {
return
}
postID, ok := ParseUintParam(w, r, "id", "Post")
if !ok {
return
}
post, err := h.postRepo.GetByID(postID)
if !HandleRepoError(w, err, "Post") {
return
}
if post.AuthorID == nil || *post.AuthorID != userID {
SendErrorResponse(w, "You can only edit your own posts", http.StatusForbidden)
return
}
var req struct {
Title string `json:"title"`
Content string `json:"content"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
req.Title = security.SanitizeInput(req.Title)
req.Content = security.SanitizePostContent(req.Content)
if len(req.Title) > 200 {
SendErrorResponse(w, "Title must be no more than 200 characters", http.StatusBadRequest)
return
}
if len(req.Content) > 10000 {
SendErrorResponse(w, "Content must be no more than 10,000 characters", http.StatusBadRequest)
return
}
if err := validation.ValidateTitle(req.Title); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := validation.ValidateContent(req.Content); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
post.Title = req.Title
post.Content = req.Content
if err := h.postRepo.Update(post); err != nil {
SendErrorResponse(w, "Failed to update post", http.StatusInternalServerError)
return
}
postDTO := dto.ToPostDTO(post)
SendSuccessResponse(w, "Post updated successfully", postDTO)
}
// @Summary Delete a post
// @Description Delete a post owned by the authenticated user
// @Tags posts
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param id path int true "Post ID"
// @Success 200 {object} PostResponse "Post deleted successfully"
// @Failure 400 {object} PostResponse "Invalid post ID"
// @Failure 401 {object} PostResponse "Authentication required"
// @Failure 403 {object} PostResponse "Not authorized to delete this post"
// @Failure 404 {object} PostResponse "Post not found"
// @Failure 500 {object} PostResponse "Internal server error"
// @Router /posts/{id} [delete]
func (h *PostHandler) DeletePost(w http.ResponseWriter, r *http.Request) {
userID, ok := RequireAuth(w, r)
if !ok {
return
}
postID, ok := ParseUintParam(w, r, "id", "Post")
if !ok {
return
}
post, err := h.postRepo.GetByID(postID)
if !HandleRepoError(w, err, "Post") {
return
}
if post.AuthorID == nil || *post.AuthorID != userID {
SendErrorResponse(w, "You can only delete your own posts", http.StatusForbidden)
return
}
if err := h.voteService.DeleteVotesByPostID(postID); err != nil {
SendErrorResponse(w, "Failed to delete post votes", http.StatusInternalServerError)
return
}
if err := h.postRepo.Delete(postID); err != nil {
SendErrorResponse(w, "Failed to delete post", http.StatusInternalServerError)
return
}
SendSuccessResponse(w, "Post deleted successfully", nil)
}
// @Summary Fetch title from URL
// @Description Fetch the HTML title for the provided URL
// @Tags posts
// @Accept json
// @Produce json
// @Param url query string true "URL to inspect"
// @Success 200 {object} PostResponse "Title fetched successfully"
// @Failure 400 {object} PostResponse "Invalid URL or URL parameter missing"
// @Failure 501 {object} PostResponse "Title fetching is not available"
// @Failure 502 {object} PostResponse "Failed to fetch title from URL"
// @Router /posts/title [get]
func (h *PostHandler) FetchTitleFromURL(w http.ResponseWriter, r *http.Request) {
if h.titleFetcher == nil {
SendErrorResponse(w, "Title fetching is not available", http.StatusNotImplemented)
return
}
requestedURL := strings.TrimSpace(r.URL.Query().Get("url"))
if requestedURL == "" {
SendErrorResponse(w, "URL query parameter is required", http.StatusBadRequest)
return
}
titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second)
defer cancel()
title, err := h.titleFetcher.FetchTitle(titleCtx, requestedURL)
if err != nil {
switch {
case errors.Is(err, services.ErrUnsupportedScheme):
SendErrorResponse(w, "Only HTTP and HTTPS URLs are supported", http.StatusBadRequest)
case errors.Is(err, services.ErrTitleNotFound):
SendErrorResponse(w, "Title could not be extracted from the provided URL", http.StatusBadRequest)
default:
SendErrorResponse(w, "Failed to fetch title from URL", http.StatusBadGateway)
}
return
}
SendSuccessResponse(w, "Title fetched successfully", map[string]string{
"title": title,
})
}
func translatePostCreateError(err error) (string, int) {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
switch pgErr.Code {
case "23505":
return "This URL has already been submitted.", http.StatusConflict
case "23503":
return "Author account not found. Please sign in again.", http.StatusUnauthorized
}
}
errStr := err.Error()
if strings.Contains(errStr, "UNIQUE constraint") || strings.Contains(errStr, "duplicate") {
return "This URL has already been submitted.", http.StatusConflict
}
return "", 0
}
func (h *PostHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
public := r
if config.GeneralRateLimit != nil {
public = config.GeneralRateLimit(r)
}
public.Get("/posts", h.GetPosts)
public.Get("/posts/search", h.SearchPosts)
public.Get("/posts/title", h.FetchTitleFromURL)
public.Get("/posts/{id}", h.GetPost)
protected := r
if config.AuthMiddleware != nil {
protected = r.With(config.AuthMiddleware)
}
if config.GeneralRateLimit != nil {
protected = config.GeneralRateLimit(protected)
}
protected.Post("/posts", h.CreatePost)
protected.Put("/posts/{id}", h.UpdatePost)
protected.Delete("/posts/{id}", h.DeletePost)
}

View File

@@ -0,0 +1,711 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgconn"
"gorm.io/gorm"
"goyco/internal/database"
"goyco/internal/middleware"
"goyco/internal/repositories"
"goyco/internal/services"
"goyco/internal/testutils"
)
func decodeHandlerResponse(t *testing.T, rr *httptest.ResponseRecorder) map[string]any {
t.Helper()
var payload map[string]any
if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
return payload
}
func TestPostHandlerGetPostsWithVoteService(t *testing.T) {
repo := testutils.NewPostRepositoryStub()
repo.GetAllFn = func(limit, offset int) ([]database.Post, error) {
return []database.Post{
{ID: 1, Title: "Test Post 1"},
{ID: 2, Title: "Test Post 2"},
}, nil
}
voteRepo := testutils.NewMockVoteRepository()
voteService := services.NewVoteService(voteRepo, repo, nil)
handler := NewPostHandler(repo, nil, voteService)
request := httptest.NewRequest(http.MethodGet, "/api/posts", nil)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
recorder := httptest.NewRecorder()
handler.GetPosts(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
payload := decodeHandlerResponse(t, recorder)
if !payload["success"].(bool) {
t.Fatalf("expected success response, got %v", payload)
}
}
func TestPostHandlerCreatePostWithTitleFetcher(t *testing.T) {
repo := testutils.NewPostRepositoryStub()
var storedPost *database.Post
repo.CreateFn = func(post *database.Post) error {
storedPost = post
return nil
}
titleFetcher := &testutils.MockTitleFetcher{}
titleFetcher.SetTitle("Fetched Title")
handler := NewPostHandler(repo, titleFetcher, nil)
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"url":"https://example.com","content":"Test content"}`))
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
if storedPost == nil {
t.Fatal("expected post to be created")
}
if storedPost.Title != "Fetched Title" {
t.Errorf("expected title 'Fetched Title', got %s", storedPost.Title)
}
}
func TestPostHandlerCreatePostTitleFetcherError(t *testing.T) {
repo := testutils.NewPostRepositoryStub()
titleFetcher := &testutils.MockTitleFetcher{}
titleFetcher.SetError(services.ErrUnsupportedScheme)
handler := NewPostHandler(repo, titleFetcher, nil)
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"url":"ftp://example.com"}`))
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
payload := decodeHandlerResponse(t, recorder)
if payload["success"].(bool) {
t.Fatalf("expected error response, got %v", payload)
}
}
func TestPostHandlerSearchPosts(t *testing.T) {
repo := testutils.NewPostRepositoryStub()
repo.SearchFn = func(query string, limit, offset int) ([]database.Post, error) {
return []database.Post{
{ID: 1, Title: "Search Result 1"},
{ID: 2, Title: "Search Result 2"},
}, nil
}
handler := NewPostHandler(repo, nil, nil)
request := httptest.NewRequest(http.MethodGet, "/api/posts/search?q=test", nil)
recorder := httptest.NewRecorder()
handler.SearchPosts(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
payload := decodeHandlerResponse(t, recorder)
if !payload["success"].(bool) {
t.Fatalf("expected success response, got %v", payload)
}
}
func TestPostHandlerFetchTitleFromURL(t *testing.T) {
titleFetcher := &testutils.MockTitleFetcher{}
titleFetcher.SetTitle("Test Title")
handler := NewPostHandler(nil, titleFetcher, nil)
request := httptest.NewRequest(http.MethodGet, "/api/posts/title?url=https://example.com", nil)
recorder := httptest.NewRecorder()
handler.FetchTitleFromURL(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
payload := decodeHandlerResponse(t, recorder)
if !payload["success"].(bool) {
t.Fatalf("expected success response, got %v", payload)
}
}
func TestPostHandlerFetchTitleFromURLNoFetcher(t *testing.T) {
handler := NewPostHandler(nil, nil, nil)
request := httptest.NewRequest(http.MethodGet, "/api/posts/title?url=https://example.com", nil)
recorder := httptest.NewRecorder()
handler.FetchTitleFromURL(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusNotImplemented)
}
func TestPostHandlerUpdatePostUnauthorized(t *testing.T) {
repo := testutils.NewPostRepositoryStub()
repo.GetByIDFn = func(id uint) (*database.Post, error) {
return &database.Post{ID: id, AuthorID: func() *uint { u := uint(2); return &u }()}, nil
}
handler := NewPostHandler(repo, nil, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/1", bytes.NewBufferString(`{"title":"Updated Title","content":"Updated content"}`))
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.UpdatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden)
}
func TestPostHandlerDeletePostUnauthorized(t *testing.T) {
repo := testutils.NewPostRepositoryStub()
repo.GetByIDFn = func(id uint) (*database.Post, error) {
return &database.Post{ID: id, AuthorID: func() *uint { u := uint(2); return &u }()}, nil
}
voteRepo := testutils.NewMockVoteRepository()
voteService := services.NewVoteService(voteRepo, repo, nil)
handler := NewPostHandler(repo, nil, voteService)
request := httptest.NewRequest(http.MethodDelete, "/api/posts/1", nil)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
recorder := httptest.NewRecorder()
handler.DeletePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden)
}
func TestPostHandlerGetPosts(t *testing.T) {
var receivedLimit, receivedOffset int
repo := testutils.NewPostRepositoryStub()
repo.GetAllFn = func(limit, offset int) ([]database.Post, error) {
receivedLimit = limit
receivedOffset = offset
return []database.Post{{ID: 1}}, nil
}
handler := NewPostHandler(repo, nil, nil)
request := httptest.NewRequest(http.MethodGet, "/api/posts?limit=5&offset=2", nil)
recorder := httptest.NewRecorder()
handler.GetPosts(recorder, request)
if receivedLimit != 5 || receivedOffset != 2 {
t.Fatalf("expected limit=5 offset=2, got %d %d", receivedLimit, receivedOffset)
}
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
payload := decodeHandlerResponse(t, recorder)
if !payload["success"].(bool) {
t.Fatalf("expected success response, got %v", payload)
}
}
func TestPostHandlerGetPostErrors(t *testing.T) {
repo := testutils.NewPostRepositoryStub()
handler := NewPostHandler(repo, nil, nil)
request := httptest.NewRequest(http.MethodGet, "/api/posts", nil)
recorder := httptest.NewRecorder()
handler.GetPost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing id, got %d", recorder.Result().StatusCode)
}
request = httptest.NewRequest(http.MethodGet, "/api/posts/abc", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
recorder = httptest.NewRecorder()
handler.GetPost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid id, got %d", recorder.Result().StatusCode)
}
repo.GetByIDFn = func(uint) (*database.Post, error) {
return nil, gorm.ErrRecordNotFound
}
request = httptest.NewRequest(http.MethodGet, "/api/posts/1", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
recorder = httptest.NewRecorder()
handler.GetPost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusNotFound)
}
func TestPostHandlerCreatePostSuccess(t *testing.T) {
var storedPost *database.Post
repo := testutils.NewPostRepositoryStub()
repo.CreateFn = func(post *database.Post) error {
storedPost = &database.Post{
Title: post.Title,
URL: post.URL,
Content: post.Content,
AuthorID: post.AuthorID,
}
storedPost.ID = 1
return nil
}
fetcher := &testutils.TitleFetcherStub{FetchTitleFn: func(ctx context.Context, rawURL string) (string, error) {
return "Fetched Title", nil
}}
handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com","content":"Go"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42))
request = request.WithContext(ctx)
recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
if storedPost == nil || storedPost.Title != "Fetched Title" || storedPost.AuthorID == nil || *storedPost.AuthorID != 42 {
t.Fatalf("unexpected stored post: %#v", storedPost)
}
}
func TestPostHandlerCreatePostValidation(t *testing.T) {
handler := NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"","url":"","content":""}`))
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing url, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`invalid json`))
handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid JSON, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"ok","url":"https://example.com"}`))
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
}
func TestPostHandlerCreatePostTitleFetcherErrors(t *testing.T) {
tests := []struct {
name string
err error
wantStatus int
wantMsg string
}{
{name: "Unsupported", err: services.ErrUnsupportedScheme, wantStatus: http.StatusBadRequest, wantMsg: "Only HTTP and HTTPS URLs are supported"},
{name: "TitleMissing", err: services.ErrTitleNotFound, wantStatus: http.StatusBadRequest, wantMsg: "Title could not be extracted"},
{name: "Generic", err: errors.New("timeout"), wantStatus: http.StatusBadGateway, wantMsg: "Failed to fetch title"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
repo := testutils.NewPostRepositoryStub()
fetcher := &testutils.TitleFetcherStub{FetchTitleFn: func(ctx context.Context, rawURL string) (string, error) {
return "", tc.err
}}
handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, tc.wantStatus)
if !strings.Contains(recorder.Body.String(), tc.wantMsg) {
t.Fatalf("expected message to contain %q, got %q", tc.wantMsg, recorder.Body.String())
}
})
}
}
func TestPostHandlerFetchTitleFromURLErrors(t *testing.T) {
handler := NewPostHandler(testutils.NewPostRepositoryStub(), nil, nil)
request := httptest.NewRequest(http.MethodGet, "/api/posts/title?url=https://example.com", nil)
recorder := httptest.NewRecorder()
handler.FetchTitleFromURL(recorder, request)
if recorder.Result().StatusCode != http.StatusNotImplemented {
t.Fatalf("expected 501 when fetcher unavailable, got %d", recorder.Result().StatusCode)
}
handler = NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil)
request = httptest.NewRequest(http.MethodGet, "/api/posts/title", nil)
recorder = httptest.NewRecorder()
handler.FetchTitleFromURL(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing url query, got %d", recorder.Result().StatusCode)
}
handler = NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{FetchTitleFn: func(ctx context.Context, rawURL string) (string, error) {
return "", errors.New("failed")
}}, nil)
request = httptest.NewRequest(http.MethodGet, "/api/posts/title?url=https://example.com", nil)
recorder = httptest.NewRecorder()
handler.FetchTitleFromURL(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadGateway)
}
func TestTranslatePostCreateError(t *testing.T) {
conflictErr := &pgconn.PgError{Code: "23505"}
msg, status := translatePostCreateError(conflictErr)
if status != http.StatusConflict || !strings.Contains(msg, "already been submitted") {
t.Fatalf("unexpected conflict translation: status=%d msg=%q", status, msg)
}
fkErr := &pgconn.PgError{Code: "23503"}
msg, status = translatePostCreateError(fkErr)
if status != http.StatusUnauthorized || !strings.Contains(msg, "Author account not found") {
t.Fatalf("unexpected foreign key translation: status=%d msg=%q", status, msg)
}
msg, status = translatePostCreateError(errors.New("other"))
if status != 0 || msg != "" {
t.Fatalf("expected passthrough for unrelated errors, got status=%d msg=%q", status, msg)
}
}
func TestPostHandlerUpdatePost(t *testing.T) {
tests := []struct {
name string
postID string
requestBody string
userID uint
mockSetup func(*testutils.PostRepositoryStub)
expectedStatus int
expectedError string
}{
{
name: "valid post update",
postID: "1",
requestBody: `{"title": "Updated Title", "content": "Updated content"}`,
userID: 1,
mockSetup: func(repo *testutils.PostRepositoryStub) {
repo.GetByIDFn = func(id uint) (*database.Post, error) {
authorID := uint(1)
return &database.Post{ID: id, Title: "Old Title", AuthorID: &authorID}, nil
}
repo.UpdateFn = func(post *database.Post) error { return nil }
},
expectedStatus: http.StatusOK,
},
{
name: "missing user context",
postID: "1",
requestBody: `{"title": "Updated Title", "content": "Updated content"}`,
userID: 0,
mockSetup: func(repo *testutils.PostRepositoryStub) {},
expectedStatus: http.StatusUnauthorized,
expectedError: "Authentication required",
},
{
name: "post not found",
postID: "999",
requestBody: `{"title": "Updated Title", "content": "Updated content"}`,
userID: 1,
mockSetup: func(repo *testutils.PostRepositoryStub) {
repo.GetByIDFn = func(id uint) (*database.Post, error) {
return nil, gorm.ErrRecordNotFound
}
},
expectedStatus: http.StatusNotFound,
expectedError: "Post not found",
},
{
name: "not author",
postID: "1",
requestBody: `{"title": "Updated Title", "content": "Updated content"}`,
userID: 2,
mockSetup: func(repo *testutils.PostRepositoryStub) {
repo.GetByIDFn = func(id uint) (*database.Post, error) {
authorID := uint(1)
return &database.Post{ID: id, Title: "Old Title", AuthorID: &authorID}, nil
}
},
expectedStatus: http.StatusForbidden,
expectedError: "You can only edit your own posts",
},
{
name: "empty title",
postID: "1",
requestBody: `{"title": "", "content": "Updated content"}`,
userID: 1,
mockSetup: func(repo *testutils.PostRepositoryStub) {
authorID := uint(1)
repo.GetByIDFn = func(id uint) (*database.Post, error) {
return &database.Post{ID: id, Title: "Old Title", AuthorID: &authorID}, nil
}
},
expectedStatus: http.StatusBadRequest,
expectedError: "Title is required",
},
{
name: "short title",
postID: "1",
requestBody: `{"title": "ab", "content": "Updated content"}`,
userID: 1,
mockSetup: func(repo *testutils.PostRepositoryStub) {
authorID := uint(1)
repo.GetByIDFn = func(id uint) (*database.Post, error) {
return &database.Post{ID: id, Title: "Old Title", AuthorID: &authorID}, nil
}
},
expectedStatus: http.StatusBadRequest,
expectedError: "Title must be at least 3 characters",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := testutils.NewPostRepositoryStub()
if tt.mockSetup != nil {
tt.mockSetup(repo)
}
handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/"+tt.postID, bytes.NewBufferString(tt.requestBody))
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
}
ctx := chi.NewRouteContext()
ctx.URLParams.Add("id", tt.postID)
request = request.WithContext(context.WithValue(request.Context(), chi.RouteCtxKey, ctx))
recorder := httptest.NewRecorder()
handler.UpdatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, tt.expectedStatus)
if tt.expectedError != "" {
if !strings.Contains(recorder.Body.String(), tt.expectedError) {
t.Fatalf("expected error to contain %q, got %q", tt.expectedError, recorder.Body.String())
}
}
})
}
}
func TestPostHandlerDeletePost(t *testing.T) {
tests := []struct {
name string
postID string
userID uint
mockSetup func(*testutils.PostRepositoryStub)
expectedStatus int
expectedError string
}{
{
name: "valid post deletion",
postID: "1",
userID: 1,
mockSetup: func(repo *testutils.PostRepositoryStub) {
repo.GetByIDFn = func(id uint) (*database.Post, error) {
authorID := uint(1)
return &database.Post{ID: id, Title: "Test Post", AuthorID: &authorID}, nil
}
repo.DeleteFn = func(id uint) error { return nil }
},
expectedStatus: http.StatusOK,
},
{
name: "missing user context",
postID: "1",
userID: 0,
mockSetup: func(repo *testutils.PostRepositoryStub) {},
expectedStatus: http.StatusUnauthorized,
expectedError: "Authentication required",
},
{
name: "post not found",
postID: "999",
userID: 1,
mockSetup: func(repo *testutils.PostRepositoryStub) {
repo.GetByIDFn = func(id uint) (*database.Post, error) {
return nil, gorm.ErrRecordNotFound
}
},
expectedStatus: http.StatusNotFound,
expectedError: "Post not found",
},
{
name: "not author",
postID: "1",
userID: 2,
mockSetup: func(repo *testutils.PostRepositoryStub) {
repo.GetByIDFn = func(id uint) (*database.Post, error) {
authorID := uint(1)
return &database.Post{ID: id, Title: "Test Post", AuthorID: &authorID}, nil
}
},
expectedStatus: http.StatusForbidden,
expectedError: "You can only delete your own posts",
},
{
name: "delete error",
postID: "1",
userID: 1,
mockSetup: func(repo *testutils.PostRepositoryStub) {
repo.GetByIDFn = func(id uint) (*database.Post, error) {
authorID := uint(1)
return &database.Post{ID: id, Title: "Test Post", AuthorID: &authorID}, nil
}
repo.DeleteFn = func(id uint) error { return errors.New("database error") }
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to delete post",
},
{
name: "delete votes error",
postID: "1",
userID: 1,
mockSetup: func(repo *testutils.PostRepositoryStub) {
repo.GetByIDFn = func(id uint) (*database.Post, error) {
authorID := uint(1)
return &database.Post{ID: id, Title: "Test Post", AuthorID: &authorID}, nil
}
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to delete post votes",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := testutils.NewPostRepositoryStub()
if tt.mockSetup != nil {
tt.mockSetup(repo)
}
var voteService *services.VoteService
if tt.name == "delete votes error" {
voteRepo := &errorVoteRepository{}
voteService = services.NewVoteService(voteRepo, repo, nil)
} else {
voteRepo := testutils.NewMockVoteRepository()
voteService = services.NewVoteService(voteRepo, repo, nil)
}
handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, voteService)
request := httptest.NewRequest(http.MethodDelete, "/api/posts/"+tt.postID, nil)
if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx)
}
ctx := chi.NewRouteContext()
ctx.URLParams.Add("id", tt.postID)
request = request.WithContext(context.WithValue(request.Context(), chi.RouteCtxKey, ctx))
recorder := httptest.NewRecorder()
handler.DeletePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, tt.expectedStatus)
if tt.expectedError != "" {
if !strings.Contains(recorder.Body.String(), tt.expectedError) {
t.Fatalf("expected error to contain %q, got %q", tt.expectedError, recorder.Body.String())
}
}
})
}
}
type errorVoteRepository struct{}
func (e *errorVoteRepository) Create(*database.Vote) error { return nil }
func (e *errorVoteRepository) CreateOrUpdate(*database.Vote) error { return nil }
func (e *errorVoteRepository) GetByID(uint) (*database.Vote, error) {
return nil, gorm.ErrRecordNotFound
}
func (e *errorVoteRepository) GetByUserAndPost(uint, uint) (*database.Vote, error) {
return nil, gorm.ErrRecordNotFound
}
func (e *errorVoteRepository) GetByVoteHash(string) (*database.Vote, error) {
return nil, gorm.ErrRecordNotFound
}
func (e *errorVoteRepository) GetByPostID(uint) ([]database.Vote, error) {
return nil, errors.New("database error")
}
func (e *errorVoteRepository) GetByUserID(uint) ([]database.Vote, error) { return nil, nil }
func (e *errorVoteRepository) Update(*database.Vote) error { return nil }
func (e *errorVoteRepository) Delete(uint) error { return nil }
func (e *errorVoteRepository) Count() (int64, error) { return 0, nil }
func (e *errorVoteRepository) CountByPostID(uint) (int64, error) { return 0, nil }
func (e *errorVoteRepository) CountByUserID(uint) (int64, error) { return 0, nil }
func (e *errorVoteRepository) WithTx(*gorm.DB) repositories.VoteRepository { return e }
func TestPostHandler_EdgeCases(t *testing.T) {
postRepo := testutils.NewPostRepositoryStub()
titleFetcher := &testutils.TitleFetcherStub{}
handler := NewPostHandler(postRepo, titleFetcher, nil)
t.Run("GetPosts with zero limit", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts?limit=0", nil)
w := httptest.NewRecorder()
handler.GetPosts(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for zero limit, got %d", w.Code)
}
})
t.Run("GetPosts with negative limit", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts?limit=-1", nil)
w := httptest.NewRecorder()
handler.GetPosts(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for negative limit, got %d", w.Code)
}
})
t.Run("GetPosts with negative offset", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts?offset=-1", nil)
w := httptest.NewRecorder()
handler.GetPosts(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for negative offset, got %d", w.Code)
}
})
}

View File

@@ -0,0 +1,21 @@
package handlers
import (
"net/http"
"goyco/internal/middleware"
"github.com/go-chi/chi/v5"
)
type RouteModule interface {
MountRoutes(r chi.Router, config RouteModuleConfig)
}
type RouteModuleConfig struct {
AuthService middleware.TokenVerifier
GeneralRateLimit func(chi.Router) chi.Router
AuthRateLimit func(chi.Router) chi.Router
CSRFMiddleware func(http.Handler) http.Handler
AuthMiddleware func(http.Handler) http.Handler
}

View File

@@ -0,0 +1,412 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"gorm.io/gorm"
"goyco/internal/database"
"goyco/internal/middleware"
"goyco/internal/security"
"goyco/internal/testutils"
"goyco/internal/validation"
)
func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
maliciousInputs := testutils.GetMaliciousInputs()
for _, payload := range maliciousInputs.XSSPayloads {
t.Run("XSS_"+payload[:minLen(20, len(payload))], func(t *testing.T) {
repo := &testutils.PostRepositoryStub{
CreateFn: func(post *database.Post) error {
sanitizedTitle := security.SanitizeInput(payload)
if post.Title != sanitizedTitle {
t.Errorf("Expected sanitized title, got %q", post.Title)
}
return nil
},
}
handler := NewPostHandler(repo, nil, nil)
postData := map[string]string{
"title": payload,
"url": "https://example.com",
"content": "Test content",
}
body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
})
}
}
func minLen(a, b int) int {
if a < b {
return a
}
return b
}
func TestPostHandler_InputValidation(t *testing.T) {
tests := []struct {
name string
title string
content string
url string
expectedStatus int
description string
}{
{
name: "title too long",
title: string(make([]byte, 201)),
content: "Normal content",
url: "https://example.com",
expectedStatus: http.StatusBadRequest,
description: "Title should be limited to 200 characters",
},
{
name: "content too long",
title: "Normal title",
content: string(make([]byte, 10001)),
url: "https://example.com",
expectedStatus: http.StatusBadRequest,
description: "Content should be limited to 10,000 characters",
},
{
name: "invalid URL protocol",
title: "Normal title",
content: "Normal content",
url: "ftp://example.com",
expectedStatus: http.StatusBadRequest,
description: "Only HTTP and HTTPS URLs should be allowed",
},
{
name: "localhost URL blocked",
title: "Normal title",
content: "Normal content",
url: "http://localhost:8080",
expectedStatus: http.StatusBadRequest,
description: "Localhost URLs should be blocked",
},
{
name: "private IP URL blocked",
title: "Normal title",
content: "Normal content",
url: "http://192.168.1.1",
expectedStatus: http.StatusBadRequest,
description: "Private IP URLs should be blocked",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &testutils.PostRepositoryStub{}
handler := NewPostHandler(repo, nil, nil)
postData := map[string]string{
"title": tt.title,
"url": tt.url,
"content": tt.content,
}
body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, tt.expectedStatus)
})
}
}
func TestAuthHandler_PasswordValidation(t *testing.T) {
tests := []struct {
name string
password string
expectedStatus int
description string
}{
{
name: "weak password",
password: "123",
expectedStatus: http.StatusBadRequest,
description: "Weak passwords should be rejected",
},
{
name: "password without letters",
password: "12345678",
expectedStatus: http.StatusBadRequest,
description: "Passwords without letters should be rejected",
},
{
name: "password without numbers",
password: "password",
expectedStatus: http.StatusBadRequest,
description: "Passwords without numbers should be rejected",
},
{
name: "password without special chars",
password: "Password123",
expectedStatus: http.StatusBadRequest,
description: "Passwords without special characters should be rejected",
},
{
name: "password too short",
password: "Pass1!",
expectedStatus: http.StatusBadRequest,
description: "Passwords shorter than 8 characters should be rejected",
},
{
name: "password too long",
password: string(make([]byte, 129)),
expectedStatus: http.StatusBadRequest,
description: "Passwords that are too long should be rejected",
},
{
name: "empty password",
password: "",
expectedStatus: http.StatusBadRequest,
description: "Empty passwords should be rejected",
},
{
name: "valid password",
password: "Password123!",
expectedStatus: http.StatusCreated,
description: "Valid passwords should be accepted",
},
{
name: "valid password with underscore",
password: "Password123_",
expectedStatus: http.StatusCreated,
description: "Valid passwords with underscore should be accepted",
},
{
name: "valid password with hyphen",
password: "Password123-",
expectedStatus: http.StatusCreated,
description: "Valid passwords with hyphen should be accepted",
},
{
name: "valid password with unicode",
password: "Pássw0rd123!",
expectedStatus: http.StatusCreated,
description: "Valid passwords with unicode should be accepted",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &testutils.UserRepositoryStub{
GetByUsernameFn: func(string) (*database.User, error) {
return nil, gorm.ErrRecordNotFound
},
CreateFn: func(user *database.User) error {
return nil
},
}
handler := newAuthHandler(repo)
registerData := map[string]string{
"username": "testuser",
"email": "test@example.com",
"password": tt.password,
}
body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, tt.expectedStatus)
})
}
}
func TestAuthHandler_UsernameSanitization(t *testing.T) {
tests := []struct {
name string
username string
expectedStatus int
description string
}{
{
name: "username with special chars",
username: "test@user#123",
expectedStatus: http.StatusCreated,
description: "Special characters should be removed from username",
},
{
name: "username with script tags",
username: "test<script>alert('xss')</script>user",
expectedStatus: http.StatusCreated,
description: "Script tags should be removed from username",
},
{
name: "username starting with special char",
username: "@testuser",
expectedStatus: http.StatusCreated,
description: "Username starting with special char should be prefixed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedUsername string
repo := &testutils.UserRepositoryStub{
GetByUsernameFn: func(username string) (*database.User, error) {
capturedUsername = username
return nil, gorm.ErrRecordNotFound
},
CreateFn: func(user *database.User) error {
return nil
},
}
handler := newAuthHandler(repo)
registerData := map[string]string{
"username": tt.username,
"email": "test@example.com",
"password": "Password123!",
}
body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, tt.expectedStatus)
expectedUsername := security.SanitizeUsername(tt.username)
if capturedUsername != expectedUsername {
t.Errorf("Expected sanitized username %q, got %q", expectedUsername, capturedUsername)
}
})
}
}
func TestPostHandler_AuthorizationBypass(t *testing.T) {
repo := &testutils.PostRepositoryStub{
GetByIDFn: func(id uint) (*database.Post, error) {
authorID := uint(2)
return &database.Post{ID: id, Title: "Test Post", AuthorID: &authorID}, nil
},
}
handler := NewPostHandler(repo, nil, nil)
updateData := map[string]string{
"title": "Updated Title",
"content": "Updated content",
}
body, _ := json.Marshal(updateData)
request := httptest.NewRequest("PUT", "/api/posts/1", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
routeCtx := chi.NewRouteContext()
routeCtx.URLParams.Add("id", "1")
request = request.WithContext(context.WithValue(request.Context(), chi.RouteCtxKey, routeCtx))
recorder := httptest.NewRecorder()
handler.UpdatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Users should not be able to edit other users' posts", recorder.Result().StatusCode)
}
}
func TestPageHandler_PasswordResetValidation(t *testing.T) {
tests := []struct {
name string
password string
expectedError bool
description string
}{
{
name: "valid password",
password: "Password123!",
expectedError: false,
description: "Valid passwords should pass validation",
},
{
name: "password without special chars",
password: "Password123",
expectedError: true,
description: "Passwords without special characters should be rejected",
},
{
name: "password too short",
password: "Pass1!",
expectedError: true,
description: "Passwords shorter than 8 characters should be rejected",
},
{
name: "password without letters",
password: "12345678!",
expectedError: true,
description: "Passwords without letters should be rejected",
},
{
name: "password without numbers",
password: "Password!",
expectedError: true,
description: "Passwords without numbers should be rejected",
},
{
name: "empty password",
password: "",
expectedError: true,
description: "Empty passwords should be rejected",
},
{
name: "password too long",
password: string(make([]byte, 129)),
expectedError: true,
description: "Passwords longer than 128 characters should be rejected",
},
{
name: "valid password with unicode",
password: "Pássw0rd123!",
expectedError: false,
description: "Valid passwords with unicode should pass validation",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validation.ValidatePassword(tt.password)
if tt.expectedError && err == nil {
t.Errorf("ValidatePassword(%q) expected error, got nil. %s", tt.password, tt.description)
}
if !tt.expectedError && err != nil {
t.Errorf("ValidatePassword(%q) unexpected error: %v. %s", tt.password, err, tt.description)
}
})
}
}

View File

@@ -0,0 +1,195 @@
package handlers
import (
"errors"
"net/http"
"goyco/internal/dto"
"goyco/internal/repositories"
"goyco/internal/validation"
"github.com/go-chi/chi/v5"
)
type UserHandler struct {
userRepo repositories.UserRepository
authService AuthServiceInterface
}
func NewUserHandler(userRepo repositories.UserRepository, authService AuthServiceInterface) *UserHandler {
return &UserHandler{
userRepo: userRepo,
authService: authService,
}
}
type UserResponse = CommonResponse
// @Summary List users
// @Description Retrieve a paginated list of users
// @Tags users
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param limit query int false "Number of users to return" default(20)
// @Param offset query int false "Number of users to skip" default(0)
// @Success 200 {object} UserResponse "Users retrieved successfully"
// @Failure 401 {object} UserResponse "Authentication required"
// @Failure 500 {object} UserResponse "Internal server error"
// @Router /users [get]
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
limit, offset := parsePagination(r)
users, err := h.userRepo.GetAll(limit, offset)
if err != nil {
SendErrorResponse(w, "Failed to fetch users", http.StatusInternalServerError)
return
}
userDTOs := dto.ToSanitizedUserDTOs(users)
SendSuccessResponse(w, "Users retrieved successfully", map[string]any{
"users": userDTOs,
"count": len(userDTOs),
"limit": limit,
"offset": offset,
})
}
// @Summary Get user
// @Description Retrieve a specific user by ID
// @Tags users
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param id path int true "User ID"
// @Success 200 {object} UserResponse "User retrieved successfully"
// @Failure 400 {object} UserResponse "Invalid user ID"
// @Failure 401 {object} UserResponse "Authentication required"
// @Failure 404 {object} UserResponse "User not found"
// @Failure 500 {object} UserResponse "Internal server error"
// @Router /users/{id} [get]
func (h *UserHandler) GetUser(w http.ResponseWriter, r *http.Request) {
userID, ok := ParseUintParam(w, r, "id", "User")
if !ok {
return
}
user, err := h.userRepo.GetByID(userID)
if !HandleRepoError(w, err, "User") {
return
}
userDTO := dto.ToSanitizedUserDTO(user)
SendSuccessResponse(w, "User retrieved successfully", userDTO)
}
// @Summary Create user
// @Description Create a new user account
// @Tags users
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body RegisterRequest true "User data"
// @Success 201 {object} UserResponse "User created successfully"
// @Failure 400 {object} UserResponse "Invalid request data or validation failed"
// @Failure 401 {object} UserResponse "Authentication required"
// @Failure 409 {object} UserResponse "Username or email already exists"
// @Failure 500 {object} UserResponse "Internal server error"
// @Router /users [post]
func (h *UserHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
if err := validation.ValidateUsername(req.Username); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := validation.ValidateEmail(req.Email); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(req.Password); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
result, err := h.authService.Register(req.Username, req.Email, req.Password)
if err != nil {
var validationErr *validation.ValidationError
if errors.As(err, &validationErr) {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if !HandleServiceError(w, err, "Failed to create user", http.StatusInternalServerError) {
return
}
}
SendCreatedResponse(w, "User created successfully. Verification email sent.", map[string]any{
"user": result.User,
"verification_sent": result.VerificationSent,
})
}
// @Summary Get user posts
// @Description Retrieve posts created by a specific user
// @Tags users
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param id path int true "User ID"
// @Param limit query int false "Number of posts to return" default(20)
// @Param offset query int false "Number of posts to skip" default(0)
// @Success 200 {object} UserResponse "User posts retrieved successfully"
// @Failure 400 {object} UserResponse "Invalid user ID or pagination parameters"
// @Failure 401 {object} UserResponse "Authentication required"
// @Failure 500 {object} UserResponse "Internal server error"
// @Router /users/{id}/posts [get]
func (h *UserHandler) GetUserPosts(w http.ResponseWriter, r *http.Request) {
userID, ok := ParseUintParam(w, r, "id", "User")
if !ok {
return
}
limit, offset := parsePagination(r)
posts, err := h.userRepo.GetPosts(userID, limit, offset)
if err != nil {
SendErrorResponse(w, "Failed to fetch user posts", http.StatusInternalServerError)
return
}
postDTOs := dto.ToPostDTOs(posts)
SendSuccessResponse(w, "User posts retrieved successfully", map[string]any{
"posts": postDTOs,
"count": len(postDTOs),
"limit": limit,
"offset": offset,
})
}
func (h *UserHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
protected := r
if config.AuthMiddleware != nil {
protected = r.With(config.AuthMiddleware)
}
if config.GeneralRateLimit != nil {
protected = config.GeneralRateLimit(protected)
}
protected.Get("/users", h.GetUsers)
protected.Post("/users", h.CreateUser)
protected.Get("/users/{id}", h.GetUser)
protected.Get("/users/{id}/posts", h.GetUserPosts)
}

View File

@@ -0,0 +1,362 @@
package handlers
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"gorm.io/gorm"
"goyco/internal/config"
"goyco/internal/database"
"goyco/internal/repositories"
"goyco/internal/services"
"goyco/internal/testutils"
)
func newUserHandler(repo repositories.UserRepository) *UserHandler {
return newUserHandlerWithSender(repo, &testutils.EmailSenderStub{})
}
func newUserHandlerWithSender(repo repositories.UserRepository, sender services.EmailSender) *UserHandler {
cfg := &config.Config{
JWT: config.JWTConfig{Secret: "secret", Expiration: 1},
App: config.AppConfig{BaseURL: "https://test.example.com"},
}
mockRefreshRepo := &mockRefreshTokenRepository{}
authService, err := services.NewAuthFacadeForTest(cfg, repo, nil, nil, mockRefreshRepo, sender)
if err != nil {
panic(fmt.Sprintf("Failed to create auth service: %v", err))
}
return NewUserHandler(repo, authService)
}
func TestUserHandlerGetUsers(t *testing.T) {
var limit, offset int
repo := testutils.NewUserRepositoryStub()
repo.GetAllFn = func(l, o int) ([]database.User, error) {
limit, offset = l, o
return []database.User{{ID: 1}}, nil
}
handler := newUserHandler(repo)
request := httptest.NewRequest(http.MethodGet, "/api/users?limit=5&offset=2", nil)
recorder := httptest.NewRecorder()
handler.GetUsers(recorder, request)
if limit != 5 || offset != 2 {
t.Fatalf("expected limit=5 offset=2, got %d %d", limit, offset)
}
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
}
func TestUserHandlerGetUser(t *testing.T) {
repo := testutils.NewUserRepositoryStub()
handler := newUserHandler(repo)
request := httptest.NewRequest(http.MethodGet, "/api/users/1", nil)
recorder := httptest.NewRecorder()
handler.GetUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
request = httptest.NewRequest(http.MethodGet, "/api/users/abc", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
recorder = httptest.NewRecorder()
handler.GetUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
repo.GetByIDFn = func(uint) (*database.User, error) { return nil, gorm.ErrRecordNotFound }
request = httptest.NewRequest(http.MethodGet, "/api/users/1", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
recorder = httptest.NewRecorder()
handler.GetUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusNotFound)
repo.GetByIDFn = func(id uint) (*database.User, error) {
return &database.User{ID: id, Username: "user"}, nil
}
request = httptest.NewRequest(http.MethodGet, "/api/users/1", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
recorder = httptest.NewRecorder()
handler.GetUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
}
func TestUserHandlerCreateUser(t *testing.T) {
repo := testutils.NewUserRepositoryStub()
repo.CreateFn = func(u *database.User) error {
u.ID = 10
return nil
}
sent := false
handler := newUserHandlerWithSender(repo, &testutils.EmailSenderStub{SendFn: func(to, subject, body string) error {
sent = true
if to != "user@example.com" {
t.Fatalf("expected email to user@example.com, got %q", to)
}
return nil
}})
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`))
recorder := httptest.NewRecorder()
handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
var resp UserResponse
_ = json.NewDecoder(recorder.Body).Decode(&resp)
data := resp.Data.(map[string]any)
if !resp.Success {
t.Fatalf("expected success response")
}
if v, ok := data["verification_sent"].(bool); !ok || !v {
t.Fatalf("expected verification_sent true, got %+v", data["verification_sent"])
}
userData := data["user"].(map[string]any)
if _, ok := userData["password"]; ok {
t.Fatalf("expected password field to be omitted, got %+v", userData)
}
if !sent {
t.Fatalf("expected verification email to be sent")
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString("invalid"))
handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"","email":"","password":""}`))
handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing fields, got %d", recorder.Result().StatusCode)
}
repo.GetByUsernameFn = func(string) (*database.User, error) {
return &database.User{ID: 1}, nil
}
handler = newUserHandler(repo)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`))
handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
}
func TestUserHandlerGetUserPosts(t *testing.T) {
repo := testutils.NewUserRepositoryStub()
repo.GetPostsFn = func(userID uint, limit, offset int) ([]database.Post, error) {
return []database.Post{{ID: 1, AuthorID: &userID}}, nil
}
handler := newUserHandler(repo)
request := httptest.NewRequest(http.MethodGet, "/api/users/1/posts?limit=2&offset=1", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
recorder := httptest.NewRecorder()
handler.GetUserPosts(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
repo.GetPostsFn = func(uint, int, int) ([]database.Post, error) {
return nil, gorm.ErrInvalidValue
}
recorder = httptest.NewRecorder()
handler.GetUserPosts(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError)
}
func TestUserHandlerDataSanitization(t *testing.T) {
repo := testutils.NewUserRepositoryStub()
repo.GetAllFn = func(l, o int) ([]database.User, error) {
users := []database.User{
{
ID: 1,
Username: "user1",
Email: "user1@example.com",
Password: "hashedpassword",
EmailVerified: true,
EmailVerifiedAt: &[]time.Time{time.Now()}[0],
EmailVerificationToken: "secret-token",
PasswordResetToken: "reset-token",
Locked: false,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
{
ID: 2,
Username: "user2",
Email: "user2@example.com",
Password: "another-hashed-password",
EmailVerified: false,
EmailVerificationToken: "another-secret-token",
Locked: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
}
return users, nil
}
handler := newUserHandler(repo)
request := httptest.NewRequest(http.MethodGet, "/api/users", nil)
recorder := httptest.NewRecorder()
handler.GetUsers(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var response map[string]any
if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
data, ok := response["data"].(map[string]any)
if !ok {
t.Fatalf("expected data field in response")
}
users, ok := data["users"].([]any)
if !ok {
t.Fatalf("expected users field in data")
}
if len(users) != 2 {
t.Fatalf("expected 2 users, got %d", len(users))
}
for i, userInterface := range users {
user, ok := userInterface.(map[string]any)
if !ok {
t.Fatalf("expected user %d to be a map", i)
}
expectedFields := []string{"id", "username", "created_at", "updated_at"}
for _, field := range expectedFields {
if _, exists := user[field]; !exists {
t.Errorf("expected field %s to be present in user %d", field, i)
}
}
sensitiveFields := []string{"email", "password", "email_verified", "email_verified_at",
"email_verification_token", "password_reset_token", "locked", "deleted_at"}
for _, field := range sensitiveFields {
if _, exists := user[field]; exists {
t.Errorf("sensitive field %s should not be present in user %d", field, i)
}
}
}
}
func TestUserHandler_PasswordValidation(t *testing.T) {
tests := []struct {
name string
password string
expectedStatus int
description string
}{
{
name: "valid password",
password: "Password123!",
expectedStatus: http.StatusCreated,
description: "Valid passwords should be accepted",
},
{
name: "password without special chars",
password: "Password123",
expectedStatus: http.StatusBadRequest,
description: "Passwords without special characters should be rejected",
},
{
name: "password too short",
password: "Pass1!",
expectedStatus: http.StatusBadRequest,
description: "Passwords shorter than 8 characters should be rejected",
},
{
name: "password without letters",
password: "12345678!",
expectedStatus: http.StatusBadRequest,
description: "Passwords without letters should be rejected",
},
{
name: "password without numbers",
password: "Password!",
expectedStatus: http.StatusBadRequest,
description: "Passwords without numbers should be rejected",
},
{
name: "empty password",
password: "",
expectedStatus: http.StatusBadRequest,
description: "Empty passwords should be rejected",
},
{
name: "password too long",
password: string(make([]byte, 129)),
expectedStatus: http.StatusBadRequest,
description: "Passwords longer than 128 characters should be rejected",
},
{
name: "valid password with unicode",
password: "Pássw0rd123!",
expectedStatus: http.StatusCreated,
description: "Valid passwords with unicode should be accepted",
},
{
name: "valid password with underscore",
password: "Password123_",
expectedStatus: http.StatusCreated,
description: "Valid passwords with underscore should be accepted",
},
{
name: "valid password with hyphen",
password: "Password123-",
expectedStatus: http.StatusCreated,
description: "Valid passwords with hyphen should be accepted",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := testutils.NewUserRepositoryStub()
repo.CreateFn = func(user *database.User) error {
return nil
}
repo.GetByUsernameFn = func(username string) (*database.User, error) {
return nil, gorm.ErrRecordNotFound
}
repo.GetByEmailFn = func(email string) (*database.User, error) {
return nil, gorm.ErrRecordNotFound
}
cfg := &config.Config{
JWT: config.JWTConfig{Secret: "secret", Expiration: 1},
App: config.AppConfig{BaseURL: "https://test.example.com"},
}
emailSender := &testutils.MockEmailSender{}
mockRefreshRepo := &mockRefreshTokenRepository{}
authService, err := services.NewAuthFacadeForTest(cfg, repo, nil, nil, mockRefreshRepo, emailSender)
if err != nil {
t.Fatalf("Failed to create auth service: %v", err)
}
handler := NewUserHandler(repo, authService)
requestBody := fmt.Sprintf(`{"username":"testuser","email":"test@example.com","password":"%s"}`, tt.password)
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(requestBody))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, tt.expectedStatus)
})
}
}

View File

@@ -0,0 +1,293 @@
package handlers
import (
"net/http"
"goyco/internal/database"
"goyco/internal/services"
"github.com/go-chi/chi/v5"
)
// @securityDefinitions.apikey BearerAuth
// @in header
// @name Authorization
// @description Type "Bearer" followed by a space and JWT token.
// @tag.name votes
// @tag.description Voting system endpoints. All votes are handled through the same API with identical behavior.
// @tag.name posts
// @tag.description Post management endpoints with integrated vote statistics.
// @tag.name auth
// @tag.description Authentication and user management endpoints.
// @tag.name users
// @tag.description User management endpoints.
// @tag.name api
// @tag.description API information and system metrics.
type VoteHandler struct {
voteService *services.VoteService
}
func NewVoteHandler(voteService *services.VoteService) *VoteHandler {
return &VoteHandler{
voteService: voteService,
}
}
// @Description Vote request with type field. All votes are handled the same way.
type VoteRequest struct {
Type string `json:"type" example:"up" enums:"up,down,none" description:"Vote type: 'up' for upvote, 'down' for downvote, 'none' to remove vote"`
}
type VoteResponse = CommonResponse
// @Summary Cast a vote on a post
// @Description Vote on a post (upvote, downvote, or remove vote). Authentication is required; the vote is performed on behalf of the current user.
// @Description
// @Description **Vote Types:**
// @Description - `up`: Upvote the post
// @Description - `down`: Downvote the post
// @Description - `none`: Remove existing vote
// @Description
// @Description **Response includes:**
// @Description - Updated post vote counts (up_votes, down_votes, score)
// @Description - Success message
// @Tags votes
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param id path int true "Post ID"
// @Param request body VoteRequest true "Vote data (type: 'up', 'down', or 'none' to remove)"
// @Success 200 {object} VoteResponse "Vote cast successfully with updated post statistics"
// @Failure 401 {object} VoteResponse "Authentication required"
// @Failure 400 {object} VoteResponse "Invalid request data or vote type"
// @Failure 404 {object} VoteResponse "Post not found"
// @Failure 500 {object} VoteResponse "Internal server error"
// @Example 200 {"success": true, "message": "Vote cast successfully", "data": {"post_id": 1, "type": "up", "up_votes": 5, "down_votes": 2, "score": 3, "is_anonymous": false}}
// @Example 400 {"success": false, "error": "Invalid vote type. Must be 'up', 'down', or 'none'"}
// @Router /posts/{id}/vote [post]
func (h *VoteHandler) CastVote(w http.ResponseWriter, r *http.Request) {
userID, ok := RequireAuth(w, r)
if !ok {
return
}
postID, ok := ParseUintParam(w, r, "id", "Post")
if !ok {
return
}
var req VoteRequest
if !DecodeJSONRequest(w, r, &req) {
return
}
var voteType database.VoteType
switch req.Type {
case "up":
voteType = database.VoteUp
case "down":
voteType = database.VoteDown
case "none":
voteType = database.VoteNone
default:
SendErrorResponse(w, "Invalid vote type. Must be 'up', 'down', or 'none'", http.StatusBadRequest)
return
}
ipAddress := GetClientIP(r)
userAgent := r.UserAgent()
serviceReq := services.VoteRequest{
UserID: userID,
PostID: postID,
Type: voteType,
IPAddress: ipAddress,
UserAgent: userAgent,
}
response, err := h.voteService.CastVote(serviceReq)
if err != nil {
if err.Error() == "post not found" {
SendErrorResponse(w, err.Error(), http.StatusNotFound)
return
}
if err.Error() == "post ID is required" || err.Error() == "invalid vote type" {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
SendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
return
}
SendSuccessResponse(w, "Vote cast successfully", response)
}
// @Summary Remove a vote
// @Description Remove a vote from a post for the authenticated user. This is equivalent to casting a vote with type 'none'.
// @Tags votes
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param id path int true "Post ID"
// @Success 200 {object} VoteResponse "Vote removed successfully with updated post statistics"
// @Failure 401 {object} VoteResponse "Authentication required"
// @Failure 400 {object} VoteResponse "Invalid post ID"
// @Failure 404 {object} VoteResponse "Post not found"
// @Failure 500 {object} VoteResponse "Internal server error"
// @Router /posts/{id}/vote [delete]
func (h *VoteHandler) RemoveVote(w http.ResponseWriter, r *http.Request) {
userID, ok := RequireAuth(w, r)
if !ok {
return
}
postID, ok := ParseUintParam(w, r, "id", "Post")
if !ok {
return
}
ipAddress := GetClientIP(r)
userAgent := r.UserAgent()
serviceReq := services.VoteRequest{
UserID: userID,
PostID: postID,
Type: database.VoteNone,
IPAddress: ipAddress,
UserAgent: userAgent,
}
response, err := h.voteService.CastVote(serviceReq)
if err != nil {
if err.Error() == "post not found" {
SendErrorResponse(w, err.Error(), http.StatusNotFound)
return
}
if err.Error() == "post ID is required" {
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
SendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
return
}
SendSuccessResponse(w, "Vote removed successfully", response)
}
// @Summary Get current user's vote
// @Description Retrieve the current user's vote for a specific post. Requires authentication and returns the vote type if it exists.
// @Description
// @Description **Response:**
// @Description - If vote exists: Returns vote details with contextual metadata (including `is_anonymous`)
// @Description - If no vote: Returns success with null vote data and metadata
// @Tags votes
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param id path int true "Post ID"
// @Success 200 {object} VoteResponse "Vote retrieved successfully"
// @Success 200 {object} VoteResponse "No vote found for this user/post combination"
// @Failure 401 {object} VoteResponse "Authentication required"
// @Failure 400 {object} VoteResponse "Invalid post ID"
// @Failure 500 {object} VoteResponse "Internal server error"
// @Example 200 {"success": true, "message": "Vote retrieved successfully", "data": {"has_vote": true, "vote": {"type": "up", "user_id": 123}, "is_anonymous": false}}
// @Example 200 {"success": true, "message": "No vote found", "data": {"has_vote": false, "vote": null, "is_anonymous": false}}
// @Router /posts/{id}/vote [get]
func (h *VoteHandler) GetUserVote(w http.ResponseWriter, r *http.Request) {
userID, ok := RequireAuth(w, r)
if !ok {
return
}
postID, ok := ParseUintParam(w, r, "id", "Post")
if !ok {
return
}
ipAddress := GetClientIP(r)
userAgent := r.UserAgent()
vote, err := h.voteService.GetUserVote(userID, postID, ipAddress, userAgent)
if err != nil {
if err.Error() == "record not found" {
SendSuccessResponse(w, "No vote found", map[string]any{
"has_vote": false,
"vote": nil,
"is_anonymous": false,
})
return
}
SendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
return
}
SendSuccessResponse(w, "Vote retrieved successfully", map[string]any{
"has_vote": true,
"vote": vote,
"is_anonymous": false,
})
}
// @Summary Get post votes
// @Description Retrieve all votes for a specific post. Returns all votes in a single format.
// @Description
// @Description **Authentication Required:** Yes (Bearer token)
// @Description
// @Description **Response includes:**
// @Description - Array of all votes
// @Description - Total vote count
// @Description - Each vote includes type and unauthenticated status
// @Tags votes
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param id path int true "Post ID"
// @Success 200 {object} VoteResponse "Votes retrieved successfully with count"
// @Failure 400 {object} VoteResponse "Invalid post ID"
// @Failure 401 {object} VoteResponse "Authentication required"
// @Failure 500 {object} VoteResponse "Internal server error"
// @Example 200 {"success": true, "message": "Votes retrieved successfully", "data": {"votes": [{"type": "up", "user_id": 123}, {"type": "down", "vote_hash": "abc123"}], "count": 2}}
// @Router /posts/{id}/votes [get]
func (h *VoteHandler) GetPostVotes(w http.ResponseWriter, r *http.Request) {
postID, ok := ParseUintParam(w, r, "id", "Post")
if !ok {
return
}
votes, err := h.voteService.GetPostVotes(postID)
if err != nil {
SendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
return
}
allVotes := make([]any, 0, len(votes))
for _, vote := range votes {
allVotes = append(allVotes, vote)
}
SendSuccessResponse(w, "Votes retrieved successfully", map[string]any{
"votes": allVotes,
"count": len(allVotes),
})
}
func (h *VoteHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
protected := r
if config.AuthMiddleware != nil {
protected = r.With(config.AuthMiddleware)
}
if config.GeneralRateLimit != nil {
protected = config.GeneralRateLimit(protected)
}
protected.Post("/posts/{id}/vote", h.CastVote)
protected.Delete("/posts/{id}/vote", h.RemoveVote)
protected.Get("/posts/{id}/vote", h.GetUserVote)
protected.Get("/posts/{id}/votes", h.GetPostVotes)
}

View File

@@ -0,0 +1,482 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"gorm.io/gorm"
"goyco/internal/database"
"goyco/internal/middleware"
"goyco/internal/services"
"goyco/internal/testutils"
)
func newVoteHandlerWithRepos() *VoteHandler {
handler, _, _ := newVoteHandlerWithReposRefs()
return handler
}
func newVoteHandlerWithReposRefs() (*VoteHandler, *testutils.MockVoteRepository, map[uint]*database.Post) {
voteRepo := testutils.NewMockVoteRepository()
posts := map[uint]*database.Post{
1: {ID: 1},
}
postRepo := testutils.NewPostRepositoryStub()
postRepo.GetByIDFn = func(id uint) (*database.Post, error) {
if post, ok := posts[id]; ok {
copy := *post
return &copy, nil
}
return nil, gorm.ErrRecordNotFound
}
postRepo.UpdateFn = func(post *database.Post) error {
copy := *post
posts[post.ID] = &copy
return nil
}
postRepo.DeleteFn = func(id uint) error {
if _, ok := posts[id]; !ok {
return gorm.ErrRecordNotFound
}
delete(posts, id)
return nil
}
postRepo.CreateFn = func(post *database.Post) error {
copy := *post
posts[post.ID] = &copy
return nil
}
service := services.NewVoteService(voteRepo, postRepo, nil)
return NewVoteHandler(service), voteRepo, posts
}
func TestVoteHandlerCastVote(t *testing.T) {
handler := newVoteHandlerWithRepos()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
handler.CastVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/abc/vote", bytes.NewBufferString(`{"type":"up"}`))
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`invalid`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"maybe"}`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid vote type, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for successful down vote, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for successful none vote, got %d", recorder.Result().StatusCode)
}
}
func TestVoteHandlerCastVotePostNotFound(t *testing.T) {
handler, _, posts := newVoteHandlerWithReposRefs()
delete(posts, 1)
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
recorder := httptest.NewRecorder()
handler.CastVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusNotFound)
}
func TestVoteHandlerRemoveVote(t *testing.T) {
handler := newVoteHandlerWithRepos()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodDelete, "/api/posts/1/vote", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
handler.RemoveVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodDelete, "/api/posts/abc/vote", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.RemoveVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodDelete, "/api/posts/1/vote", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.RemoveVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for removing non-existent vote (idempotent), got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for creating vote, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodDelete, "/api/posts/1/vote", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.RemoveVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 when removing vote, got %d", recorder.Result().StatusCode)
}
}
func TestVoteHandlerRemoveVotePostNotFound(t *testing.T) {
handler, _, posts := newVoteHandlerWithReposRefs()
delete(posts, 1)
request := httptest.NewRequest(http.MethodDelete, "/api/posts/1/vote", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
recorder := httptest.NewRecorder()
handler.RemoveVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusNotFound)
}
func TestVoteHandlerRemoveVoteUnexpectedError(t *testing.T) {
handler, voteRepo, _ := newVoteHandlerWithReposRefs()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
recorder := httptest.NewRecorder()
handler.CastVote(recorder, request)
voteRepo.DeleteErr = fmt.Errorf("database unavailable")
request = httptest.NewRequest(http.MethodDelete, "/api/posts/1/vote", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
recorder = httptest.NewRecorder()
handler.RemoveVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError)
}
func TestVoteHandlerGetUserVote(t *testing.T) {
handler := newVoteHandlerWithRepos()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/api/posts/1/vote", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
handler.GetUserVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodGet, "/api/posts/abc/vote", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.GetUserVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/vote", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.GetUserVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 when vote missing, got %d", recorder.Result().StatusCode)
}
var resp VoteResponse
_ = json.NewDecoder(recorder.Body).Decode(&resp)
data := resp.Data.(map[string]any)
if data["has_vote"].(bool) {
t.Fatalf("expected has_vote false, got true")
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for creating vote, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/vote", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.GetUserVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 when vote exists, got %d", recorder.Result().StatusCode)
}
_ = json.NewDecoder(recorder.Body).Decode(&resp)
data = resp.Data.(map[string]any)
if !data["has_vote"].(bool) {
t.Fatalf("expected has_vote true, got false")
}
}
func TestVoteHandlerGetPostVotes(t *testing.T) {
handler := newVoteHandlerWithRepos()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/api/posts/abc/votes", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
handler.GetPostVotes(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/votes", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
handler.GetPostVotes(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for empty votes, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for creating vote, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for creating vote, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/votes", nil)
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
handler.GetPostVotes(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Result().StatusCode)
}
var resp VoteResponse
_ = json.NewDecoder(recorder.Body).Decode(&resp)
data := resp.Data.(map[string]any)
votes := data["votes"].([]any)
if len(votes) != 2 {
t.Fatalf("expected 2 votes, got %d", len(votes))
}
}
func TestVoteFlowRegression(t *testing.T) {
handler := newVoteHandlerWithRepos()
t.Run("CompleteVoteLifecycle", func(t *testing.T) {
userID := uint(1)
postID := "1"
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/vote", nil)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
handler.GetUserVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for getting vote, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for changing to downvote, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`))
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for removing vote, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/vote", nil)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx)
handler.GetUserVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for getting removed vote, got %d", recorder.Result().StatusCode)
}
var resp VoteResponse
_ = json.NewDecoder(recorder.Body).Decode(&resp)
data := resp.Data.(map[string]any)
if data["has_vote"].(bool) {
t.Fatalf("expected has_vote false after removal, got true")
}
})
t.Run("MultipleUsersVoting", func(t *testing.T) {
postID := "1"
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for user 1 upvote, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for user 2 downvote, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for user 3 upvote, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/votes", nil)
request = testutils.WithURLParams(request, map[string]string{"id": postID})
handler.GetPostVotes(recorder, request)
if recorder.Result().StatusCode != http.StatusOK {
t.Fatalf("expected 200 for getting all votes, got %d", recorder.Result().StatusCode)
}
var resp VoteResponse
_ = json.NewDecoder(recorder.Body).Decode(&resp)
data := resp.Data.(map[string]any)
votes := data["votes"].([]any)
if len(votes) != 3 {
t.Fatalf("expected 3 votes, got %d", len(votes))
}
})
t.Run("ErrorHandlingEdgeCases", func(t *testing.T) {
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(``))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{}`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing type field, got %d", recorder.Result().StatusCode)
}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"invalid"}`))
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx)
handler.CastVote(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid vote type, got %d", recorder.Result().StatusCode)
}
})
}