To gitea and beyond, let's go(-yco)
This commit is contained in:
238
internal/handlers/api_handler.go
Normal file
238
internal/handlers/api_handler.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/middleware"
|
||||
"goyco/internal/repositories"
|
||||
"goyco/internal/services"
|
||||
"goyco/internal/version"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type APIHandler struct {
|
||||
config *config.Config
|
||||
postRepo repositories.PostRepository
|
||||
userRepo repositories.UserRepository
|
||||
voteService *services.VoteService
|
||||
dbMonitor middleware.DBMonitor
|
||||
healthChecker *middleware.DatabaseHealthChecker
|
||||
metricsCollector *middleware.MetricsCollector
|
||||
}
|
||||
|
||||
func NewAPIHandler(config *config.Config, postRepo repositories.PostRepository, userRepo repositories.UserRepository, voteService *services.VoteService) *APIHandler {
|
||||
return &APIHandler{
|
||||
config: config,
|
||||
postRepo: postRepo,
|
||||
userRepo: userRepo,
|
||||
voteService: voteService,
|
||||
}
|
||||
}
|
||||
|
||||
func NewAPIHandlerWithMonitoring(config *config.Config, postRepo repositories.PostRepository, userRepo repositories.UserRepository, voteService *services.VoteService, db *gorm.DB, dbMonitor middleware.DBMonitor) *APIHandler {
|
||||
if db == nil {
|
||||
return NewAPIHandler(config, postRepo, userRepo, voteService)
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return NewAPIHandler(config, postRepo, userRepo, voteService)
|
||||
}
|
||||
|
||||
healthChecker := middleware.NewDatabaseHealthChecker(sqlDB, dbMonitor)
|
||||
metricsCollector := middleware.NewMetricsCollector(dbMonitor)
|
||||
|
||||
return &APIHandler{
|
||||
config: config,
|
||||
postRepo: postRepo,
|
||||
userRepo: userRepo,
|
||||
voteService: voteService,
|
||||
dbMonitor: dbMonitor,
|
||||
healthChecker: healthChecker,
|
||||
metricsCollector: metricsCollector,
|
||||
}
|
||||
}
|
||||
|
||||
type APIInfo = CommonResponse
|
||||
|
||||
func (h *APIHandler) GetAPIInfo(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
apiInfo := map[string]any{
|
||||
"name": fmt.Sprintf("%s API", h.config.App.Title),
|
||||
"version": version.Version,
|
||||
"description": "Y Combinator-style news board API",
|
||||
"endpoints": map[string]any{
|
||||
"authentication": map[string]any{
|
||||
"POST /api/auth/register": "Register new user",
|
||||
"POST /api/auth/login": "Login user",
|
||||
"GET /api/auth/confirm": "Confirm email address",
|
||||
"POST /api/auth/resend-verification": "Resend verification email",
|
||||
"POST /api/auth/forgot-password": "Request password reset",
|
||||
"POST /api/auth/reset-password": "Reset password",
|
||||
"POST /api/auth/account/confirm": "Confirm account deletion",
|
||||
"GET /api/auth/me": "Get current user profile",
|
||||
"POST /api/auth/logout": "Logout user",
|
||||
"PUT /api/auth/email": "Update email address",
|
||||
"PUT /api/auth/username": "Update username",
|
||||
"PUT /api/auth/password": "Update password",
|
||||
"DELETE /api/auth/account": "Request account deletion",
|
||||
},
|
||||
"posts": map[string]any{
|
||||
"GET /api/posts": "List all posts",
|
||||
"GET /api/posts/search": "Search posts",
|
||||
"GET /api/posts/title": "Fetch title from URL",
|
||||
"GET /api/posts/{id}": "Get specific post",
|
||||
"POST /api/posts": "Create new post",
|
||||
"PUT /api/posts/{id}": "Update post",
|
||||
"DELETE /api/posts/{id}": "Delete post",
|
||||
},
|
||||
"votes": map[string]any{
|
||||
"POST /api/posts/{id}/vote": "Cast a vote",
|
||||
"DELETE /api/posts/{id}/vote": "Remove vote",
|
||||
"GET /api/posts/{id}/vote": "Get user's vote",
|
||||
"GET /api/posts/{id}/votes": "Get all votes for post",
|
||||
},
|
||||
"users": map[string]any{
|
||||
"GET /api/users": "List all users",
|
||||
"POST /api/users": "Create new user",
|
||||
"GET /api/users/{id}": "Get specific user",
|
||||
"GET /api/users/{id}/posts": "Get user's posts",
|
||||
},
|
||||
"system": map[string]any{
|
||||
"GET /health": "Health check",
|
||||
"GET /metrics": "Service metrics",
|
||||
},
|
||||
},
|
||||
"authentication": map[string]any{
|
||||
"type": "Bearer Token (JWT)",
|
||||
"note": "Include Authorization header with 'Bearer <token>' for protected endpoints",
|
||||
},
|
||||
"response_format": map[string]any{
|
||||
"success": "boolean",
|
||||
"message": "string",
|
||||
"data": "object or array",
|
||||
"error": "string (on error)",
|
||||
},
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "API information retrieved successfully", apiInfo)
|
||||
}
|
||||
|
||||
func (h *APIHandler) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if h.healthChecker != nil {
|
||||
health := h.healthChecker.CheckHealth()
|
||||
health["version"] = version.Version
|
||||
SendSuccessResponse(w, "Health check successful", health)
|
||||
return
|
||||
}
|
||||
|
||||
currentTimestamp := time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
health := map[string]any{
|
||||
"status": "healthy",
|
||||
"timestamp": currentTimestamp,
|
||||
"version": version.Version,
|
||||
"services": map[string]any{
|
||||
"database": "connected",
|
||||
"api": "running",
|
||||
},
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Health check successful", health)
|
||||
}
|
||||
|
||||
func (h *APIHandler) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
postCount, err := h.postRepo.Count()
|
||||
if err != nil {
|
||||
SendErrorResponse(w, "Failed to get post count", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
userCount, err := h.userRepo.Count()
|
||||
if err != nil {
|
||||
SendErrorResponse(w, "Failed to get user count", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
totalVoteCount, _, err := h.voteService.GetVoteStatistics()
|
||||
if err != nil {
|
||||
SendErrorResponse(w, "Failed to get vote statistics", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topPosts, err := h.postRepo.GetTopPosts(5)
|
||||
if err != nil {
|
||||
SendErrorResponse(w, "Failed to get top posts", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var avgVotesPerPost float64
|
||||
if postCount > 0 {
|
||||
avgVotesPerPost = float64(totalVoteCount) / float64(postCount)
|
||||
}
|
||||
|
||||
var totalScore int
|
||||
for _, post := range topPosts {
|
||||
totalScore += post.Score
|
||||
}
|
||||
|
||||
var avgScore float64
|
||||
if len(topPosts) > 0 {
|
||||
avgScore = float64(totalScore) / float64(len(topPosts))
|
||||
}
|
||||
|
||||
metrics := map[string]any{
|
||||
"posts": map[string]any{
|
||||
"total_count": postCount,
|
||||
"top_posts_count": len(topPosts),
|
||||
"total_score": totalScore,
|
||||
"average_score": avgScore,
|
||||
},
|
||||
"users": map[string]any{
|
||||
"total_count": userCount,
|
||||
},
|
||||
"votes": map[string]any{
|
||||
"total_count": totalVoteCount,
|
||||
"average_per_post": avgVotesPerPost,
|
||||
"note": "All votes are counted together",
|
||||
},
|
||||
"system": map[string]any{
|
||||
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
||||
"version": version.Version,
|
||||
},
|
||||
}
|
||||
|
||||
if h.metricsCollector != nil {
|
||||
performanceMetrics := h.metricsCollector.GetMetrics()
|
||||
metrics["database"] = map[string]any{
|
||||
"total_queries": performanceMetrics.DBStats.TotalQueries,
|
||||
"slow_queries": performanceMetrics.DBStats.SlowQueries,
|
||||
"average_duration": performanceMetrics.DBStats.AverageDuration.String(),
|
||||
"max_duration": performanceMetrics.DBStats.MaxDuration.String(),
|
||||
"error_count": performanceMetrics.DBStats.ErrorCount,
|
||||
"last_query_time": performanceMetrics.DBStats.LastQueryTime.Format(time.RFC3339),
|
||||
}
|
||||
metrics["performance"] = map[string]any{
|
||||
"request_count": performanceMetrics.RequestCount,
|
||||
"average_response": performanceMetrics.AverageResponse.String(),
|
||||
"max_response": performanceMetrics.MaxResponse.String(),
|
||||
"error_count": performanceMetrics.ErrorCount,
|
||||
}
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Metrics retrieved successfully", metrics)
|
||||
}
|
||||
|
||||
func (h *APIHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
|
||||
}
|
||||
280
internal/handlers/api_handler_test.go
Normal file
280
internal/handlers/api_handler_test.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/middleware"
|
||||
"goyco/internal/repositories"
|
||||
"goyco/internal/services"
|
||||
"goyco/internal/testutils"
|
||||
)
|
||||
|
||||
func TestAPIHandlerGetAPIInfo(t *testing.T) {
|
||||
mockPostRepo := testutils.NewPostRepositoryStub()
|
||||
mockUserRepo := testutils.NewUserRepositoryStub()
|
||||
handler := newAPIHandlerForTest(mockPostRepo, mockUserRepo)
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/api", nil)
|
||||
|
||||
handler.GetAPIInfo(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
|
||||
var resp APIInfo
|
||||
if err := json.NewDecoder(recorder.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if !resp.Success || resp.Message == "" {
|
||||
t.Fatalf("expected success response, got %+v", resp)
|
||||
}
|
||||
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
if !ok || data["name"] != fmt.Sprintf("%s API", testutils.AppTestConfig.App.Title) {
|
||||
t.Fatalf("unexpected data payload: %#v", resp.Data)
|
||||
}
|
||||
|
||||
endpoints, ok := data["endpoints"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected endpoints map, got %#v", data["endpoints"])
|
||||
}
|
||||
|
||||
authEndpoints := endpoints["authentication"].(map[string]any)
|
||||
for _, route := range []string{
|
||||
"POST /api/auth/resend-verification",
|
||||
"POST /api/auth/account/confirm",
|
||||
} {
|
||||
if _, found := authEndpoints[route]; !found {
|
||||
t.Fatalf("expected authentication catalogue to include %s", route)
|
||||
}
|
||||
}
|
||||
|
||||
systemEndpoints := endpoints["system"].(map[string]any)
|
||||
if _, found := systemEndpoints["GET /metrics"]; !found {
|
||||
t.Fatalf("expected system catalogue to include GET /metrics")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIHandlerGetHealth(t *testing.T) {
|
||||
mockPostRepo := testutils.NewPostRepositoryStub()
|
||||
mockUserRepo := testutils.NewUserRepositoryStub()
|
||||
handler := newAPIHandlerForTest(mockPostRepo, mockUserRepo)
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
|
||||
handler.GetHealth(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
|
||||
var resp APIInfo
|
||||
if err := json.NewDecoder(recorder.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode error: %v", err)
|
||||
}
|
||||
|
||||
if !resp.Success || resp.Message == "" {
|
||||
t.Fatalf("expected success message, got %+v", resp)
|
||||
}
|
||||
|
||||
data := resp.Data.(map[string]any)
|
||||
if data["status"] != "healthy" {
|
||||
t.Fatalf("expected health status, got %+v", data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIHandlerGetMetrics(t *testing.T) {
|
||||
mockPostRepo := testutils.NewPostRepositoryStub()
|
||||
mockPostRepo.CountFn = func() (int64, error) { return 10, nil }
|
||||
mockPostRepo.GetTopPostsFn = func(limit int) ([]database.Post, error) {
|
||||
return []database.Post{
|
||||
{ID: 1, Score: 100},
|
||||
{ID: 2, Score: 50},
|
||||
{ID: 3, Score: 25},
|
||||
}, nil
|
||||
}
|
||||
|
||||
mockUserRepo := testutils.NewUserRepositoryStub()
|
||||
mockUserRepo.CountFn = func() (int64, error) { return 5, nil }
|
||||
|
||||
handler := newAPIHandlerForTest(mockPostRepo, mockUserRepo)
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
|
||||
handler.GetMetrics(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
|
||||
var resp APIInfo
|
||||
if err := json.NewDecoder(recorder.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode error: %v", err)
|
||||
}
|
||||
|
||||
if !resp.Success || resp.Message == "" {
|
||||
t.Fatalf("expected success response, got %+v", resp)
|
||||
}
|
||||
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected metrics data map, got %T", resp.Data)
|
||||
}
|
||||
|
||||
if data["posts"] == nil {
|
||||
t.Fatalf("expected metrics payload to include posts")
|
||||
}
|
||||
if data["users"] == nil {
|
||||
t.Fatalf("expected metrics payload to include users")
|
||||
}
|
||||
if data["votes"] == nil {
|
||||
t.Fatalf("expected metrics payload to include votes")
|
||||
}
|
||||
if data["system"] == nil {
|
||||
t.Fatalf("expected metrics payload to include system")
|
||||
}
|
||||
|
||||
posts, ok := data["posts"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected posts to be a map, got %T", data["posts"])
|
||||
}
|
||||
if posts["total_count"] != float64(10) {
|
||||
t.Fatalf("expected posts total_count to be 10, got %v", posts["total_count"])
|
||||
}
|
||||
}
|
||||
|
||||
func newAPIHandlerForTest(postRepo repositories.PostRepository, userRepo repositories.UserRepository) *APIHandler {
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := services.NewVoteService(voteRepo, postRepo, nil)
|
||||
return NewAPIHandler(testutils.AppTestConfig, postRepo, userRepo, voteService)
|
||||
}
|
||||
|
||||
func TestAPIHandlerGetMetricsErrorHandling(t *testing.T) {
|
||||
mockPostRepo := testutils.NewPostRepositoryStub()
|
||||
mockPostRepo.CountFn = func() (int64, error) { return 0, errors.New("database error") }
|
||||
|
||||
mockUserRepo := testutils.NewUserRepositoryStub()
|
||||
handler := newAPIHandlerForTest(mockPostRepo, mockUserRepo)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
|
||||
handler.GetMetrics(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError)
|
||||
|
||||
var resp APIInfo
|
||||
if err := json.NewDecoder(recorder.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode error: %v", err)
|
||||
}
|
||||
|
||||
if resp.Success {
|
||||
t.Fatalf("expected error response, got %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIHandlerGetMetricsWithDatabaseMonitoring(t *testing.T) {
|
||||
mockPostRepo := testutils.NewPostRepositoryStub()
|
||||
mockPostRepo.CountFn = func() (int64, error) { return 10, nil }
|
||||
mockPostRepo.GetTopPostsFn = func(limit int) ([]database.Post, error) {
|
||||
return []database.Post{
|
||||
{ID: 1, Score: 100},
|
||||
{ID: 2, Score: 50},
|
||||
}, nil
|
||||
}
|
||||
|
||||
mockUserRepo := testutils.NewUserRepositoryStub()
|
||||
mockUserRepo.CountFn = func() (int64, error) { return 5, nil }
|
||||
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := services.NewVoteService(voteRepo, mockPostRepo, nil)
|
||||
|
||||
handler := NewAPIHandler(testutils.AppTestConfig, mockPostRepo, mockUserRepo, voteService)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
|
||||
handler.GetMetrics(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
|
||||
var resp APIInfo
|
||||
if err := json.NewDecoder(recorder.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode error: %v", err)
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
t.Fatalf("expected success response, got %+v", resp)
|
||||
}
|
||||
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected metrics data map, got %T", resp.Data)
|
||||
}
|
||||
|
||||
expectedSections := []string{"posts", "users", "votes", "system"}
|
||||
for _, section := range expectedSections {
|
||||
if data[section] == nil {
|
||||
t.Fatalf("expected metrics payload to include %s", section)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAPIHandlerWithMonitoring(t *testing.T) {
|
||||
mockPostRepo := testutils.NewPostRepositoryStub()
|
||||
mockUserRepo := testutils.NewUserRepositoryStub()
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := services.NewVoteService(voteRepo, mockPostRepo, nil)
|
||||
monitor := middleware.NewInMemoryDBMonitor()
|
||||
|
||||
db := testutils.NewTestDB(t)
|
||||
defer func() {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
}()
|
||||
|
||||
handler := NewAPIHandlerWithMonitoring(testutils.AppTestConfig, mockPostRepo, mockUserRepo, voteService, db, monitor)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created")
|
||||
}
|
||||
|
||||
if handler.dbMonitor == nil {
|
||||
t.Error("Expected dbMonitor to be set")
|
||||
}
|
||||
|
||||
if handler.healthChecker == nil {
|
||||
t.Error("Expected healthChecker to be set")
|
||||
}
|
||||
|
||||
if handler.metricsCollector == nil {
|
||||
t.Error("Expected metricsCollector to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAPIHandlerWithMonitoring_NilDB(t *testing.T) {
|
||||
mockPostRepo := testutils.NewPostRepositoryStub()
|
||||
mockUserRepo := testutils.NewUserRepositoryStub()
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := services.NewVoteService(voteRepo, mockPostRepo, nil)
|
||||
|
||||
handler := NewAPIHandlerWithMonitoring(testutils.AppTestConfig, mockPostRepo, mockUserRepo, voteService, nil, nil)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created")
|
||||
}
|
||||
|
||||
if handler.dbMonitor != nil {
|
||||
t.Error("Expected dbMonitor to be nil when db is nil")
|
||||
}
|
||||
|
||||
if handler.healthChecker != nil {
|
||||
t.Error("Expected healthChecker to be nil when db is nil")
|
||||
}
|
||||
|
||||
if handler.metricsCollector != nil {
|
||||
t.Error("Expected metricsCollector to be nil when db is nil")
|
||||
}
|
||||
}
|
||||
825
internal/handlers/auth_handler.go
Normal file
825
internal/handlers/auth_handler.go
Normal file
@@ -0,0 +1,825 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/dto"
|
||||
"goyco/internal/repositories"
|
||||
"goyco/internal/security"
|
||||
"goyco/internal/services"
|
||||
"goyco/internal/validation"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type AuthServiceInterface interface {
|
||||
Login(username, password string) (*services.AuthResult, error)
|
||||
Register(username, email, password string) (*services.RegistrationResult, error)
|
||||
ConfirmEmail(token string) (*database.User, error)
|
||||
ResendVerificationEmail(email string) error
|
||||
RequestPasswordReset(usernameOrEmail string) error
|
||||
ResetPassword(token, newPassword string) error
|
||||
UpdateEmail(userID uint, email string) (*database.User, error)
|
||||
UpdateUsername(userID uint, username string) (*database.User, error)
|
||||
UpdatePassword(userID uint, currentPassword, newPassword string) (*database.User, error)
|
||||
RequestAccountDeletion(userID uint) error
|
||||
ConfirmAccountDeletionWithPosts(token string, deletePosts bool) error
|
||||
RefreshAccessToken(refreshToken string) (*services.AuthResult, error)
|
||||
RevokeRefreshToken(refreshToken string) error
|
||||
RevokeAllUserTokens(userID uint) error
|
||||
InvalidateAllSessions(userID uint) error
|
||||
GetAdminEmail() string
|
||||
VerifyToken(tokenString string) (uint, error)
|
||||
GetUserIDFromDeletionToken(token string) (uint, error)
|
||||
UserHasPosts(userID uint) (bool, int64, error)
|
||||
}
|
||||
|
||||
type AuthHandler struct {
|
||||
authService AuthServiceInterface
|
||||
userRepo repositories.UserRepository
|
||||
}
|
||||
|
||||
type AuthResponse = CommonResponse
|
||||
|
||||
type AuthTokensResponse struct {
|
||||
Success bool `json:"success" example:"true"`
|
||||
Message string `json:"message" example:"Authentication successful"`
|
||||
Data AuthTokensDetail `json:"data"`
|
||||
}
|
||||
|
||||
type AuthTokensDetail struct {
|
||||
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
|
||||
RefreshToken string `json:"refresh_token" example:"f94d4ddc7d9b4fcb9d3a2c44c400b780"`
|
||||
User AuthUserSummary `json:"user"`
|
||||
}
|
||||
|
||||
type AuthUserSummary struct {
|
||||
ID uint `json:"id" example:"42"`
|
||||
Username string `json:"username" example:"janedoe"`
|
||||
Email string `json:"email" example:"jane@example.com"`
|
||||
EmailVerified bool `json:"email_verified" example:"true"`
|
||||
Locked bool `json:"locked" example:"false"`
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type CreatePostRequest struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ResendVerificationRequest struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type ForgotPasswordRequest struct {
|
||||
UsernameOrEmail string `json:"username_or_email"`
|
||||
}
|
||||
|
||||
type ResetPasswordRequest struct {
|
||||
Token string `json:"token"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
type UpdateEmailRequest struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type UpdateUsernameRequest struct {
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type UpdatePasswordRequest struct {
|
||||
CurrentPassword string `json:"current_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
type ConfirmAccountDeletionRequest struct {
|
||||
Token string `json:"token"`
|
||||
DeletePosts bool `json:"delete_posts"`
|
||||
}
|
||||
|
||||
type RefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." binding:"required"`
|
||||
}
|
||||
|
||||
type RevokeTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." binding:"required"`
|
||||
}
|
||||
|
||||
func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.UserRepository) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
authService: authService,
|
||||
userRepo: userRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// @Summary Login user
|
||||
// @Description Authenticate user with username and password
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body LoginRequest true "Login credentials"
|
||||
// @Success 200 {object} AuthTokensResponse "Authentication successful"
|
||||
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
|
||||
// @Failure 401 {object} AuthResponse "Invalid credentials"
|
||||
// @Failure 403 {object} AuthResponse "Account is locked"
|
||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||
// @Router /auth/login [post]
|
||||
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
username := security.SanitizeUsername(req.Username)
|
||||
password := strings.TrimSpace(req.Password)
|
||||
|
||||
if username == "" || password == "" {
|
||||
SendErrorResponse(w, "Username and password are required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validation.ValidatePassword(password); err != nil {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.authService.Login(username, password)
|
||||
if !HandleServiceError(w, err, "Authentication failed", http.StatusInternalServerError) {
|
||||
return
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Authentication successful", result)
|
||||
}
|
||||
|
||||
// @Summary Register a new user
|
||||
// @Description Register a new user with username, email and password
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body RegisterRequest true "Registration data"
|
||||
// @Success 201 {object} AuthResponse "Registration successful"
|
||||
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
|
||||
// @Failure 409 {object} AuthResponse "Username or email already exists"
|
||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||
// @Router /auth/register [post]
|
||||
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
username := strings.TrimSpace(req.Username)
|
||||
email := strings.TrimSpace(req.Email)
|
||||
password := strings.TrimSpace(req.Password)
|
||||
|
||||
if username == "" || email == "" || password == "" {
|
||||
SendErrorResponse(w, "Username, email, and password are required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
username = security.SanitizeUsername(username)
|
||||
if err := validation.ValidateUsername(username); err != nil {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validation.ValidateEmail(email); err != nil {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validation.ValidatePassword(password); err != nil {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.authService.Register(username, email, password)
|
||||
if err != nil {
|
||||
var validationErr *validation.ValidationError
|
||||
if errors.As(err, &validationErr) {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !HandleServiceError(w, err, "Registration failed", http.StatusInternalServerError) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userData := map[string]any{
|
||||
"id": result.User.ID,
|
||||
"username": result.User.Username,
|
||||
"email": result.User.Email,
|
||||
"email_verified": result.User.EmailVerified,
|
||||
"created_at": result.User.CreatedAt,
|
||||
"updated_at": result.User.UpdatedAt,
|
||||
"deleted_at": result.User.DeletedAt,
|
||||
}
|
||||
|
||||
responseData := map[string]any{
|
||||
"user": userData,
|
||||
"verification_sent": result.VerificationSent,
|
||||
}
|
||||
|
||||
SendCreatedResponse(w, "Registration successful. Check your email to confirm your account.", responseData)
|
||||
}
|
||||
|
||||
// @Summary Confirm email address
|
||||
// @Description Confirm user email with verification token
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param token query string true "Email verification token"
|
||||
// @Success 200 {object} AuthResponse "Email confirmed successfully"
|
||||
// @Failure 400 {object} AuthResponse "Invalid or missing token"
|
||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||
// @Router /auth/confirm [get]
|
||||
func (h *AuthHandler) ConfirmEmail(w http.ResponseWriter, r *http.Request) {
|
||||
token := strings.TrimSpace(r.URL.Query().Get("token"))
|
||||
if token == "" {
|
||||
SendErrorResponse(w, "Verification token is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.authService.ConfirmEmail(token)
|
||||
if !HandleServiceError(w, err, "Unable to verify email", http.StatusInternalServerError) {
|
||||
return
|
||||
}
|
||||
|
||||
userDTO := dto.ToUserDTO(user)
|
||||
SendSuccessResponse(w, "Email confirmed successfully", map[string]any{
|
||||
"user": userDTO,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Resend verification email
|
||||
// @Description Send a new verification email to the provided address
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body ResendVerificationRequest true "Email address"
|
||||
// @Success 200 {object} AuthResponse
|
||||
// @Failure 400 {object} AuthResponse
|
||||
// @Failure 404 {object} AuthResponse
|
||||
// @Failure 409 {object} AuthResponse
|
||||
// @Failure 429 {object} AuthResponse
|
||||
// @Failure 503 {object} AuthResponse
|
||||
// @Failure 500 {object} AuthResponse
|
||||
// @Router /auth/resend-verification [post]
|
||||
func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(req.Email)
|
||||
if email == "" {
|
||||
SendErrorResponse(w, "Email address is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.authService.ResendVerificationEmail(email)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, services.ErrInvalidCredentials):
|
||||
SendErrorResponse(w, "No account found with this email address", http.StatusNotFound)
|
||||
case errors.Is(err, services.ErrInvalidEmail):
|
||||
SendErrorResponse(w, "Invalid email address format", http.StatusBadRequest)
|
||||
case errors.Is(err, services.ErrEmailSenderUnavailable):
|
||||
SendErrorResponse(w, "We couldn't send the verification email. Try again later.", http.StatusServiceUnavailable)
|
||||
case err.Error() == "email already verified":
|
||||
SendErrorResponse(w, "This email address is already verified", http.StatusConflict)
|
||||
case err.Error() == "verification email sent recently, please wait before requesting another":
|
||||
SendErrorResponse(w, "Please wait 5 minutes before requesting another verification email", http.StatusTooManyRequests)
|
||||
default:
|
||||
SendErrorResponse(w, "Unable to resend verification email", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Verification email sent successfully", map[string]any{
|
||||
"message": "Check your inbox for the verification link",
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Get current user profile
|
||||
// @Description Retrieve the authenticated user's profile information
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Success 200 {object} AuthResponse "User profile retrieved successfully"
|
||||
// @Failure 401 {object} AuthResponse "Authentication required"
|
||||
// @Failure 404 {object} AuthResponse "User not found"
|
||||
// @Router /auth/me [get]
|
||||
func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := RequireAuth(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
SendErrorResponse(w, "User not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
userDTO := dto.ToUserDTO(user)
|
||||
SendSuccessResponse(w, "User profile fetched", userDTO)
|
||||
}
|
||||
|
||||
// @Summary Request a password reset
|
||||
// @Description Send a password reset email using a username or email
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body ForgotPasswordRequest true "Username or email"
|
||||
// @Success 200 {object} AuthResponse "Password reset email sent if account exists"
|
||||
// @Failure 400 {object} AuthResponse "Invalid request data"
|
||||
// @Router /auth/forgot-password [post]
|
||||
func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
UsernameOrEmail string `json:"username_or_email"`
|
||||
}
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail)
|
||||
if usernameOrEmail == "" {
|
||||
SendErrorResponse(w, "Username or email is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.authService.RequestPasswordReset(usernameOrEmail); err != nil {
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "If an account with that username or email exists, we've sent a password reset link.", nil)
|
||||
}
|
||||
|
||||
// @Summary Reset password
|
||||
// @Description Reset a user's password using a reset token
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body ResetPasswordRequest true "Password reset data"
|
||||
// @Success 200 {object} AuthResponse "Password reset successfully"
|
||||
// @Failure 400 {object} AuthResponse "Invalid or expired token, or validation failed"
|
||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||
// @Router /auth/reset-password [post]
|
||||
func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Token string `json:"token"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(req.Token)
|
||||
newPassword := strings.TrimSpace(req.NewPassword)
|
||||
|
||||
if token == "" {
|
||||
SendErrorResponse(w, "Reset token is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if newPassword == "" {
|
||||
SendErrorResponse(w, "New password is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(newPassword) < 8 {
|
||||
SendErrorResponse(w, "Password must be at least 8 characters long", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.authService.ResetPassword(token, newPassword); err != nil {
|
||||
switch {
|
||||
case strings.Contains(err.Error(), "expired"):
|
||||
SendErrorResponse(w, "The reset link has expired. Please request a new one.", http.StatusBadRequest)
|
||||
case strings.Contains(err.Error(), "invalid"):
|
||||
SendErrorResponse(w, "The reset link is invalid. Please request a new one.", http.StatusBadRequest)
|
||||
default:
|
||||
SendErrorResponse(w, "Unable to reset password. Please try again later.", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Password reset successfully. You can now sign in with your new password.", nil)
|
||||
}
|
||||
|
||||
// @Summary Update email address
|
||||
// @Description Update the authenticated user's email address
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body UpdateEmailRequest true "New email address"
|
||||
// @Success 200 {object} AuthResponse
|
||||
// @Failure 400 {object} AuthResponse
|
||||
// @Failure 401 {object} AuthResponse
|
||||
// @Failure 409 {object} AuthResponse
|
||||
// @Failure 503 {object} AuthResponse
|
||||
// @Failure 500 {object} AuthResponse
|
||||
// @Router /auth/email [put]
|
||||
func (h *AuthHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := RequireAuth(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(req.Email)
|
||||
if err := validation.ValidateEmail(email); err != nil {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.authService.UpdateEmail(userID, email)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, services.ErrEmailTaken):
|
||||
SendErrorResponse(w, "That email is already in use. Choose another one.", http.StatusConflict)
|
||||
case errors.Is(err, services.ErrEmailSenderUnavailable):
|
||||
SendErrorResponse(w, "We couldn't send the confirmation email. Try again later.", http.StatusServiceUnavailable)
|
||||
case errors.Is(err, services.ErrInvalidEmail):
|
||||
SendErrorResponse(w, "Invalid email address", http.StatusBadRequest)
|
||||
default:
|
||||
SendErrorResponse(w, "We couldn't update your email right now.", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
userDTO := dto.ToUserDTO(user)
|
||||
SendSuccessResponse(w, "Email updated. Check your inbox to confirm the new address.", map[string]any{
|
||||
"user": userDTO,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Update username
|
||||
// @Description Update the authenticated user's username
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body UpdateUsernameRequest true "New username"
|
||||
// @Success 200 {object} AuthResponse
|
||||
// @Failure 400 {object} AuthResponse
|
||||
// @Failure 401 {object} AuthResponse
|
||||
// @Failure 409 {object} AuthResponse
|
||||
// @Failure 500 {object} AuthResponse
|
||||
// @Router /auth/username [put]
|
||||
func (h *AuthHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := RequireAuth(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
username := strings.TrimSpace(req.Username)
|
||||
if err := validation.ValidateUsername(username); err != nil {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.authService.UpdateUsername(userID, username)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, services.ErrUsernameTaken):
|
||||
SendErrorResponse(w, "That username is already taken. Try another one.", http.StatusConflict)
|
||||
default:
|
||||
SendErrorResponse(w, "We couldn't update your username right now.", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
userDTO := dto.ToUserDTO(user)
|
||||
SendSuccessResponse(w, "Username updated successfully.", map[string]any{
|
||||
"user": userDTO,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Update password
|
||||
// @Description Update the authenticated user's password
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body UpdatePasswordRequest true "Password update data"
|
||||
// @Success 200 {object} AuthResponse
|
||||
// @Failure 400 {object} AuthResponse
|
||||
// @Failure 401 {object} AuthResponse
|
||||
// @Failure 500 {object} AuthResponse
|
||||
// @Router /auth/password [put]
|
||||
func (h *AuthHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := RequireAuth(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
CurrentPassword string `json:"current_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
currentPassword := strings.TrimSpace(req.CurrentPassword)
|
||||
newPassword := strings.TrimSpace(req.NewPassword)
|
||||
|
||||
if currentPassword == "" {
|
||||
SendErrorResponse(w, "Current password is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validation.ValidatePassword(newPassword); err != nil {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.authService.UpdatePassword(userID, currentPassword, newPassword)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "current password is incorrect") {
|
||||
SendErrorResponse(w, "Current password is incorrect", http.StatusBadRequest)
|
||||
} else {
|
||||
SendErrorResponse(w, "We couldn't update your password right now.", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
userDTO := dto.ToUserDTO(user)
|
||||
SendSuccessResponse(w, "Password updated successfully.", map[string]any{
|
||||
"user": userDTO,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Request account deletion
|
||||
// @Description Initiate the deletion process for the authenticated user's account
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Success 200 {object} AuthResponse "Deletion email sent"
|
||||
// @Failure 401 {object} AuthResponse "Authentication required"
|
||||
// @Failure 503 {object} AuthResponse "Email delivery unavailable"
|
||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||
// @Router /auth/account [delete]
|
||||
func (h *AuthHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := RequireAuth(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.authService.RequestAccountDeletion(userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, services.ErrEmailSenderUnavailable) {
|
||||
SendErrorResponse(w, "Account deletion isn't available right now because email delivery is disabled.", http.StatusServiceUnavailable)
|
||||
} else {
|
||||
SendErrorResponse(w, "We couldn't start the deletion process right now.", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Check your inbox for a confirmation link to finish deleting your account.", nil)
|
||||
}
|
||||
|
||||
// @Summary Confirm account deletion
|
||||
// @Description Confirm account deletion using the provided token
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body ConfirmAccountDeletionRequest true "Account deletion data"
|
||||
// @Success 200 {object} AuthResponse "Account deleted successfully"
|
||||
// @Failure 400 {object} AuthResponse "Invalid or expired token"
|
||||
// @Failure 503 {object} AuthResponse "Email delivery unavailable"
|
||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||
// @Router /auth/account/confirm [post]
|
||||
func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Token string `json:"token"`
|
||||
DeletePosts bool `json:"delete_posts"`
|
||||
}
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(req.Token)
|
||||
if token == "" {
|
||||
SendErrorResponse(w, "Deletion token is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.authService.ConfirmAccountDeletionWithPosts(token, req.DeletePosts); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, services.ErrInvalidDeletionToken):
|
||||
SendErrorResponse(w, "This deletion link is invalid or has expired.", http.StatusBadRequest)
|
||||
case errors.Is(err, services.ErrEmailSenderUnavailable):
|
||||
SendErrorResponse(w, "Account deletion isn't available right now because email delivery is disabled.", http.StatusServiceUnavailable)
|
||||
case errors.Is(err, services.ErrDeletionEmailFailed):
|
||||
SendSuccessResponse(w, "Your account has been deleted, but we couldn't send the confirmation email.", map[string]any{
|
||||
"posts_deleted": req.DeletePosts,
|
||||
})
|
||||
default:
|
||||
SendErrorResponse(w, "We couldn't confirm the deletion right now.", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Your account has been deleted.", map[string]any{
|
||||
"posts_deleted": req.DeletePosts,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Logout user
|
||||
// @Description Logout the authenticated user and invalidate their session
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Success 200 {object} AuthResponse "Logged out successfully"
|
||||
// @Failure 401 {object} AuthResponse "Authentication required"
|
||||
// @Router /auth/logout [post]
|
||||
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
SendSuccessResponse(w, "Logged out successfully", nil)
|
||||
}
|
||||
|
||||
// @Summary Refresh access token
|
||||
// @Description Use a refresh token to get a new access token. This endpoint allows clients to obtain a new access token using a valid refresh token without requiring user credentials.
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body RefreshTokenRequest true "Refresh token data"
|
||||
// @Success 200 {object} AuthTokensResponse "Token refreshed successfully"
|
||||
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
|
||||
// @Failure 401 {object} AuthResponse "Invalid or expired refresh token"
|
||||
// @Failure 403 {object} AuthResponse "Account is locked"
|
||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||
// @Router /auth/refresh [post]
|
||||
func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
||||
var req RefreshTokenRequest
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.RefreshToken) == "" {
|
||||
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.authService.RefreshAccessToken(req.RefreshToken)
|
||||
if !HandleServiceError(w, err, "Token refresh failed", http.StatusInternalServerError) {
|
||||
return
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Token refreshed successfully", result)
|
||||
}
|
||||
|
||||
// @Summary Revoke refresh token
|
||||
// @Description Revoke a specific refresh token. This endpoint allows authenticated users to invalidate a specific refresh token, preventing its future use.
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body RevokeTokenRequest true "Token revocation data"
|
||||
// @Success 200 {object} AuthResponse "Token revoked successfully"
|
||||
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
|
||||
// @Failure 401 {object} AuthResponse "Invalid or expired access token"
|
||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||
// @Router /auth/revoke [post]
|
||||
func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) {
|
||||
var req RevokeTokenRequest
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.RefreshToken) == "" {
|
||||
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.authService.RevokeRefreshToken(req.RefreshToken)
|
||||
if err != nil {
|
||||
SendErrorResponse(w, "Failed to revoke token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Token revoked successfully", nil)
|
||||
}
|
||||
|
||||
// @Summary Revoke all user tokens
|
||||
// @Description Revoke all refresh tokens for the authenticated user. This endpoint allows users to invalidate all their refresh tokens at once, effectively logging them out from all devices.
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Success 200 {object} AuthResponse "All tokens revoked successfully"
|
||||
// @Failure 401 {object} AuthResponse "Invalid or expired access token"
|
||||
// @Failure 500 {object} AuthResponse "Internal server error"
|
||||
// @Router /auth/revoke-all [post]
|
||||
func (h *AuthHandler) RevokeAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := RequireAuth(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.authService.RevokeAllUserTokens(userID)
|
||||
if err != nil {
|
||||
SendErrorResponse(w, "Failed to revoke tokens", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "All tokens revoked successfully", nil)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
|
||||
if config.GeneralRateLimit != nil {
|
||||
rateLimited := config.GeneralRateLimit(r)
|
||||
rateLimited.Post("/auth/refresh", h.RefreshToken)
|
||||
rateLimited.Get("/auth/confirm", h.ConfirmEmail)
|
||||
rateLimited.Post("/auth/resend-verification", h.ResendVerificationEmail)
|
||||
} else {
|
||||
r.Post("/auth/refresh", h.RefreshToken)
|
||||
r.Get("/auth/confirm", h.ConfirmEmail)
|
||||
r.Post("/auth/resend-verification", h.ResendVerificationEmail)
|
||||
}
|
||||
|
||||
if config.AuthRateLimit != nil {
|
||||
rateLimited := config.AuthRateLimit(r)
|
||||
rateLimited.Post("/auth/register", h.Register)
|
||||
rateLimited.Post("/auth/login", h.Login)
|
||||
rateLimited.Post("/auth/forgot-password", h.RequestPasswordReset)
|
||||
rateLimited.Post("/auth/reset-password", h.ResetPassword)
|
||||
rateLimited.Post("/auth/account/confirm", h.ConfirmAccountDeletion)
|
||||
} else {
|
||||
r.Post("/auth/register", h.Register)
|
||||
r.Post("/auth/login", h.Login)
|
||||
r.Post("/auth/forgot-password", h.RequestPasswordReset)
|
||||
r.Post("/auth/reset-password", h.ResetPassword)
|
||||
r.Post("/auth/account/confirm", h.ConfirmAccountDeletion)
|
||||
}
|
||||
|
||||
protected := r
|
||||
if config.AuthMiddleware != nil {
|
||||
protected = r.With(config.AuthMiddleware)
|
||||
}
|
||||
if config.GeneralRateLimit != nil {
|
||||
protected = config.GeneralRateLimit(protected)
|
||||
}
|
||||
|
||||
protected.Get("/auth/me", h.Me)
|
||||
protected.Post("/auth/logout", h.Logout)
|
||||
protected.Post("/auth/revoke", h.RevokeToken)
|
||||
protected.Post("/auth/revoke-all", h.RevokeAllTokens)
|
||||
protected.Put("/auth/email", h.UpdateEmail)
|
||||
protected.Put("/auth/username", h.UpdateUsername)
|
||||
protected.Put("/auth/password", h.UpdatePassword)
|
||||
protected.Delete("/auth/account", h.DeleteAccount)
|
||||
}
|
||||
1584
internal/handlers/auth_handler_test.go
Normal file
1584
internal/handlers/auth_handler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
292
internal/handlers/common.go
Normal file
292
internal/handlers/common.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/dto"
|
||||
"goyco/internal/middleware"
|
||||
"goyco/internal/services"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type CommonResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type PaginationData struct {
|
||||
Count int `json:"count"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
type VoteCookieData struct {
|
||||
Type database.VoteType `json:"type"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
func sendResponse(w http.ResponseWriter, statusCode int, success bool, message string, data any, errMsg string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
response := CommonResponse{
|
||||
Success: success,
|
||||
Message: message,
|
||||
Data: data,
|
||||
Error: errMsg,
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
func SendSuccessResponse(w http.ResponseWriter, message string, data any) {
|
||||
sendResponse(w, http.StatusOK, true, message, data, "")
|
||||
}
|
||||
|
||||
func SendCreatedResponse(w http.ResponseWriter, message string, data any) {
|
||||
sendResponse(w, http.StatusCreated, true, message, data, "")
|
||||
}
|
||||
|
||||
func SendErrorResponse(w http.ResponseWriter, message string, statusCode int) {
|
||||
sendResponse(w, statusCode, false, "", nil, message)
|
||||
}
|
||||
|
||||
func DecodeJSONRequest(w http.ResponseWriter, r *http.Request, req any) bool {
|
||||
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
|
||||
SendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func GetClientIP(r *http.Request) string {
|
||||
return middleware.GetSecureClientIP(r)
|
||||
}
|
||||
|
||||
const (
|
||||
CookieMaxAgeDays = 30
|
||||
SecondsPerDay = 86400
|
||||
DefaultPaginationLimit = 20
|
||||
DefaultPaginationOffset = 0
|
||||
)
|
||||
|
||||
func SetVoteCookie(w http.ResponseWriter, r *http.Request, postID uint, voteType database.VoteType) {
|
||||
cookieName := fmt.Sprintf("vote_%d", postID)
|
||||
cookieValue := fmt.Sprintf("%s:%d", voteType, time.Now().Unix())
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: cookieValue,
|
||||
Path: "/",
|
||||
MaxAge: SecondsPerDay * CookieMaxAgeDays,
|
||||
HttpOnly: true,
|
||||
Secure: IsHTTPS(r),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
http.SetCookie(w, cookie)
|
||||
}
|
||||
|
||||
func GetVoteCookie(r *http.Request, postID uint) string {
|
||||
cookieName := fmt.Sprintf("vote_%d", postID)
|
||||
cookie, err := r.Cookie(cookieName)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
func ClearVoteCookie(w http.ResponseWriter, postID uint) {
|
||||
cookieName := fmt.Sprintf("vote_%d", postID)
|
||||
cookie := &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
}
|
||||
|
||||
func IsHTTPS(r *http.Request) bool {
|
||||
if r.TLS != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if proto := r.Header.Get("X-Forwarded-Proto"); proto == "https" {
|
||||
return true
|
||||
}
|
||||
|
||||
if proto := r.Header.Get("X-Forwarded-Ssl"); proto == "on" {
|
||||
return true
|
||||
}
|
||||
|
||||
if proto := r.Header.Get("X-Forwarded-Scheme"); proto == "https" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func SanitizeUser(user *database.User) dto.SanitizedUserDTO {
|
||||
if user == nil {
|
||||
return dto.SanitizedUserDTO{}
|
||||
}
|
||||
return dto.ToSanitizedUserDTO(user)
|
||||
}
|
||||
|
||||
func SanitizeUsers(users []database.User) []dto.SanitizedUserDTO {
|
||||
return dto.ToSanitizedUserDTOs(users)
|
||||
}
|
||||
|
||||
func parsePagination(r *http.Request) (limit, offset int) {
|
||||
limit = DefaultPaginationLimit
|
||||
offset = DefaultPaginationOffset
|
||||
|
||||
limitStr := r.URL.Query().Get("limit")
|
||||
if limitStr != "" {
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
}
|
||||
|
||||
offsetStr := r.URL.Query().Get("offset")
|
||||
if offsetStr != "" {
|
||||
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
|
||||
offset = o
|
||||
}
|
||||
}
|
||||
|
||||
return limit, offset
|
||||
}
|
||||
|
||||
func ValidateRedirectURL(redirectURL string) string {
|
||||
redirectURL = strings.TrimSpace(redirectURL)
|
||||
if redirectURL == "" || len(redirectURL) > 512 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(redirectURL, "/") || strings.HasPrefix(redirectURL, "//") {
|
||||
return ""
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(redirectURL)
|
||||
if err != nil || parsed.Scheme != "" || parsed.Host != "" || parsed.User != nil || parsed.Path == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
path := parsed.EscapedPath()
|
||||
if path == "" {
|
||||
path = parsed.Path
|
||||
}
|
||||
|
||||
validated := path
|
||||
if parsed.RawQuery != "" {
|
||||
validated += "?" + parsed.RawQuery
|
||||
}
|
||||
if parsed.Fragment != "" {
|
||||
validated += "#" + parsed.Fragment
|
||||
}
|
||||
|
||||
return validated
|
||||
}
|
||||
|
||||
func ParseUintParam(w http.ResponseWriter, r *http.Request, paramName, entityName string) (uint, bool) {
|
||||
str := chi.URLParam(r, paramName)
|
||||
if str == "" {
|
||||
SendErrorResponse(w, entityName+" ID is required", http.StatusBadRequest)
|
||||
return 0, false
|
||||
}
|
||||
id, err := strconv.ParseUint(str, 10, 32)
|
||||
if err != nil {
|
||||
SendErrorResponse(w, "Invalid "+entityName+" ID", http.StatusBadRequest)
|
||||
return 0, false
|
||||
}
|
||||
return uint(id), true
|
||||
}
|
||||
|
||||
func RequireAuth(w http.ResponseWriter, r *http.Request) (uint, bool) {
|
||||
userID := middleware.GetUserIDFromContext(r.Context())
|
||||
if userID == 0 {
|
||||
SendErrorResponse(w, "Authentication required", http.StatusUnauthorized)
|
||||
return 0, false
|
||||
}
|
||||
return userID, true
|
||||
}
|
||||
|
||||
func NewVoteContext(r *http.Request) services.VoteContext {
|
||||
return services.VoteContext{
|
||||
UserID: middleware.GetUserIDFromContext(r.Context()),
|
||||
IPAddress: GetClientIP(r),
|
||||
UserAgent: r.UserAgent(),
|
||||
}
|
||||
}
|
||||
|
||||
func HandleRepoError(w http.ResponseWriter, err error, entityName string) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
SendErrorResponse(w, entityName+" not found", http.StatusNotFound)
|
||||
} else {
|
||||
SendErrorResponse(w, "Failed to retrieve "+entityName, http.StatusInternalServerError)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var AuthErrorMapping = []struct {
|
||||
err error
|
||||
msg string
|
||||
code int
|
||||
}{
|
||||
{services.ErrInvalidCredentials, "Invalid username or password", http.StatusUnauthorized},
|
||||
{services.ErrEmailNotVerified, "Please confirm your email before logging in", http.StatusForbidden},
|
||||
{services.ErrAccountLocked, "Your account has been locked. Please contact us for assistance.", http.StatusForbidden},
|
||||
{services.ErrUsernameTaken, "Username is already taken", http.StatusConflict},
|
||||
{services.ErrEmailTaken, "Email is already registered", http.StatusConflict},
|
||||
{services.ErrInvalidEmail, "Invalid email address", http.StatusBadRequest},
|
||||
{services.ErrPasswordTooShort, "Password must be at least 8 characters", http.StatusBadRequest},
|
||||
{services.ErrInvalidVerificationToken, "Invalid or expired verification token", http.StatusBadRequest},
|
||||
{services.ErrRefreshTokenExpired, "Refresh token has expired", http.StatusUnauthorized},
|
||||
{services.ErrRefreshTokenInvalid, "Invalid refresh token", http.StatusUnauthorized},
|
||||
{services.ErrInvalidDeletionToken, "This deletion link is invalid or has expired.", http.StatusBadRequest},
|
||||
{services.ErrDeletionRequestNotFound, "Deletion request not found", http.StatusBadRequest},
|
||||
{services.ErrUserNotFound, "User not found", http.StatusNotFound},
|
||||
{services.ErrEmailSenderUnavailable, "Email service is unavailable. Please try again later.", http.StatusServiceUnavailable},
|
||||
}
|
||||
|
||||
func HandleServiceError(w http.ResponseWriter, err error, defaultMsg string, defaultCode int) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, mapping := range AuthErrorMapping {
|
||||
if err == mapping.err || errors.Is(err, mapping.err) {
|
||||
SendErrorResponse(w, mapping.msg, mapping.code)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
for _, mapping := range AuthErrorMapping {
|
||||
if mapping.err.Error() == errMsg {
|
||||
SendErrorResponse(w, mapping.msg, mapping.code)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
SendErrorResponse(w, defaultMsg, defaultCode)
|
||||
return false
|
||||
}
|
||||
1158
internal/handlers/common_test.go
Normal file
1158
internal/handlers/common_test.go
Normal file
File diff suppressed because it is too large
Load Diff
146
internal/handlers/fuzz_test.go
Normal file
146
internal/handlers/fuzz_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"goyco/internal/fuzz"
|
||||
)
|
||||
|
||||
func FuzzJSONParsing(f *testing.F) {
|
||||
helper := fuzz.NewFuzzTestHelper()
|
||||
testCases := []map[string]any{
|
||||
{
|
||||
"name": "auth_login",
|
||||
"body": `{"username":"FUZZED_INPUT","password":"test"}`,
|
||||
},
|
||||
{
|
||||
"name": "auth_register",
|
||||
"body": `{"username":"FUZZED_INPUT","email":"test@example.com","password":"test123"}`,
|
||||
},
|
||||
{
|
||||
"name": "post_create",
|
||||
"body": `{"title":"FUZZED_INPUT","url":"https://example.com","content":"test"}`,
|
||||
},
|
||||
{
|
||||
"name": "vote_cast",
|
||||
"body": `{"type":"FUZZED_INPUT"}`,
|
||||
},
|
||||
}
|
||||
helper.RunJSONFuzzTest(f, testCases)
|
||||
}
|
||||
|
||||
func FuzzURLParsing(f *testing.F) {
|
||||
helper := fuzz.NewFuzzTestHelper()
|
||||
helper.RunBasicFuzzTest(f, func(t *testing.T, input string) {
|
||||
|
||||
sanitized := ""
|
||||
for _, char := range input {
|
||||
|
||||
if (char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') ||
|
||||
(char >= '0' && char <= '9') || char == '-' || char == '_' {
|
||||
sanitized += string(char)
|
||||
}
|
||||
}
|
||||
|
||||
if len(sanitized) > 20 {
|
||||
sanitized = sanitized[:20]
|
||||
}
|
||||
|
||||
if len(sanitized) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
url := "/api/posts/" + sanitized
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
|
||||
pathParts := strings.Split(req.URL.Path, "/")
|
||||
if len(pathParts) >= 4 {
|
||||
idStr := pathParts[3]
|
||||
_ = idStr
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzQueryParameters(f *testing.F) {
|
||||
helper := fuzz.NewFuzzTestHelper()
|
||||
helper.RunBasicFuzzTest(f, func(t *testing.T, input string) {
|
||||
|
||||
if !utf8.ValidString(input) {
|
||||
return
|
||||
}
|
||||
|
||||
sanitized := ""
|
||||
for _, char := range input {
|
||||
|
||||
if char >= 32 && char <= 126 {
|
||||
switch char {
|
||||
case ' ', '\n', '\r', '\t':
|
||||
|
||||
continue
|
||||
case '&':
|
||||
sanitized += "%26"
|
||||
case '=':
|
||||
sanitized += "%3D"
|
||||
case '?':
|
||||
sanitized += "%3F"
|
||||
case '#':
|
||||
sanitized += "%23"
|
||||
case '/':
|
||||
sanitized += "%2F"
|
||||
case '\\':
|
||||
sanitized += "%5C"
|
||||
default:
|
||||
sanitized += string(char)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(sanitized) > 100 {
|
||||
sanitized = sanitized[:100]
|
||||
}
|
||||
|
||||
if len(sanitized) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
query := "?q=" + sanitized + "&limit=10&offset=0"
|
||||
req := httptest.NewRequest("GET", "/api/posts/search"+query, nil)
|
||||
|
||||
q := req.URL.Query().Get("q")
|
||||
limit := req.URL.Query().Get("limit")
|
||||
offset := req.URL.Query().Get("offset")
|
||||
|
||||
if !utf8.ValidString(q) {
|
||||
|
||||
return
|
||||
}
|
||||
_ = limit
|
||||
_ = offset
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzHTTPHeaders(f *testing.F) {
|
||||
helper := fuzz.NewFuzzTestHelper()
|
||||
helper.RunBasicFuzzTest(f, func(t *testing.T, input string) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+input)
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
req.Header.Set("User-Agent", input)
|
||||
req.Header.Set("X-Forwarded-For", input)
|
||||
|
||||
for name, values := range req.Header {
|
||||
if !utf8.ValidString(name) {
|
||||
t.Fatal("Header name contains invalid UTF-8")
|
||||
}
|
||||
for _, value := range values {
|
||||
if !utf8.ValidString(value) {
|
||||
t.Fatal("Header value contains invalid UTF-8")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
1626
internal/handlers/page_handler.go
Normal file
1626
internal/handlers/page_handler.go
Normal file
File diff suppressed because it is too large
Load Diff
464
internal/handlers/post_handler.go
Normal file
464
internal/handlers/post_handler.go
Normal file
@@ -0,0 +1,464 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/dto"
|
||||
"goyco/internal/repositories"
|
||||
"goyco/internal/security"
|
||||
"goyco/internal/services"
|
||||
"goyco/internal/validation"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgconn"
|
||||
)
|
||||
|
||||
type PostHandler struct {
|
||||
postRepo repositories.PostRepository
|
||||
titleFetcher services.TitleFetcher
|
||||
voteService *services.VoteService
|
||||
postQueries *services.PostQueries
|
||||
}
|
||||
|
||||
func NewPostHandler(postRepo repositories.PostRepository, titleFetcher services.TitleFetcher, voteService *services.VoteService) *PostHandler {
|
||||
return &PostHandler{
|
||||
postRepo: postRepo,
|
||||
titleFetcher: titleFetcher,
|
||||
voteService: voteService,
|
||||
postQueries: services.NewPostQueries(postRepo, voteService),
|
||||
}
|
||||
}
|
||||
|
||||
type PostResponse = CommonResponse
|
||||
|
||||
type UpdatePostRequest struct {
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// @Summary Get posts
|
||||
// @Description Get a list of posts with pagination. Posts include vote statistics (up_votes, down_votes, score) and current user's vote status.
|
||||
// @Tags posts
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param limit query int false "Number of posts to return" default(20)
|
||||
// @Param offset query int false "Number of posts to skip" default(0)
|
||||
// @Success 200 {object} PostResponse "Posts retrieved successfully with vote statistics"
|
||||
// @Failure 400 {object} PostResponse "Invalid pagination parameters"
|
||||
// @Failure 500 {object} PostResponse "Internal server error"
|
||||
// @Router /posts [get]
|
||||
func (h *PostHandler) GetPosts(w http.ResponseWriter, r *http.Request) {
|
||||
limit, offset := parsePagination(r)
|
||||
|
||||
opts := services.QueryOptions{
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
|
||||
ctx := NewVoteContext(r)
|
||||
|
||||
posts, err := h.postQueries.GetAll(opts, ctx)
|
||||
if err != nil {
|
||||
SendErrorResponse(w, "Failed to fetch posts", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
postDTOs := dto.ToPostDTOs(posts)
|
||||
SendSuccessResponse(w, "Posts retrieved successfully", map[string]any{
|
||||
"posts": postDTOs,
|
||||
"count": len(postDTOs),
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Get a single post
|
||||
// @Description Get a post by ID with vote statistics and current user's vote status
|
||||
// @Tags posts
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path int true "Post ID"
|
||||
// @Success 200 {object} PostResponse "Post retrieved successfully with vote statistics"
|
||||
// @Failure 400 {object} PostResponse "Invalid post ID"
|
||||
// @Failure 404 {object} PostResponse "Post not found"
|
||||
// @Failure 500 {object} PostResponse "Internal server error"
|
||||
// @Router /posts/{id} [get]
|
||||
func (h *PostHandler) GetPost(w http.ResponseWriter, r *http.Request) {
|
||||
postID, ok := ParseUintParam(w, r, "id", "Post")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := NewVoteContext(r)
|
||||
|
||||
post, err := h.postQueries.GetByID(postID, ctx)
|
||||
if !HandleRepoError(w, err, "Post") {
|
||||
return
|
||||
}
|
||||
|
||||
postDTO := dto.ToPostDTO(post)
|
||||
SendSuccessResponse(w, "Post retrieved successfully", postDTO)
|
||||
}
|
||||
|
||||
// @Summary Create a new post
|
||||
// @Description Create a new post with URL and optional title
|
||||
// @Tags posts
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body CreatePostRequest true "Post data"
|
||||
// @Success 201 {object} PostResponse
|
||||
// @Failure 400 {object} PostResponse "Invalid request data or validation failed"
|
||||
// @Failure 401 {object} PostResponse "Authentication required"
|
||||
// @Failure 409 {object} PostResponse "URL already submitted"
|
||||
// @Failure 502 {object} PostResponse "Failed to fetch title from URL"
|
||||
// @Failure 500 {object} PostResponse "Internal server error"
|
||||
// @Router /posts [post]
|
||||
func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
req.Title = security.SanitizeInput(req.Title)
|
||||
req.URL = security.SanitizeURL(req.URL)
|
||||
req.Content = security.SanitizePostContent(req.Content)
|
||||
|
||||
if req.URL == "" {
|
||||
SendErrorResponse(w, "URL is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Title) > 200 {
|
||||
SendErrorResponse(w, "Title must be no more than 200 characters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Content) > 10000 {
|
||||
SendErrorResponse(w, "Content must be no more than 10,000 characters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
userID, ok := RequireAuth(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
title := req.Title
|
||||
|
||||
if title == "" && h.titleFetcher != nil {
|
||||
titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fetchedTitle, err := h.titleFetcher.FetchTitle(titleCtx, req.URL)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, services.ErrUnsupportedScheme):
|
||||
SendErrorResponse(w, "Only HTTP and HTTPS URLs are supported", http.StatusBadRequest)
|
||||
case errors.Is(err, services.ErrTitleNotFound):
|
||||
SendErrorResponse(w, "Title could not be extracted from the provided URL", http.StatusBadRequest)
|
||||
default:
|
||||
SendErrorResponse(w, "Failed to fetch title from URL", http.StatusBadGateway)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
title = fetchedTitle
|
||||
}
|
||||
|
||||
if title == "" {
|
||||
SendErrorResponse(w, "Title is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(title) < 3 {
|
||||
SendErrorResponse(w, "Title must be at least 3 characters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
post := &database.Post{
|
||||
Title: title,
|
||||
URL: req.URL,
|
||||
Content: req.Content,
|
||||
AuthorID: &userID,
|
||||
}
|
||||
|
||||
if err := h.postRepo.Create(post); err != nil {
|
||||
if errMsg, status := translatePostCreateError(err); status != 0 {
|
||||
SendErrorResponse(w, errMsg, status)
|
||||
return
|
||||
}
|
||||
|
||||
SendErrorResponse(w, "Failed to create post", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
postDTO := dto.ToPostDTO(post)
|
||||
SendCreatedResponse(w, "Post created successfully", postDTO)
|
||||
}
|
||||
|
||||
// @Summary Search posts
|
||||
// @Description Search posts by title or content keywords. Results include vote statistics and current user's vote status.
|
||||
// @Tags posts
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param q query string false "Search term"
|
||||
// @Param limit query int false "Number of posts to return" default(20)
|
||||
// @Param offset query int false "Number of posts to skip" default(0)
|
||||
// @Success 200 {object} PostResponse "Search results with vote statistics"
|
||||
// @Failure 400 {object} PostResponse "Invalid search parameters"
|
||||
// @Failure 500 {object} PostResponse "Internal server error"
|
||||
// @Router /posts/search [get]
|
||||
func (h *PostHandler) SearchPosts(w http.ResponseWriter, r *http.Request) {
|
||||
query := strings.TrimSpace(r.URL.Query().Get("q"))
|
||||
limit, offset := parsePagination(r)
|
||||
|
||||
opts := services.QueryOptions{
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
|
||||
ctx := NewVoteContext(r)
|
||||
|
||||
posts, err := h.postQueries.GetSearch(query, opts, ctx)
|
||||
if err != nil {
|
||||
if searchErr, ok := err.(*repositories.SearchError); ok {
|
||||
SendErrorResponse(w, "Invalid search query: "+searchErr.Message, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
SendErrorResponse(w, "Failed to search posts", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
postDTOs := dto.ToPostDTOs(posts)
|
||||
SendSuccessResponse(w, "Search results retrieved successfully", map[string]any{
|
||||
"posts": postDTOs,
|
||||
"count": len(postDTOs),
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Update a post
|
||||
// @Description Update the title and content of a post owned by the authenticated user
|
||||
// @Tags posts
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "Post ID"
|
||||
// @Param request body UpdatePostRequest true "Post update data"
|
||||
// @Success 200 {object} PostResponse "Post updated successfully"
|
||||
// @Failure 400 {object} PostResponse "Invalid request data or validation failed"
|
||||
// @Failure 401 {object} PostResponse "Authentication required"
|
||||
// @Failure 403 {object} PostResponse "Not authorized to update this post"
|
||||
// @Failure 404 {object} PostResponse "Post not found"
|
||||
// @Failure 500 {object} PostResponse "Internal server error"
|
||||
// @Router /posts/{id} [put]
|
||||
func (h *PostHandler) UpdatePost(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := RequireAuth(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
postID, ok := ParseUintParam(w, r, "id", "Post")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
post, err := h.postRepo.GetByID(postID)
|
||||
if !HandleRepoError(w, err, "Post") {
|
||||
return
|
||||
}
|
||||
|
||||
if post.AuthorID == nil || *post.AuthorID != userID {
|
||||
SendErrorResponse(w, "You can only edit your own posts", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
req.Title = security.SanitizeInput(req.Title)
|
||||
req.Content = security.SanitizePostContent(req.Content)
|
||||
|
||||
if len(req.Title) > 200 {
|
||||
SendErrorResponse(w, "Title must be no more than 200 characters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Content) > 10000 {
|
||||
SendErrorResponse(w, "Content must be no more than 10,000 characters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validation.ValidateTitle(req.Title); err != nil {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validation.ValidateContent(req.Content); err != nil {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
post.Title = req.Title
|
||||
post.Content = req.Content
|
||||
|
||||
if err := h.postRepo.Update(post); err != nil {
|
||||
SendErrorResponse(w, "Failed to update post", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
postDTO := dto.ToPostDTO(post)
|
||||
SendSuccessResponse(w, "Post updated successfully", postDTO)
|
||||
}
|
||||
|
||||
// @Summary Delete a post
|
||||
// @Description Delete a post owned by the authenticated user
|
||||
// @Tags posts
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "Post ID"
|
||||
// @Success 200 {object} PostResponse "Post deleted successfully"
|
||||
// @Failure 400 {object} PostResponse "Invalid post ID"
|
||||
// @Failure 401 {object} PostResponse "Authentication required"
|
||||
// @Failure 403 {object} PostResponse "Not authorized to delete this post"
|
||||
// @Failure 404 {object} PostResponse "Post not found"
|
||||
// @Failure 500 {object} PostResponse "Internal server error"
|
||||
// @Router /posts/{id} [delete]
|
||||
func (h *PostHandler) DeletePost(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := RequireAuth(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
postID, ok := ParseUintParam(w, r, "id", "Post")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
post, err := h.postRepo.GetByID(postID)
|
||||
if !HandleRepoError(w, err, "Post") {
|
||||
return
|
||||
}
|
||||
|
||||
if post.AuthorID == nil || *post.AuthorID != userID {
|
||||
SendErrorResponse(w, "You can only delete your own posts", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.voteService.DeleteVotesByPostID(postID); err != nil {
|
||||
SendErrorResponse(w, "Failed to delete post votes", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.postRepo.Delete(postID); err != nil {
|
||||
SendErrorResponse(w, "Failed to delete post", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Post deleted successfully", nil)
|
||||
}
|
||||
|
||||
// @Summary Fetch title from URL
|
||||
// @Description Fetch the HTML title for the provided URL
|
||||
// @Tags posts
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param url query string true "URL to inspect"
|
||||
// @Success 200 {object} PostResponse "Title fetched successfully"
|
||||
// @Failure 400 {object} PostResponse "Invalid URL or URL parameter missing"
|
||||
// @Failure 501 {object} PostResponse "Title fetching is not available"
|
||||
// @Failure 502 {object} PostResponse "Failed to fetch title from URL"
|
||||
// @Router /posts/title [get]
|
||||
func (h *PostHandler) FetchTitleFromURL(w http.ResponseWriter, r *http.Request) {
|
||||
if h.titleFetcher == nil {
|
||||
SendErrorResponse(w, "Title fetching is not available", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
requestedURL := strings.TrimSpace(r.URL.Query().Get("url"))
|
||||
if requestedURL == "" {
|
||||
SendErrorResponse(w, "URL query parameter is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second)
|
||||
defer cancel()
|
||||
|
||||
title, err := h.titleFetcher.FetchTitle(titleCtx, requestedURL)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, services.ErrUnsupportedScheme):
|
||||
SendErrorResponse(w, "Only HTTP and HTTPS URLs are supported", http.StatusBadRequest)
|
||||
case errors.Is(err, services.ErrTitleNotFound):
|
||||
SendErrorResponse(w, "Title could not be extracted from the provided URL", http.StatusBadRequest)
|
||||
default:
|
||||
SendErrorResponse(w, "Failed to fetch title from URL", http.StatusBadGateway)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Title fetched successfully", map[string]string{
|
||||
"title": title,
|
||||
})
|
||||
}
|
||||
|
||||
func translatePostCreateError(err error) (string, int) {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) {
|
||||
switch pgErr.Code {
|
||||
case "23505":
|
||||
return "This URL has already been submitted.", http.StatusConflict
|
||||
case "23503":
|
||||
return "Author account not found. Please sign in again.", http.StatusUnauthorized
|
||||
}
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "UNIQUE constraint") || strings.Contains(errStr, "duplicate") {
|
||||
return "This URL has already been submitted.", http.StatusConflict
|
||||
}
|
||||
|
||||
return "", 0
|
||||
}
|
||||
|
||||
func (h *PostHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
|
||||
public := r
|
||||
if config.GeneralRateLimit != nil {
|
||||
public = config.GeneralRateLimit(r)
|
||||
}
|
||||
public.Get("/posts", h.GetPosts)
|
||||
public.Get("/posts/search", h.SearchPosts)
|
||||
public.Get("/posts/title", h.FetchTitleFromURL)
|
||||
public.Get("/posts/{id}", h.GetPost)
|
||||
|
||||
protected := r
|
||||
if config.AuthMiddleware != nil {
|
||||
protected = r.With(config.AuthMiddleware)
|
||||
}
|
||||
if config.GeneralRateLimit != nil {
|
||||
protected = config.GeneralRateLimit(protected)
|
||||
}
|
||||
protected.Post("/posts", h.CreatePost)
|
||||
protected.Put("/posts/{id}", h.UpdatePost)
|
||||
protected.Delete("/posts/{id}", h.DeletePost)
|
||||
}
|
||||
711
internal/handlers/post_handler_test.go
Normal file
711
internal/handlers/post_handler_test.go
Normal file
@@ -0,0 +1,711 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgconn"
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/middleware"
|
||||
"goyco/internal/repositories"
|
||||
"goyco/internal/services"
|
||||
"goyco/internal/testutils"
|
||||
)
|
||||
|
||||
func decodeHandlerResponse(t *testing.T, rr *httptest.ResponseRecorder) map[string]any {
|
||||
t.Helper()
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func TestPostHandlerGetPostsWithVoteService(t *testing.T) {
|
||||
repo := testutils.NewPostRepositoryStub()
|
||||
repo.GetAllFn = func(limit, offset int) ([]database.Post, error) {
|
||||
return []database.Post{
|
||||
{ID: 1, Title: "Test Post 1"},
|
||||
{ID: 2, Title: "Test Post 2"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := services.NewVoteService(voteRepo, repo, nil)
|
||||
handler := NewPostHandler(repo, nil, voteService)
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/posts", nil)
|
||||
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.GetPosts(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
|
||||
payload := decodeHandlerResponse(t, recorder)
|
||||
if !payload["success"].(bool) {
|
||||
t.Fatalf("expected success response, got %v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostHandlerCreatePostWithTitleFetcher(t *testing.T) {
|
||||
repo := testutils.NewPostRepositoryStub()
|
||||
var storedPost *database.Post
|
||||
repo.CreateFn = func(post *database.Post) error {
|
||||
storedPost = post
|
||||
return nil
|
||||
}
|
||||
|
||||
titleFetcher := &testutils.MockTitleFetcher{}
|
||||
titleFetcher.SetTitle("Fetched Title")
|
||||
|
||||
handler := NewPostHandler(repo, titleFetcher, nil)
|
||||
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"url":"https://example.com","content":"Test content"}`))
|
||||
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.CreatePost(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
|
||||
|
||||
if storedPost == nil {
|
||||
t.Fatal("expected post to be created")
|
||||
}
|
||||
|
||||
if storedPost.Title != "Fetched Title" {
|
||||
t.Errorf("expected title 'Fetched Title', got %s", storedPost.Title)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostHandlerCreatePostTitleFetcherError(t *testing.T) {
|
||||
repo := testutils.NewPostRepositoryStub()
|
||||
titleFetcher := &testutils.MockTitleFetcher{}
|
||||
titleFetcher.SetError(services.ErrUnsupportedScheme)
|
||||
|
||||
handler := NewPostHandler(repo, titleFetcher, nil)
|
||||
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"url":"ftp://example.com"}`))
|
||||
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.CreatePost(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
|
||||
|
||||
payload := decodeHandlerResponse(t, recorder)
|
||||
if payload["success"].(bool) {
|
||||
t.Fatalf("expected error response, got %v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostHandlerSearchPosts(t *testing.T) {
|
||||
repo := testutils.NewPostRepositoryStub()
|
||||
repo.SearchFn = func(query string, limit, offset int) ([]database.Post, error) {
|
||||
return []database.Post{
|
||||
{ID: 1, Title: "Search Result 1"},
|
||||
{ID: 2, Title: "Search Result 2"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
handler := NewPostHandler(repo, nil, nil)
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/posts/search?q=test", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.SearchPosts(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
|
||||
payload := decodeHandlerResponse(t, recorder)
|
||||
if !payload["success"].(bool) {
|
||||
t.Fatalf("expected success response, got %v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostHandlerFetchTitleFromURL(t *testing.T) {
|
||||
titleFetcher := &testutils.MockTitleFetcher{}
|
||||
titleFetcher.SetTitle("Test Title")
|
||||
|
||||
handler := NewPostHandler(nil, titleFetcher, nil)
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/posts/title?url=https://example.com", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.FetchTitleFromURL(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
|
||||
payload := decodeHandlerResponse(t, recorder)
|
||||
if !payload["success"].(bool) {
|
||||
t.Fatalf("expected success response, got %v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostHandlerFetchTitleFromURLNoFetcher(t *testing.T) {
|
||||
handler := NewPostHandler(nil, nil, nil)
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/posts/title?url=https://example.com", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.FetchTitleFromURL(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
func TestPostHandlerUpdatePostUnauthorized(t *testing.T) {
|
||||
repo := testutils.NewPostRepositoryStub()
|
||||
repo.GetByIDFn = func(id uint) (*database.Post, error) {
|
||||
return &database.Post{ID: id, AuthorID: func() *uint { u := uint(2); return &u }()}, nil
|
||||
}
|
||||
|
||||
handler := NewPostHandler(repo, nil, nil)
|
||||
|
||||
request := httptest.NewRequest(http.MethodPut, "/api/posts/1", bytes.NewBufferString(`{"title":"Updated Title","content":"Updated content"}`))
|
||||
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.UpdatePost(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden)
|
||||
}
|
||||
|
||||
func TestPostHandlerDeletePostUnauthorized(t *testing.T) {
|
||||
repo := testutils.NewPostRepositoryStub()
|
||||
repo.GetByIDFn = func(id uint) (*database.Post, error) {
|
||||
return &database.Post{ID: id, AuthorID: func() *uint { u := uint(2); return &u }()}, nil
|
||||
}
|
||||
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService := services.NewVoteService(voteRepo, repo, nil)
|
||||
handler := NewPostHandler(repo, nil, voteService)
|
||||
|
||||
request := httptest.NewRequest(http.MethodDelete, "/api/posts/1", nil)
|
||||
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.DeletePost(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden)
|
||||
}
|
||||
|
||||
func TestPostHandlerGetPosts(t *testing.T) {
|
||||
var receivedLimit, receivedOffset int
|
||||
repo := testutils.NewPostRepositoryStub()
|
||||
repo.GetAllFn = func(limit, offset int) ([]database.Post, error) {
|
||||
receivedLimit = limit
|
||||
receivedOffset = offset
|
||||
return []database.Post{{ID: 1}}, nil
|
||||
}
|
||||
|
||||
handler := NewPostHandler(repo, nil, nil)
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/posts?limit=5&offset=2", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.GetPosts(recorder, request)
|
||||
|
||||
if receivedLimit != 5 || receivedOffset != 2 {
|
||||
t.Fatalf("expected limit=5 offset=2, got %d %d", receivedLimit, receivedOffset)
|
||||
}
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
|
||||
payload := decodeHandlerResponse(t, recorder)
|
||||
if !payload["success"].(bool) {
|
||||
t.Fatalf("expected success response, got %v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostHandlerGetPostErrors(t *testing.T) {
|
||||
repo := testutils.NewPostRepositoryStub()
|
||||
handler := NewPostHandler(repo, nil, nil)
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/posts", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.GetPost(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for missing id, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/posts/abc", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
|
||||
recorder = httptest.NewRecorder()
|
||||
handler.GetPost(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for invalid id, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
repo.GetByIDFn = func(uint) (*database.Post, error) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/posts/1", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
recorder = httptest.NewRecorder()
|
||||
handler.GetPost(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusNotFound)
|
||||
}
|
||||
|
||||
func TestPostHandlerCreatePostSuccess(t *testing.T) {
|
||||
var storedPost *database.Post
|
||||
repo := testutils.NewPostRepositoryStub()
|
||||
repo.CreateFn = func(post *database.Post) error {
|
||||
storedPost = &database.Post{
|
||||
Title: post.Title,
|
||||
URL: post.URL,
|
||||
Content: post.Content,
|
||||
AuthorID: post.AuthorID,
|
||||
}
|
||||
storedPost.ID = 1
|
||||
return nil
|
||||
}
|
||||
fetcher := &testutils.TitleFetcherStub{FetchTitleFn: func(ctx context.Context, rawURL string) (string, error) {
|
||||
return "Fetched Title", nil
|
||||
}}
|
||||
|
||||
handler := NewPostHandler(repo, fetcher, nil)
|
||||
|
||||
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com","content":"Go"}`)
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
|
||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42))
|
||||
request = request.WithContext(ctx)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.CreatePost(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
|
||||
|
||||
if storedPost == nil || storedPost.Title != "Fetched Title" || storedPost.AuthorID == nil || *storedPost.AuthorID != 42 {
|
||||
t.Fatalf("unexpected stored post: %#v", storedPost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostHandlerCreatePostValidation(t *testing.T) {
|
||||
handler := NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"","url":"","content":""}`))
|
||||
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
|
||||
handler.CreatePost(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for missing url, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`invalid json`))
|
||||
handler.CreatePost(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for invalid JSON, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"ok","url":"https://example.com"}`))
|
||||
handler.CreatePost(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func TestPostHandlerCreatePostTitleFetcherErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantStatus int
|
||||
wantMsg string
|
||||
}{
|
||||
{name: "Unsupported", err: services.ErrUnsupportedScheme, wantStatus: http.StatusBadRequest, wantMsg: "Only HTTP and HTTPS URLs are supported"},
|
||||
{name: "TitleMissing", err: services.ErrTitleNotFound, wantStatus: http.StatusBadRequest, wantMsg: "Title could not be extracted"},
|
||||
{name: "Generic", err: errors.New("timeout"), wantStatus: http.StatusBadGateway, wantMsg: "Failed to fetch title"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
repo := testutils.NewPostRepositoryStub()
|
||||
fetcher := &testutils.TitleFetcherStub{FetchTitleFn: func(ctx context.Context, rawURL string) (string, error) {
|
||||
return "", tc.err
|
||||
}}
|
||||
handler := NewPostHandler(repo, fetcher, nil)
|
||||
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com"}`)
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
|
||||
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.CreatePost(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, tc.wantStatus)
|
||||
|
||||
if !strings.Contains(recorder.Body.String(), tc.wantMsg) {
|
||||
t.Fatalf("expected message to contain %q, got %q", tc.wantMsg, recorder.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostHandlerFetchTitleFromURLErrors(t *testing.T) {
|
||||
handler := NewPostHandler(testutils.NewPostRepositoryStub(), nil, nil)
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/posts/title?url=https://example.com", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.FetchTitleFromURL(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusNotImplemented {
|
||||
t.Fatalf("expected 501 when fetcher unavailable, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
handler = NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil)
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/posts/title", nil)
|
||||
recorder = httptest.NewRecorder()
|
||||
handler.FetchTitleFromURL(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for missing url query, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
handler = NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{FetchTitleFn: func(ctx context.Context, rawURL string) (string, error) {
|
||||
return "", errors.New("failed")
|
||||
}}, nil)
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/posts/title?url=https://example.com", nil)
|
||||
recorder = httptest.NewRecorder()
|
||||
handler.FetchTitleFromURL(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
func TestTranslatePostCreateError(t *testing.T) {
|
||||
conflictErr := &pgconn.PgError{Code: "23505"}
|
||||
msg, status := translatePostCreateError(conflictErr)
|
||||
if status != http.StatusConflict || !strings.Contains(msg, "already been submitted") {
|
||||
t.Fatalf("unexpected conflict translation: status=%d msg=%q", status, msg)
|
||||
}
|
||||
|
||||
fkErr := &pgconn.PgError{Code: "23503"}
|
||||
msg, status = translatePostCreateError(fkErr)
|
||||
if status != http.StatusUnauthorized || !strings.Contains(msg, "Author account not found") {
|
||||
t.Fatalf("unexpected foreign key translation: status=%d msg=%q", status, msg)
|
||||
}
|
||||
|
||||
msg, status = translatePostCreateError(errors.New("other"))
|
||||
if status != 0 || msg != "" {
|
||||
t.Fatalf("expected passthrough for unrelated errors, got status=%d msg=%q", status, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostHandlerUpdatePost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
postID string
|
||||
requestBody string
|
||||
userID uint
|
||||
mockSetup func(*testutils.PostRepositoryStub)
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "valid post update",
|
||||
postID: "1",
|
||||
requestBody: `{"title": "Updated Title", "content": "Updated content"}`,
|
||||
userID: 1,
|
||||
mockSetup: func(repo *testutils.PostRepositoryStub) {
|
||||
repo.GetByIDFn = func(id uint) (*database.Post, error) {
|
||||
authorID := uint(1)
|
||||
return &database.Post{ID: id, Title: "Old Title", AuthorID: &authorID}, nil
|
||||
}
|
||||
repo.UpdateFn = func(post *database.Post) error { return nil }
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "missing user context",
|
||||
postID: "1",
|
||||
requestBody: `{"title": "Updated Title", "content": "Updated content"}`,
|
||||
userID: 0,
|
||||
mockSetup: func(repo *testutils.PostRepositoryStub) {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Authentication required",
|
||||
},
|
||||
{
|
||||
name: "post not found",
|
||||
postID: "999",
|
||||
requestBody: `{"title": "Updated Title", "content": "Updated content"}`,
|
||||
userID: 1,
|
||||
mockSetup: func(repo *testutils.PostRepositoryStub) {
|
||||
repo.GetByIDFn = func(id uint) (*database.Post, error) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedError: "Post not found",
|
||||
},
|
||||
{
|
||||
name: "not author",
|
||||
postID: "1",
|
||||
requestBody: `{"title": "Updated Title", "content": "Updated content"}`,
|
||||
userID: 2,
|
||||
mockSetup: func(repo *testutils.PostRepositoryStub) {
|
||||
repo.GetByIDFn = func(id uint) (*database.Post, error) {
|
||||
authorID := uint(1)
|
||||
return &database.Post{ID: id, Title: "Old Title", AuthorID: &authorID}, nil
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedError: "You can only edit your own posts",
|
||||
},
|
||||
{
|
||||
name: "empty title",
|
||||
postID: "1",
|
||||
requestBody: `{"title": "", "content": "Updated content"}`,
|
||||
userID: 1,
|
||||
mockSetup: func(repo *testutils.PostRepositoryStub) {
|
||||
authorID := uint(1)
|
||||
repo.GetByIDFn = func(id uint) (*database.Post, error) {
|
||||
return &database.Post{ID: id, Title: "Old Title", AuthorID: &authorID}, nil
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Title is required",
|
||||
},
|
||||
{
|
||||
name: "short title",
|
||||
postID: "1",
|
||||
requestBody: `{"title": "ab", "content": "Updated content"}`,
|
||||
userID: 1,
|
||||
mockSetup: func(repo *testutils.PostRepositoryStub) {
|
||||
authorID := uint(1)
|
||||
repo.GetByIDFn = func(id uint) (*database.Post, error) {
|
||||
return &database.Post{ID: id, Title: "Old Title", AuthorID: &authorID}, nil
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Title must be at least 3 characters",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := testutils.NewPostRepositoryStub()
|
||||
if tt.mockSetup != nil {
|
||||
tt.mockSetup(repo)
|
||||
}
|
||||
handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, nil)
|
||||
|
||||
request := httptest.NewRequest(http.MethodPut, "/api/posts/"+tt.postID, bytes.NewBufferString(tt.requestBody))
|
||||
if tt.userID > 0 {
|
||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
|
||||
request = request.WithContext(ctx)
|
||||
}
|
||||
|
||||
ctx := chi.NewRouteContext()
|
||||
ctx.URLParams.Add("id", tt.postID)
|
||||
request = request.WithContext(context.WithValue(request.Context(), chi.RouteCtxKey, ctx))
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.UpdatePost(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, tt.expectedStatus)
|
||||
|
||||
if tt.expectedError != "" {
|
||||
if !strings.Contains(recorder.Body.String(), tt.expectedError) {
|
||||
t.Fatalf("expected error to contain %q, got %q", tt.expectedError, recorder.Body.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostHandlerDeletePost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
postID string
|
||||
userID uint
|
||||
mockSetup func(*testutils.PostRepositoryStub)
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "valid post deletion",
|
||||
postID: "1",
|
||||
userID: 1,
|
||||
mockSetup: func(repo *testutils.PostRepositoryStub) {
|
||||
repo.GetByIDFn = func(id uint) (*database.Post, error) {
|
||||
authorID := uint(1)
|
||||
return &database.Post{ID: id, Title: "Test Post", AuthorID: &authorID}, nil
|
||||
}
|
||||
repo.DeleteFn = func(id uint) error { return nil }
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "missing user context",
|
||||
postID: "1",
|
||||
userID: 0,
|
||||
mockSetup: func(repo *testutils.PostRepositoryStub) {},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Authentication required",
|
||||
},
|
||||
{
|
||||
name: "post not found",
|
||||
postID: "999",
|
||||
userID: 1,
|
||||
mockSetup: func(repo *testutils.PostRepositoryStub) {
|
||||
repo.GetByIDFn = func(id uint) (*database.Post, error) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedError: "Post not found",
|
||||
},
|
||||
{
|
||||
name: "not author",
|
||||
postID: "1",
|
||||
userID: 2,
|
||||
mockSetup: func(repo *testutils.PostRepositoryStub) {
|
||||
repo.GetByIDFn = func(id uint) (*database.Post, error) {
|
||||
authorID := uint(1)
|
||||
return &database.Post{ID: id, Title: "Test Post", AuthorID: &authorID}, nil
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedError: "You can only delete your own posts",
|
||||
},
|
||||
{
|
||||
name: "delete error",
|
||||
postID: "1",
|
||||
userID: 1,
|
||||
mockSetup: func(repo *testutils.PostRepositoryStub) {
|
||||
repo.GetByIDFn = func(id uint) (*database.Post, error) {
|
||||
authorID := uint(1)
|
||||
return &database.Post{ID: id, Title: "Test Post", AuthorID: &authorID}, nil
|
||||
}
|
||||
repo.DeleteFn = func(id uint) error { return errors.New("database error") }
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to delete post",
|
||||
},
|
||||
{
|
||||
name: "delete votes error",
|
||||
postID: "1",
|
||||
userID: 1,
|
||||
mockSetup: func(repo *testutils.PostRepositoryStub) {
|
||||
repo.GetByIDFn = func(id uint) (*database.Post, error) {
|
||||
authorID := uint(1)
|
||||
return &database.Post{ID: id, Title: "Test Post", AuthorID: &authorID}, nil
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to delete post votes",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := testutils.NewPostRepositoryStub()
|
||||
if tt.mockSetup != nil {
|
||||
tt.mockSetup(repo)
|
||||
}
|
||||
|
||||
var voteService *services.VoteService
|
||||
if tt.name == "delete votes error" {
|
||||
voteRepo := &errorVoteRepository{}
|
||||
voteService = services.NewVoteService(voteRepo, repo, nil)
|
||||
} else {
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
voteService = services.NewVoteService(voteRepo, repo, nil)
|
||||
}
|
||||
|
||||
handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, voteService)
|
||||
|
||||
request := httptest.NewRequest(http.MethodDelete, "/api/posts/"+tt.postID, nil)
|
||||
if tt.userID > 0 {
|
||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
|
||||
request = request.WithContext(ctx)
|
||||
}
|
||||
|
||||
ctx := chi.NewRouteContext()
|
||||
ctx.URLParams.Add("id", tt.postID)
|
||||
request = request.WithContext(context.WithValue(request.Context(), chi.RouteCtxKey, ctx))
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.DeletePost(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, tt.expectedStatus)
|
||||
|
||||
if tt.expectedError != "" {
|
||||
if !strings.Contains(recorder.Body.String(), tt.expectedError) {
|
||||
t.Fatalf("expected error to contain %q, got %q", tt.expectedError, recorder.Body.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type errorVoteRepository struct{}
|
||||
|
||||
func (e *errorVoteRepository) Create(*database.Vote) error { return nil }
|
||||
func (e *errorVoteRepository) CreateOrUpdate(*database.Vote) error { return nil }
|
||||
func (e *errorVoteRepository) GetByID(uint) (*database.Vote, error) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
func (e *errorVoteRepository) GetByUserAndPost(uint, uint) (*database.Vote, error) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
func (e *errorVoteRepository) GetByVoteHash(string) (*database.Vote, error) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
func (e *errorVoteRepository) GetByPostID(uint) ([]database.Vote, error) {
|
||||
return nil, errors.New("database error")
|
||||
}
|
||||
func (e *errorVoteRepository) GetByUserID(uint) ([]database.Vote, error) { return nil, nil }
|
||||
func (e *errorVoteRepository) Update(*database.Vote) error { return nil }
|
||||
func (e *errorVoteRepository) Delete(uint) error { return nil }
|
||||
func (e *errorVoteRepository) Count() (int64, error) { return 0, nil }
|
||||
func (e *errorVoteRepository) CountByPostID(uint) (int64, error) { return 0, nil }
|
||||
func (e *errorVoteRepository) CountByUserID(uint) (int64, error) { return 0, nil }
|
||||
func (e *errorVoteRepository) WithTx(*gorm.DB) repositories.VoteRepository { return e }
|
||||
|
||||
func TestPostHandler_EdgeCases(t *testing.T) {
|
||||
postRepo := testutils.NewPostRepositoryStub()
|
||||
titleFetcher := &testutils.TitleFetcherStub{}
|
||||
handler := NewPostHandler(postRepo, titleFetcher, nil)
|
||||
|
||||
t.Run("GetPosts with zero limit", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/posts?limit=0", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetPosts(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for zero limit, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetPosts with negative limit", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/posts?limit=-1", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetPosts(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for negative limit, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetPosts with negative offset", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/posts?offset=-1", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.GetPosts(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for negative offset, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
21
internal/handlers/routes.go
Normal file
21
internal/handlers/routes.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"goyco/internal/middleware"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type RouteModule interface {
|
||||
MountRoutes(r chi.Router, config RouteModuleConfig)
|
||||
}
|
||||
|
||||
type RouteModuleConfig struct {
|
||||
AuthService middleware.TokenVerifier
|
||||
GeneralRateLimit func(chi.Router) chi.Router
|
||||
AuthRateLimit func(chi.Router) chi.Router
|
||||
CSRFMiddleware func(http.Handler) http.Handler
|
||||
AuthMiddleware func(http.Handler) http.Handler
|
||||
}
|
||||
412
internal/handlers/security_test.go
Normal file
412
internal/handlers/security_test.go
Normal file
@@ -0,0 +1,412 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/middleware"
|
||||
"goyco/internal/security"
|
||||
"goyco/internal/testutils"
|
||||
"goyco/internal/validation"
|
||||
)
|
||||
|
||||
func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
|
||||
maliciousInputs := testutils.GetMaliciousInputs()
|
||||
|
||||
for _, payload := range maliciousInputs.XSSPayloads {
|
||||
t.Run("XSS_"+payload[:minLen(20, len(payload))], func(t *testing.T) {
|
||||
repo := &testutils.PostRepositoryStub{
|
||||
CreateFn: func(post *database.Post) error {
|
||||
sanitizedTitle := security.SanitizeInput(payload)
|
||||
if post.Title != sanitizedTitle {
|
||||
t.Errorf("Expected sanitized title, got %q", post.Title)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewPostHandler(repo, nil, nil)
|
||||
|
||||
postData := map[string]string{
|
||||
"title": payload,
|
||||
"url": "https://example.com",
|
||||
"content": "Test content",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(postData)
|
||||
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.CreatePost(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func minLen(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func TestPostHandler_InputValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
title string
|
||||
content string
|
||||
url string
|
||||
expectedStatus int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "title too long",
|
||||
title: string(make([]byte, 201)),
|
||||
content: "Normal content",
|
||||
url: "https://example.com",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Title should be limited to 200 characters",
|
||||
},
|
||||
{
|
||||
name: "content too long",
|
||||
title: "Normal title",
|
||||
content: string(make([]byte, 10001)),
|
||||
url: "https://example.com",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Content should be limited to 10,000 characters",
|
||||
},
|
||||
{
|
||||
name: "invalid URL protocol",
|
||||
title: "Normal title",
|
||||
content: "Normal content",
|
||||
url: "ftp://example.com",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Only HTTP and HTTPS URLs should be allowed",
|
||||
},
|
||||
{
|
||||
name: "localhost URL blocked",
|
||||
title: "Normal title",
|
||||
content: "Normal content",
|
||||
url: "http://localhost:8080",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Localhost URLs should be blocked",
|
||||
},
|
||||
{
|
||||
name: "private IP URL blocked",
|
||||
title: "Normal title",
|
||||
content: "Normal content",
|
||||
url: "http://192.168.1.1",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Private IP URLs should be blocked",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &testutils.PostRepositoryStub{}
|
||||
handler := NewPostHandler(repo, nil, nil)
|
||||
|
||||
postData := map[string]string{
|
||||
"title": tt.title,
|
||||
"url": tt.url,
|
||||
"content": tt.content,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(postData)
|
||||
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.CreatePost(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, tt.expectedStatus)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_PasswordValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
expectedStatus int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "weak password",
|
||||
password: "123",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Weak passwords should be rejected",
|
||||
},
|
||||
{
|
||||
name: "password without letters",
|
||||
password: "12345678",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Passwords without letters should be rejected",
|
||||
},
|
||||
{
|
||||
name: "password without numbers",
|
||||
password: "password",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Passwords without numbers should be rejected",
|
||||
},
|
||||
{
|
||||
name: "password without special chars",
|
||||
password: "Password123",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Passwords without special characters should be rejected",
|
||||
},
|
||||
{
|
||||
name: "password too short",
|
||||
password: "Pass1!",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Passwords shorter than 8 characters should be rejected",
|
||||
},
|
||||
{
|
||||
name: "password too long",
|
||||
password: string(make([]byte, 129)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Passwords that are too long should be rejected",
|
||||
},
|
||||
{
|
||||
name: "empty password",
|
||||
password: "",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Empty passwords should be rejected",
|
||||
},
|
||||
{
|
||||
name: "valid password",
|
||||
password: "Password123!",
|
||||
expectedStatus: http.StatusCreated,
|
||||
description: "Valid passwords should be accepted",
|
||||
},
|
||||
{
|
||||
name: "valid password with underscore",
|
||||
password: "Password123_",
|
||||
expectedStatus: http.StatusCreated,
|
||||
description: "Valid passwords with underscore should be accepted",
|
||||
},
|
||||
{
|
||||
name: "valid password with hyphen",
|
||||
password: "Password123-",
|
||||
expectedStatus: http.StatusCreated,
|
||||
description: "Valid passwords with hyphen should be accepted",
|
||||
},
|
||||
{
|
||||
name: "valid password with unicode",
|
||||
password: "Pássw0rd123!",
|
||||
expectedStatus: http.StatusCreated,
|
||||
description: "Valid passwords with unicode should be accepted",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &testutils.UserRepositoryStub{
|
||||
GetByUsernameFn: func(string) (*database.User, error) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
},
|
||||
CreateFn: func(user *database.User) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := newAuthHandler(repo)
|
||||
|
||||
registerData := map[string]string{
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"password": tt.password,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(registerData)
|
||||
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.Register(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, tt.expectedStatus)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_UsernameSanitization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
expectedStatus int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "username with special chars",
|
||||
username: "test@user#123",
|
||||
expectedStatus: http.StatusCreated,
|
||||
description: "Special characters should be removed from username",
|
||||
},
|
||||
{
|
||||
name: "username with script tags",
|
||||
username: "test<script>alert('xss')</script>user",
|
||||
expectedStatus: http.StatusCreated,
|
||||
description: "Script tags should be removed from username",
|
||||
},
|
||||
{
|
||||
name: "username starting with special char",
|
||||
username: "@testuser",
|
||||
expectedStatus: http.StatusCreated,
|
||||
description: "Username starting with special char should be prefixed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedUsername string
|
||||
repo := &testutils.UserRepositoryStub{
|
||||
GetByUsernameFn: func(username string) (*database.User, error) {
|
||||
capturedUsername = username
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
},
|
||||
CreateFn: func(user *database.User) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := newAuthHandler(repo)
|
||||
|
||||
registerData := map[string]string{
|
||||
"username": tt.username,
|
||||
"email": "test@example.com",
|
||||
"password": "Password123!",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(registerData)
|
||||
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.Register(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, tt.expectedStatus)
|
||||
|
||||
expectedUsername := security.SanitizeUsername(tt.username)
|
||||
if capturedUsername != expectedUsername {
|
||||
t.Errorf("Expected sanitized username %q, got %q", expectedUsername, capturedUsername)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostHandler_AuthorizationBypass(t *testing.T) {
|
||||
repo := &testutils.PostRepositoryStub{
|
||||
GetByIDFn: func(id uint) (*database.Post, error) {
|
||||
authorID := uint(2)
|
||||
return &database.Post{ID: id, Title: "Test Post", AuthorID: &authorID}, nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewPostHandler(repo, nil, nil)
|
||||
|
||||
updateData := map[string]string{
|
||||
"title": "Updated Title",
|
||||
"content": "Updated content",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(updateData)
|
||||
request := httptest.NewRequest("PUT", "/api/posts/1", bytes.NewBuffer(body))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
|
||||
|
||||
routeCtx := chi.NewRouteContext()
|
||||
routeCtx.URLParams.Add("id", "1")
|
||||
request = request.WithContext(context.WithValue(request.Context(), chi.RouteCtxKey, routeCtx))
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.UpdatePost(recorder, request)
|
||||
|
||||
if recorder.Result().StatusCode != http.StatusForbidden {
|
||||
t.Errorf("Expected status 403, got %d. Users should not be able to edit other users' posts", recorder.Result().StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPageHandler_PasswordResetValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
expectedError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid password",
|
||||
password: "Password123!",
|
||||
expectedError: false,
|
||||
description: "Valid passwords should pass validation",
|
||||
},
|
||||
{
|
||||
name: "password without special chars",
|
||||
password: "Password123",
|
||||
expectedError: true,
|
||||
description: "Passwords without special characters should be rejected",
|
||||
},
|
||||
{
|
||||
name: "password too short",
|
||||
password: "Pass1!",
|
||||
expectedError: true,
|
||||
description: "Passwords shorter than 8 characters should be rejected",
|
||||
},
|
||||
{
|
||||
name: "password without letters",
|
||||
password: "12345678!",
|
||||
expectedError: true,
|
||||
description: "Passwords without letters should be rejected",
|
||||
},
|
||||
{
|
||||
name: "password without numbers",
|
||||
password: "Password!",
|
||||
expectedError: true,
|
||||
description: "Passwords without numbers should be rejected",
|
||||
},
|
||||
{
|
||||
name: "empty password",
|
||||
password: "",
|
||||
expectedError: true,
|
||||
description: "Empty passwords should be rejected",
|
||||
},
|
||||
{
|
||||
name: "password too long",
|
||||
password: string(make([]byte, 129)),
|
||||
expectedError: true,
|
||||
description: "Passwords longer than 128 characters should be rejected",
|
||||
},
|
||||
{
|
||||
name: "valid password with unicode",
|
||||
password: "Pássw0rd123!",
|
||||
expectedError: false,
|
||||
description: "Valid passwords with unicode should pass validation",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validation.ValidatePassword(tt.password)
|
||||
|
||||
if tt.expectedError && err == nil {
|
||||
t.Errorf("ValidatePassword(%q) expected error, got nil. %s", tt.password, tt.description)
|
||||
}
|
||||
if !tt.expectedError && err != nil {
|
||||
t.Errorf("ValidatePassword(%q) unexpected error: %v. %s", tt.password, err, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
195
internal/handlers/user_handler.go
Normal file
195
internal/handlers/user_handler.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"goyco/internal/dto"
|
||||
"goyco/internal/repositories"
|
||||
"goyco/internal/validation"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type UserHandler struct {
|
||||
userRepo repositories.UserRepository
|
||||
authService AuthServiceInterface
|
||||
}
|
||||
|
||||
func NewUserHandler(userRepo repositories.UserRepository, authService AuthServiceInterface) *UserHandler {
|
||||
return &UserHandler{
|
||||
userRepo: userRepo,
|
||||
authService: authService,
|
||||
}
|
||||
}
|
||||
|
||||
type UserResponse = CommonResponse
|
||||
|
||||
// @Summary List users
|
||||
// @Description Retrieve a paginated list of users
|
||||
// @Tags users
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Number of users to return" default(20)
|
||||
// @Param offset query int false "Number of users to skip" default(0)
|
||||
// @Success 200 {object} UserResponse "Users retrieved successfully"
|
||||
// @Failure 401 {object} UserResponse "Authentication required"
|
||||
// @Failure 500 {object} UserResponse "Internal server error"
|
||||
// @Router /users [get]
|
||||
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
||||
limit, offset := parsePagination(r)
|
||||
|
||||
users, err := h.userRepo.GetAll(limit, offset)
|
||||
if err != nil {
|
||||
SendErrorResponse(w, "Failed to fetch users", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
userDTOs := dto.ToSanitizedUserDTOs(users)
|
||||
|
||||
SendSuccessResponse(w, "Users retrieved successfully", map[string]any{
|
||||
"users": userDTOs,
|
||||
"count": len(userDTOs),
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Get user
|
||||
// @Description Retrieve a specific user by ID
|
||||
// @Tags users
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "User ID"
|
||||
// @Success 200 {object} UserResponse "User retrieved successfully"
|
||||
// @Failure 400 {object} UserResponse "Invalid user ID"
|
||||
// @Failure 401 {object} UserResponse "Authentication required"
|
||||
// @Failure 404 {object} UserResponse "User not found"
|
||||
// @Failure 500 {object} UserResponse "Internal server error"
|
||||
// @Router /users/{id} [get]
|
||||
func (h *UserHandler) GetUser(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := ParseUintParam(w, r, "id", "User")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userRepo.GetByID(userID)
|
||||
if !HandleRepoError(w, err, "User") {
|
||||
return
|
||||
}
|
||||
|
||||
userDTO := dto.ToSanitizedUserDTO(user)
|
||||
|
||||
SendSuccessResponse(w, "User retrieved successfully", userDTO)
|
||||
}
|
||||
|
||||
// @Summary Create user
|
||||
// @Description Create a new user account
|
||||
// @Tags users
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body RegisterRequest true "User data"
|
||||
// @Success 201 {object} UserResponse "User created successfully"
|
||||
// @Failure 400 {object} UserResponse "Invalid request data or validation failed"
|
||||
// @Failure 401 {object} UserResponse "Authentication required"
|
||||
// @Failure 409 {object} UserResponse "Username or email already exists"
|
||||
// @Failure 500 {object} UserResponse "Internal server error"
|
||||
// @Router /users [post]
|
||||
func (h *UserHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := validation.ValidateUsername(req.Username); err != nil {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validation.ValidateEmail(req.Email); err != nil {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validation.ValidatePassword(req.Password); err != nil {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.authService.Register(req.Username, req.Email, req.Password)
|
||||
if err != nil {
|
||||
var validationErr *validation.ValidationError
|
||||
if errors.As(err, &validationErr) {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !HandleServiceError(w, err, "Failed to create user", http.StatusInternalServerError) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
SendCreatedResponse(w, "User created successfully. Verification email sent.", map[string]any{
|
||||
"user": result.User,
|
||||
"verification_sent": result.VerificationSent,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Get user posts
|
||||
// @Description Retrieve posts created by a specific user
|
||||
// @Tags users
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "User ID"
|
||||
// @Param limit query int false "Number of posts to return" default(20)
|
||||
// @Param offset query int false "Number of posts to skip" default(0)
|
||||
// @Success 200 {object} UserResponse "User posts retrieved successfully"
|
||||
// @Failure 400 {object} UserResponse "Invalid user ID or pagination parameters"
|
||||
// @Failure 401 {object} UserResponse "Authentication required"
|
||||
// @Failure 500 {object} UserResponse "Internal server error"
|
||||
// @Router /users/{id}/posts [get]
|
||||
func (h *UserHandler) GetUserPosts(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := ParseUintParam(w, r, "id", "User")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
limit, offset := parsePagination(r)
|
||||
|
||||
posts, err := h.userRepo.GetPosts(userID, limit, offset)
|
||||
if err != nil {
|
||||
SendErrorResponse(w, "Failed to fetch user posts", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
postDTOs := dto.ToPostDTOs(posts)
|
||||
SendSuccessResponse(w, "User posts retrieved successfully", map[string]any{
|
||||
"posts": postDTOs,
|
||||
"count": len(postDTOs),
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *UserHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
|
||||
protected := r
|
||||
if config.AuthMiddleware != nil {
|
||||
protected = r.With(config.AuthMiddleware)
|
||||
}
|
||||
if config.GeneralRateLimit != nil {
|
||||
protected = config.GeneralRateLimit(protected)
|
||||
}
|
||||
|
||||
protected.Get("/users", h.GetUsers)
|
||||
protected.Post("/users", h.CreateUser)
|
||||
protected.Get("/users/{id}", h.GetUser)
|
||||
protected.Get("/users/{id}/posts", h.GetUserPosts)
|
||||
}
|
||||
362
internal/handlers/user_handler_test.go
Normal file
362
internal/handlers/user_handler_test.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/repositories"
|
||||
"goyco/internal/services"
|
||||
"goyco/internal/testutils"
|
||||
)
|
||||
|
||||
func newUserHandler(repo repositories.UserRepository) *UserHandler {
|
||||
return newUserHandlerWithSender(repo, &testutils.EmailSenderStub{})
|
||||
}
|
||||
|
||||
func newUserHandlerWithSender(repo repositories.UserRepository, sender services.EmailSender) *UserHandler {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{Secret: "secret", Expiration: 1},
|
||||
App: config.AppConfig{BaseURL: "https://test.example.com"},
|
||||
}
|
||||
mockRefreshRepo := &mockRefreshTokenRepository{}
|
||||
authService, err := services.NewAuthFacadeForTest(cfg, repo, nil, nil, mockRefreshRepo, sender)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to create auth service: %v", err))
|
||||
}
|
||||
return NewUserHandler(repo, authService)
|
||||
}
|
||||
|
||||
func TestUserHandlerGetUsers(t *testing.T) {
|
||||
var limit, offset int
|
||||
repo := testutils.NewUserRepositoryStub()
|
||||
repo.GetAllFn = func(l, o int) ([]database.User, error) {
|
||||
limit, offset = l, o
|
||||
return []database.User{{ID: 1}}, nil
|
||||
}
|
||||
|
||||
handler := newUserHandler(repo)
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/users?limit=5&offset=2", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.GetUsers(recorder, request)
|
||||
|
||||
if limit != 5 || offset != 2 {
|
||||
t.Fatalf("expected limit=5 offset=2, got %d %d", limit, offset)
|
||||
}
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
}
|
||||
|
||||
func TestUserHandlerGetUser(t *testing.T) {
|
||||
repo := testutils.NewUserRepositoryStub()
|
||||
handler := newUserHandler(repo)
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/users/1", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.GetUser(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
|
||||
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/users/abc", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
|
||||
recorder = httptest.NewRecorder()
|
||||
handler.GetUser(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
|
||||
|
||||
repo.GetByIDFn = func(uint) (*database.User, error) { return nil, gorm.ErrRecordNotFound }
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/users/1", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
recorder = httptest.NewRecorder()
|
||||
handler.GetUser(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusNotFound)
|
||||
|
||||
repo.GetByIDFn = func(id uint) (*database.User, error) {
|
||||
return &database.User{ID: id, Username: "user"}, nil
|
||||
}
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/users/1", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
recorder = httptest.NewRecorder()
|
||||
handler.GetUser(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
}
|
||||
|
||||
func TestUserHandlerCreateUser(t *testing.T) {
|
||||
repo := testutils.NewUserRepositoryStub()
|
||||
repo.CreateFn = func(u *database.User) error {
|
||||
u.ID = 10
|
||||
return nil
|
||||
}
|
||||
sent := false
|
||||
handler := newUserHandlerWithSender(repo, &testutils.EmailSenderStub{SendFn: func(to, subject, body string) error {
|
||||
sent = true
|
||||
if to != "user@example.com" {
|
||||
t.Fatalf("expected email to user@example.com, got %q", to)
|
||||
}
|
||||
return nil
|
||||
}})
|
||||
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`))
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.CreateUser(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
|
||||
|
||||
var resp UserResponse
|
||||
_ = json.NewDecoder(recorder.Body).Decode(&resp)
|
||||
data := resp.Data.(map[string]any)
|
||||
if !resp.Success {
|
||||
t.Fatalf("expected success response")
|
||||
}
|
||||
if v, ok := data["verification_sent"].(bool); !ok || !v {
|
||||
t.Fatalf("expected verification_sent true, got %+v", data["verification_sent"])
|
||||
}
|
||||
userData := data["user"].(map[string]any)
|
||||
if _, ok := userData["password"]; ok {
|
||||
t.Fatalf("expected password field to be omitted, got %+v", userData)
|
||||
}
|
||||
if !sent {
|
||||
t.Fatalf("expected verification email to be sent")
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString("invalid"))
|
||||
handler.CreateUser(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"","email":"","password":""}`))
|
||||
handler.CreateUser(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for missing fields, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
repo.GetByUsernameFn = func(string) (*database.User, error) {
|
||||
return &database.User{ID: 1}, nil
|
||||
}
|
||||
handler = newUserHandler(repo)
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`))
|
||||
handler.CreateUser(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
|
||||
}
|
||||
|
||||
func TestUserHandlerGetUserPosts(t *testing.T) {
|
||||
repo := testutils.NewUserRepositoryStub()
|
||||
repo.GetPostsFn = func(userID uint, limit, offset int) ([]database.Post, error) {
|
||||
return []database.Post{{ID: 1, AuthorID: &userID}}, nil
|
||||
}
|
||||
handler := newUserHandler(repo)
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/users/1/posts?limit=2&offset=1", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.GetUserPosts(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
|
||||
repo.GetPostsFn = func(uint, int, int) ([]database.Post, error) {
|
||||
return nil, gorm.ErrInvalidValue
|
||||
}
|
||||
recorder = httptest.NewRecorder()
|
||||
handler.GetUserPosts(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
func TestUserHandlerDataSanitization(t *testing.T) {
|
||||
repo := testutils.NewUserRepositoryStub()
|
||||
repo.GetAllFn = func(l, o int) ([]database.User, error) {
|
||||
users := []database.User{
|
||||
{
|
||||
ID: 1,
|
||||
Username: "user1",
|
||||
Email: "user1@example.com",
|
||||
Password: "hashedpassword",
|
||||
EmailVerified: true,
|
||||
EmailVerifiedAt: &[]time.Time{time.Now()}[0],
|
||||
EmailVerificationToken: "secret-token",
|
||||
PasswordResetToken: "reset-token",
|
||||
Locked: false,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Username: "user2",
|
||||
Email: "user2@example.com",
|
||||
Password: "another-hashed-password",
|
||||
EmailVerified: false,
|
||||
EmailVerificationToken: "another-secret-token",
|
||||
Locked: true,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
handler := newUserHandler(repo)
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/users", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.GetUsers(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
|
||||
var response map[string]any
|
||||
if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
data, ok := response["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected data field in response")
|
||||
}
|
||||
|
||||
users, ok := data["users"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected users field in data")
|
||||
}
|
||||
|
||||
if len(users) != 2 {
|
||||
t.Fatalf("expected 2 users, got %d", len(users))
|
||||
}
|
||||
|
||||
for i, userInterface := range users {
|
||||
user, ok := userInterface.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected user %d to be a map", i)
|
||||
}
|
||||
|
||||
expectedFields := []string{"id", "username", "created_at", "updated_at"}
|
||||
for _, field := range expectedFields {
|
||||
if _, exists := user[field]; !exists {
|
||||
t.Errorf("expected field %s to be present in user %d", field, i)
|
||||
}
|
||||
}
|
||||
|
||||
sensitiveFields := []string{"email", "password", "email_verified", "email_verified_at",
|
||||
"email_verification_token", "password_reset_token", "locked", "deleted_at"}
|
||||
for _, field := range sensitiveFields {
|
||||
if _, exists := user[field]; exists {
|
||||
t.Errorf("sensitive field %s should not be present in user %d", field, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_PasswordValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
expectedStatus int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid password",
|
||||
password: "Password123!",
|
||||
expectedStatus: http.StatusCreated,
|
||||
description: "Valid passwords should be accepted",
|
||||
},
|
||||
{
|
||||
name: "password without special chars",
|
||||
password: "Password123",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Passwords without special characters should be rejected",
|
||||
},
|
||||
{
|
||||
name: "password too short",
|
||||
password: "Pass1!",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Passwords shorter than 8 characters should be rejected",
|
||||
},
|
||||
{
|
||||
name: "password without letters",
|
||||
password: "12345678!",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Passwords without letters should be rejected",
|
||||
},
|
||||
{
|
||||
name: "password without numbers",
|
||||
password: "Password!",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Passwords without numbers should be rejected",
|
||||
},
|
||||
{
|
||||
name: "empty password",
|
||||
password: "",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Empty passwords should be rejected",
|
||||
},
|
||||
{
|
||||
name: "password too long",
|
||||
password: string(make([]byte, 129)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
description: "Passwords longer than 128 characters should be rejected",
|
||||
},
|
||||
{
|
||||
name: "valid password with unicode",
|
||||
password: "Pássw0rd123!",
|
||||
expectedStatus: http.StatusCreated,
|
||||
description: "Valid passwords with unicode should be accepted",
|
||||
},
|
||||
{
|
||||
name: "valid password with underscore",
|
||||
password: "Password123_",
|
||||
expectedStatus: http.StatusCreated,
|
||||
description: "Valid passwords with underscore should be accepted",
|
||||
},
|
||||
{
|
||||
name: "valid password with hyphen",
|
||||
password: "Password123-",
|
||||
expectedStatus: http.StatusCreated,
|
||||
description: "Valid passwords with hyphen should be accepted",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := testutils.NewUserRepositoryStub()
|
||||
repo.CreateFn = func(user *database.User) error {
|
||||
return nil
|
||||
}
|
||||
repo.GetByUsernameFn = func(username string) (*database.User, error) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
repo.GetByEmailFn = func(email string) (*database.User, error) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{Secret: "secret", Expiration: 1},
|
||||
App: config.AppConfig{BaseURL: "https://test.example.com"},
|
||||
}
|
||||
emailSender := &testutils.MockEmailSender{}
|
||||
mockRefreshRepo := &mockRefreshTokenRepository{}
|
||||
authService, err := services.NewAuthFacadeForTest(cfg, repo, nil, nil, mockRefreshRepo, emailSender)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create auth service: %v", err)
|
||||
}
|
||||
handler := NewUserHandler(repo, authService)
|
||||
|
||||
requestBody := fmt.Sprintf(`{"username":"testuser","email":"test@example.com","password":"%s"}`, tt.password)
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(requestBody))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.CreateUser(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, tt.expectedStatus)
|
||||
})
|
||||
}
|
||||
}
|
||||
293
internal/handlers/vote_handler.go
Normal file
293
internal/handlers/vote_handler.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/services"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// @securityDefinitions.apikey BearerAuth
|
||||
// @in header
|
||||
// @name Authorization
|
||||
// @description Type "Bearer" followed by a space and JWT token.
|
||||
|
||||
// @tag.name votes
|
||||
// @tag.description Voting system endpoints. All votes are handled through the same API with identical behavior.
|
||||
|
||||
// @tag.name posts
|
||||
// @tag.description Post management endpoints with integrated vote statistics.
|
||||
|
||||
// @tag.name auth
|
||||
// @tag.description Authentication and user management endpoints.
|
||||
|
||||
// @tag.name users
|
||||
// @tag.description User management endpoints.
|
||||
|
||||
// @tag.name api
|
||||
// @tag.description API information and system metrics.
|
||||
|
||||
type VoteHandler struct {
|
||||
voteService *services.VoteService
|
||||
}
|
||||
|
||||
func NewVoteHandler(voteService *services.VoteService) *VoteHandler {
|
||||
return &VoteHandler{
|
||||
voteService: voteService,
|
||||
}
|
||||
}
|
||||
|
||||
// @Description Vote request with type field. All votes are handled the same way.
|
||||
type VoteRequest struct {
|
||||
Type string `json:"type" example:"up" enums:"up,down,none" description:"Vote type: 'up' for upvote, 'down' for downvote, 'none' to remove vote"`
|
||||
}
|
||||
|
||||
type VoteResponse = CommonResponse
|
||||
|
||||
// @Summary Cast a vote on a post
|
||||
// @Description Vote on a post (upvote, downvote, or remove vote). Authentication is required; the vote is performed on behalf of the current user.
|
||||
// @Description
|
||||
// @Description **Vote Types:**
|
||||
// @Description - `up`: Upvote the post
|
||||
// @Description - `down`: Downvote the post
|
||||
// @Description - `none`: Remove existing vote
|
||||
// @Description
|
||||
// @Description **Response includes:**
|
||||
// @Description - Updated post vote counts (up_votes, down_votes, score)
|
||||
// @Description - Success message
|
||||
// @Tags votes
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "Post ID"
|
||||
// @Param request body VoteRequest true "Vote data (type: 'up', 'down', or 'none' to remove)"
|
||||
// @Success 200 {object} VoteResponse "Vote cast successfully with updated post statistics"
|
||||
// @Failure 401 {object} VoteResponse "Authentication required"
|
||||
// @Failure 400 {object} VoteResponse "Invalid request data or vote type"
|
||||
// @Failure 404 {object} VoteResponse "Post not found"
|
||||
// @Failure 500 {object} VoteResponse "Internal server error"
|
||||
// @Example 200 {"success": true, "message": "Vote cast successfully", "data": {"post_id": 1, "type": "up", "up_votes": 5, "down_votes": 2, "score": 3, "is_anonymous": false}}
|
||||
// @Example 400 {"success": false, "error": "Invalid vote type. Must be 'up', 'down', or 'none'"}
|
||||
// @Router /posts/{id}/vote [post]
|
||||
func (h *VoteHandler) CastVote(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := RequireAuth(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
postID, ok := ParseUintParam(w, r, "id", "Post")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req VoteRequest
|
||||
if !DecodeJSONRequest(w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
var voteType database.VoteType
|
||||
switch req.Type {
|
||||
case "up":
|
||||
voteType = database.VoteUp
|
||||
case "down":
|
||||
voteType = database.VoteDown
|
||||
case "none":
|
||||
voteType = database.VoteNone
|
||||
default:
|
||||
SendErrorResponse(w, "Invalid vote type. Must be 'up', 'down', or 'none'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ipAddress := GetClientIP(r)
|
||||
userAgent := r.UserAgent()
|
||||
|
||||
serviceReq := services.VoteRequest{
|
||||
UserID: userID,
|
||||
PostID: postID,
|
||||
Type: voteType,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
}
|
||||
|
||||
response, err := h.voteService.CastVote(serviceReq)
|
||||
if err != nil {
|
||||
if err.Error() == "post not found" {
|
||||
SendErrorResponse(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if err.Error() == "post ID is required" || err.Error() == "invalid vote type" {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
SendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Vote cast successfully", response)
|
||||
}
|
||||
|
||||
// @Summary Remove a vote
|
||||
// @Description Remove a vote from a post for the authenticated user. This is equivalent to casting a vote with type 'none'.
|
||||
// @Tags votes
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "Post ID"
|
||||
// @Success 200 {object} VoteResponse "Vote removed successfully with updated post statistics"
|
||||
// @Failure 401 {object} VoteResponse "Authentication required"
|
||||
// @Failure 400 {object} VoteResponse "Invalid post ID"
|
||||
// @Failure 404 {object} VoteResponse "Post not found"
|
||||
// @Failure 500 {object} VoteResponse "Internal server error"
|
||||
// @Router /posts/{id}/vote [delete]
|
||||
func (h *VoteHandler) RemoveVote(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := RequireAuth(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
postID, ok := ParseUintParam(w, r, "id", "Post")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ipAddress := GetClientIP(r)
|
||||
userAgent := r.UserAgent()
|
||||
|
||||
serviceReq := services.VoteRequest{
|
||||
UserID: userID,
|
||||
PostID: postID,
|
||||
Type: database.VoteNone,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
}
|
||||
|
||||
response, err := h.voteService.CastVote(serviceReq)
|
||||
if err != nil {
|
||||
if err.Error() == "post not found" {
|
||||
SendErrorResponse(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if err.Error() == "post ID is required" {
|
||||
SendErrorResponse(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
SendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Vote removed successfully", response)
|
||||
}
|
||||
|
||||
// @Summary Get current user's vote
|
||||
// @Description Retrieve the current user's vote for a specific post. Requires authentication and returns the vote type if it exists.
|
||||
// @Description
|
||||
// @Description **Response:**
|
||||
// @Description - If vote exists: Returns vote details with contextual metadata (including `is_anonymous`)
|
||||
// @Description - If no vote: Returns success with null vote data and metadata
|
||||
// @Tags votes
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "Post ID"
|
||||
// @Success 200 {object} VoteResponse "Vote retrieved successfully"
|
||||
// @Success 200 {object} VoteResponse "No vote found for this user/post combination"
|
||||
// @Failure 401 {object} VoteResponse "Authentication required"
|
||||
// @Failure 400 {object} VoteResponse "Invalid post ID"
|
||||
// @Failure 500 {object} VoteResponse "Internal server error"
|
||||
// @Example 200 {"success": true, "message": "Vote retrieved successfully", "data": {"has_vote": true, "vote": {"type": "up", "user_id": 123}, "is_anonymous": false}}
|
||||
// @Example 200 {"success": true, "message": "No vote found", "data": {"has_vote": false, "vote": null, "is_anonymous": false}}
|
||||
// @Router /posts/{id}/vote [get]
|
||||
func (h *VoteHandler) GetUserVote(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := RequireAuth(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
postID, ok := ParseUintParam(w, r, "id", "Post")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ipAddress := GetClientIP(r)
|
||||
userAgent := r.UserAgent()
|
||||
|
||||
vote, err := h.voteService.GetUserVote(userID, postID, ipAddress, userAgent)
|
||||
if err != nil {
|
||||
if err.Error() == "record not found" {
|
||||
SendSuccessResponse(w, "No vote found", map[string]any{
|
||||
"has_vote": false,
|
||||
"vote": nil,
|
||||
"is_anonymous": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
SendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Vote retrieved successfully", map[string]any{
|
||||
"has_vote": true,
|
||||
"vote": vote,
|
||||
"is_anonymous": false,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Get post votes
|
||||
// @Description Retrieve all votes for a specific post. Returns all votes in a single format.
|
||||
// @Description
|
||||
// @Description **Authentication Required:** Yes (Bearer token)
|
||||
// @Description
|
||||
// @Description **Response includes:**
|
||||
// @Description - Array of all votes
|
||||
// @Description - Total vote count
|
||||
// @Description - Each vote includes type and unauthenticated status
|
||||
// @Tags votes
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "Post ID"
|
||||
// @Success 200 {object} VoteResponse "Votes retrieved successfully with count"
|
||||
// @Failure 400 {object} VoteResponse "Invalid post ID"
|
||||
// @Failure 401 {object} VoteResponse "Authentication required"
|
||||
// @Failure 500 {object} VoteResponse "Internal server error"
|
||||
// @Example 200 {"success": true, "message": "Votes retrieved successfully", "data": {"votes": [{"type": "up", "user_id": 123}, {"type": "down", "vote_hash": "abc123"}], "count": 2}}
|
||||
// @Router /posts/{id}/votes [get]
|
||||
func (h *VoteHandler) GetPostVotes(w http.ResponseWriter, r *http.Request) {
|
||||
postID, ok := ParseUintParam(w, r, "id", "Post")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
votes, err := h.voteService.GetPostVotes(postID)
|
||||
if err != nil {
|
||||
SendErrorResponse(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
allVotes := make([]any, 0, len(votes))
|
||||
for _, vote := range votes {
|
||||
allVotes = append(allVotes, vote)
|
||||
}
|
||||
|
||||
SendSuccessResponse(w, "Votes retrieved successfully", map[string]any{
|
||||
"votes": allVotes,
|
||||
"count": len(allVotes),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *VoteHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
|
||||
protected := r
|
||||
if config.AuthMiddleware != nil {
|
||||
protected = r.With(config.AuthMiddleware)
|
||||
}
|
||||
if config.GeneralRateLimit != nil {
|
||||
protected = config.GeneralRateLimit(protected)
|
||||
}
|
||||
|
||||
protected.Post("/posts/{id}/vote", h.CastVote)
|
||||
protected.Delete("/posts/{id}/vote", h.RemoveVote)
|
||||
protected.Get("/posts/{id}/vote", h.GetUserVote)
|
||||
protected.Get("/posts/{id}/votes", h.GetPostVotes)
|
||||
}
|
||||
482
internal/handlers/vote_handler_test.go
Normal file
482
internal/handlers/vote_handler_test.go
Normal file
@@ -0,0 +1,482 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/database"
|
||||
"goyco/internal/middleware"
|
||||
"goyco/internal/services"
|
||||
"goyco/internal/testutils"
|
||||
)
|
||||
|
||||
func newVoteHandlerWithRepos() *VoteHandler {
|
||||
handler, _, _ := newVoteHandlerWithReposRefs()
|
||||
return handler
|
||||
}
|
||||
|
||||
func newVoteHandlerWithReposRefs() (*VoteHandler, *testutils.MockVoteRepository, map[uint]*database.Post) {
|
||||
voteRepo := testutils.NewMockVoteRepository()
|
||||
posts := map[uint]*database.Post{
|
||||
1: {ID: 1},
|
||||
}
|
||||
postRepo := testutils.NewPostRepositoryStub()
|
||||
postRepo.GetByIDFn = func(id uint) (*database.Post, error) {
|
||||
if post, ok := posts[id]; ok {
|
||||
copy := *post
|
||||
return ©, nil
|
||||
}
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
postRepo.UpdateFn = func(post *database.Post) error {
|
||||
copy := *post
|
||||
posts[post.ID] = ©
|
||||
return nil
|
||||
}
|
||||
postRepo.DeleteFn = func(id uint) error {
|
||||
if _, ok := posts[id]; !ok {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
delete(posts, id)
|
||||
return nil
|
||||
}
|
||||
postRepo.CreateFn = func(post *database.Post) error {
|
||||
copy := *post
|
||||
posts[post.ID] = ©
|
||||
return nil
|
||||
}
|
||||
service := services.NewVoteService(voteRepo, postRepo, nil)
|
||||
return NewVoteHandler(service), voteRepo, posts
|
||||
}
|
||||
|
||||
func TestVoteHandlerCastVote(t *testing.T) {
|
||||
handler := newVoteHandlerWithRepos()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
handler.CastVote(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/abc/vote", bytes.NewBufferString(`{"type":"up"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
|
||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`invalid`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"maybe"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for invalid vote type, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for successful down vote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for successful none vote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoteHandlerCastVotePostNotFound(t *testing.T) {
|
||||
handler, _, posts := newVoteHandlerWithReposRefs()
|
||||
delete(posts, 1)
|
||||
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.CastVote(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusNotFound)
|
||||
}
|
||||
|
||||
func TestVoteHandlerRemoveVote(t *testing.T) {
|
||||
handler := newVoteHandlerWithRepos()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodDelete, "/api/posts/1/vote", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
handler.RemoveVote(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodDelete, "/api/posts/abc/vote", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
|
||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.RemoveVote(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodDelete, "/api/posts/1/vote", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.RemoveVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for removing non-existent vote (idempotent), got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for creating vote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodDelete, "/api/posts/1/vote", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.RemoveVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 when removing vote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoteHandlerRemoveVotePostNotFound(t *testing.T) {
|
||||
handler, _, posts := newVoteHandlerWithReposRefs()
|
||||
delete(posts, 1)
|
||||
|
||||
request := httptest.NewRequest(http.MethodDelete, "/api/posts/1/vote", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.RemoveVote(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusNotFound)
|
||||
}
|
||||
|
||||
func TestVoteHandlerRemoveVoteUnexpectedError(t *testing.T) {
|
||||
handler, voteRepo, _ := newVoteHandlerWithReposRefs()
|
||||
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.CastVote(recorder, request)
|
||||
|
||||
voteRepo.DeleteErr = fmt.Errorf("database unavailable")
|
||||
|
||||
request = httptest.NewRequest(http.MethodDelete, "/api/posts/1/vote", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
recorder = httptest.NewRecorder()
|
||||
|
||||
handler.RemoveVote(recorder, request)
|
||||
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
func TestVoteHandlerGetUserVote(t *testing.T) {
|
||||
handler := newVoteHandlerWithRepos()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/posts/1/vote", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
handler.GetUserVote(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/posts/abc/vote", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
|
||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.GetUserVote(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/vote", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.GetUserVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 when vote missing, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
var resp VoteResponse
|
||||
_ = json.NewDecoder(recorder.Body).Decode(&resp)
|
||||
data := resp.Data.(map[string]any)
|
||||
if data["has_vote"].(bool) {
|
||||
t.Fatalf("expected has_vote false, got true")
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for creating vote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/vote", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.GetUserVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 when vote exists, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
_ = json.NewDecoder(recorder.Body).Decode(&resp)
|
||||
data = resp.Data.(map[string]any)
|
||||
if !data["has_vote"].(bool) {
|
||||
t.Fatalf("expected has_vote true, got false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoteHandlerGetPostVotes(t *testing.T) {
|
||||
handler := newVoteHandlerWithRepos()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/api/posts/abc/votes", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
|
||||
handler.GetPostVotes(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/votes", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
handler.GetPostVotes(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for empty votes, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for creating vote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for creating vote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/votes", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
handler.GetPostVotes(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
var resp VoteResponse
|
||||
_ = json.NewDecoder(recorder.Body).Decode(&resp)
|
||||
data := resp.Data.(map[string]any)
|
||||
votes := data["votes"].([]any)
|
||||
if len(votes) != 2 {
|
||||
t.Fatalf("expected 2 votes, got %d", len(votes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestVoteFlowRegression(t *testing.T) {
|
||||
handler := newVoteHandlerWithRepos()
|
||||
|
||||
t.Run("CompleteVoteLifecycle", func(t *testing.T) {
|
||||
userID := uint(1)
|
||||
postID := "1"
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": postID})
|
||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, userID)
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/vote", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": postID})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
|
||||
request = request.WithContext(ctx)
|
||||
handler.GetUserVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for getting vote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": postID})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for changing to downvote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": postID})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for removing vote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/vote", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": postID})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
|
||||
request = request.WithContext(ctx)
|
||||
handler.GetUserVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for getting removed vote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
var resp VoteResponse
|
||||
_ = json.NewDecoder(recorder.Body).Decode(&resp)
|
||||
data := resp.Data.(map[string]any)
|
||||
if data["has_vote"].(bool) {
|
||||
t.Fatalf("expected has_vote false after removal, got true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MultipleUsersVoting", func(t *testing.T) {
|
||||
postID := "1"
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": postID})
|
||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for user 1 upvote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": postID})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for user 2 downvote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": postID})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for user 3 upvote, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodGet, "/api/posts/1/votes", nil)
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": postID})
|
||||
handler.GetPostVotes(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for getting all votes, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
var resp VoteResponse
|
||||
_ = json.NewDecoder(recorder.Body).Decode(&resp)
|
||||
data := resp.Data.(map[string]any)
|
||||
votes := data["votes"].([]any)
|
||||
if len(votes) != 3 {
|
||||
t.Fatalf("expected 3 votes, got %d", len(votes))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ErrorHandlingEdgeCases", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(``))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for missing type field, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"invalid"}`))
|
||||
request = testutils.WithURLParams(request, map[string]string{"id": "1"})
|
||||
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
|
||||
request = request.WithContext(ctx)
|
||||
handler.CastVote(recorder, request)
|
||||
if recorder.Result().StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for invalid vote type, got %d", recorder.Result().StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user