315 lines
8.2 KiB
Go
315 lines
8.2 KiB
Go
package handlers
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"goyco/internal/database"
|
|
"goyco/internal/dto"
|
|
"goyco/internal/middleware"
|
|
"goyco/internal/services"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type CommonResponse struct {
|
|
Success bool `json:"success"`
|
|
Message string `json:"message"`
|
|
Data any `json:"data,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
type PaginationData struct {
|
|
Count int `json:"count"`
|
|
Limit int `json:"limit"`
|
|
Offset int `json:"offset"`
|
|
}
|
|
|
|
type VoteCookieData struct {
|
|
Type database.VoteType `json:"type"`
|
|
Timestamp int64 `json:"timestamp"`
|
|
}
|
|
|
|
func sendResponse(w http.ResponseWriter, statusCode int, success bool, message string, data any, errMsg string) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(statusCode)
|
|
|
|
response := CommonResponse{
|
|
Success: success,
|
|
Message: message,
|
|
Data: data,
|
|
Error: errMsg,
|
|
}
|
|
|
|
json.NewEncoder(w).Encode(response)
|
|
}
|
|
|
|
func SendSuccessResponse(w http.ResponseWriter, message string, data any) {
|
|
sendResponse(w, http.StatusOK, true, message, data, "")
|
|
}
|
|
|
|
func SendCreatedResponse(w http.ResponseWriter, message string, data any) {
|
|
sendResponse(w, http.StatusCreated, true, message, data, "")
|
|
}
|
|
|
|
func SendErrorResponse(w http.ResponseWriter, message string, statusCode int) {
|
|
sendResponse(w, statusCode, false, "", nil, message)
|
|
}
|
|
|
|
func DecodeJSONRequest(w http.ResponseWriter, r *http.Request, req any) bool {
|
|
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
|
|
SendErrorResponse(w, "Invalid request body", http.StatusBadRequest)
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func GetClientIP(r *http.Request) string {
|
|
return middleware.GetSecureClientIP(r)
|
|
}
|
|
|
|
const (
|
|
CookieMaxAgeDays = 30
|
|
SecondsPerDay = 86400
|
|
DefaultPaginationLimit = 20
|
|
DefaultPaginationOffset = 0
|
|
)
|
|
|
|
func SetVoteCookie(w http.ResponseWriter, r *http.Request, postID uint, voteType database.VoteType) {
|
|
cookieName := fmt.Sprintf("vote_%d", postID)
|
|
cookieValue := fmt.Sprintf("%s:%d", voteType, time.Now().Unix())
|
|
|
|
cookie := &http.Cookie{
|
|
Name: cookieName,
|
|
Value: cookieValue,
|
|
Path: "/",
|
|
MaxAge: SecondsPerDay * CookieMaxAgeDays,
|
|
HttpOnly: true,
|
|
Secure: IsHTTPS(r),
|
|
SameSite: http.SameSiteLaxMode,
|
|
}
|
|
|
|
http.SetCookie(w, cookie)
|
|
}
|
|
|
|
func GetVoteCookie(r *http.Request, postID uint) string {
|
|
cookieName := fmt.Sprintf("vote_%d", postID)
|
|
cookie, err := r.Cookie(cookieName)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return cookie.Value
|
|
}
|
|
|
|
func ClearVoteCookie(w http.ResponseWriter, postID uint) {
|
|
cookieName := fmt.Sprintf("vote_%d", postID)
|
|
cookie := &http.Cookie{
|
|
Name: cookieName,
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
HttpOnly: true,
|
|
}
|
|
http.SetCookie(w, cookie)
|
|
}
|
|
|
|
func IsHTTPS(r *http.Request) bool {
|
|
if r.TLS != nil {
|
|
return true
|
|
}
|
|
|
|
if proto := r.Header.Get("X-Forwarded-Proto"); proto == "https" {
|
|
return true
|
|
}
|
|
|
|
if proto := r.Header.Get("X-Forwarded-Ssl"); proto == "on" {
|
|
return true
|
|
}
|
|
|
|
if proto := r.Header.Get("X-Forwarded-Scheme"); proto == "https" {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func SanitizeUser(user *database.User) dto.SanitizedUserDTO {
|
|
if user == nil {
|
|
return dto.SanitizedUserDTO{}
|
|
}
|
|
return dto.ToSanitizedUserDTO(user)
|
|
}
|
|
|
|
func SanitizeUsers(users []database.User) []dto.SanitizedUserDTO {
|
|
return dto.ToSanitizedUserDTOs(users)
|
|
}
|
|
|
|
func parsePagination(r *http.Request) (limit, offset int) {
|
|
limit = DefaultPaginationLimit
|
|
offset = DefaultPaginationOffset
|
|
|
|
limitStr := r.URL.Query().Get("limit")
|
|
if limitStr != "" {
|
|
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
|
limit = l
|
|
}
|
|
}
|
|
|
|
offsetStr := r.URL.Query().Get("offset")
|
|
if offsetStr != "" {
|
|
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
|
|
offset = o
|
|
}
|
|
}
|
|
|
|
return limit, offset
|
|
}
|
|
|
|
func ValidateRedirectURL(redirectURL string) string {
|
|
redirectURL = strings.TrimSpace(redirectURL)
|
|
if redirectURL == "" || len(redirectURL) > 512 {
|
|
return ""
|
|
}
|
|
|
|
if !strings.HasPrefix(redirectURL, "/") || strings.HasPrefix(redirectURL, "//") {
|
|
return ""
|
|
}
|
|
|
|
parsed, err := url.Parse(redirectURL)
|
|
if err != nil || parsed.Scheme != "" || parsed.Host != "" || parsed.User != nil || parsed.Path == "" {
|
|
return ""
|
|
}
|
|
|
|
path := parsed.EscapedPath()
|
|
if path == "" {
|
|
path = parsed.Path
|
|
}
|
|
|
|
validated := path
|
|
if parsed.RawQuery != "" {
|
|
validated += "?" + parsed.RawQuery
|
|
}
|
|
if parsed.Fragment != "" {
|
|
validated += "#" + parsed.Fragment
|
|
}
|
|
|
|
return validated
|
|
}
|
|
|
|
func ParseUintParam(w http.ResponseWriter, r *http.Request, paramName, entityName string) (uint, bool) {
|
|
str := chi.URLParam(r, paramName)
|
|
if str == "" {
|
|
SendErrorResponse(w, entityName+" ID is required", http.StatusBadRequest)
|
|
return 0, false
|
|
}
|
|
id, err := strconv.ParseUint(str, 10, 32)
|
|
if err != nil {
|
|
SendErrorResponse(w, "Invalid "+entityName+" ID", http.StatusBadRequest)
|
|
return 0, false
|
|
}
|
|
return uint(id), true
|
|
}
|
|
|
|
func RequireAuth(w http.ResponseWriter, r *http.Request) (uint, bool) {
|
|
userID := middleware.GetUserIDFromContext(r.Context())
|
|
if userID == 0 {
|
|
SendErrorResponse(w, "Authentication required", http.StatusUnauthorized)
|
|
return 0, false
|
|
}
|
|
return userID, true
|
|
}
|
|
|
|
func NewVoteContext(r *http.Request) services.VoteContext {
|
|
return services.VoteContext{
|
|
UserID: middleware.GetUserIDFromContext(r.Context()),
|
|
IPAddress: GetClientIP(r),
|
|
UserAgent: r.UserAgent(),
|
|
}
|
|
}
|
|
|
|
func HandleRepoError(w http.ResponseWriter, err error, entityName string) bool {
|
|
if err == nil {
|
|
return true
|
|
}
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
SendErrorResponse(w, entityName+" not found", http.StatusNotFound)
|
|
} else {
|
|
SendErrorResponse(w, "Failed to retrieve "+entityName, http.StatusInternalServerError)
|
|
}
|
|
return false
|
|
}
|
|
|
|
var AuthErrorMapping = []struct {
|
|
err error
|
|
msg string
|
|
code int
|
|
}{
|
|
{services.ErrInvalidCredentials, "Invalid username or password", http.StatusUnauthorized},
|
|
{services.ErrEmailNotVerified, "Please confirm your email before logging in", http.StatusForbidden},
|
|
{services.ErrAccountLocked, "Your account has been locked. Please contact us for assistance.", http.StatusForbidden},
|
|
{services.ErrUsernameTaken, "Username is already taken", http.StatusConflict},
|
|
{services.ErrEmailTaken, "Email is already registered", http.StatusConflict},
|
|
{services.ErrInvalidEmail, "Invalid email address", http.StatusBadRequest},
|
|
{services.ErrPasswordTooShort, "Password must be at least 8 characters", http.StatusBadRequest},
|
|
{services.ErrInvalidVerificationToken, "Invalid or expired verification token", http.StatusBadRequest},
|
|
{services.ErrRefreshTokenExpired, "Refresh token has expired", http.StatusUnauthorized},
|
|
{services.ErrRefreshTokenInvalid, "Invalid refresh token", http.StatusUnauthorized},
|
|
{services.ErrInvalidDeletionToken, "This deletion link is invalid or has expired.", http.StatusBadRequest},
|
|
{services.ErrDeletionRequestNotFound, "Deletion request not found", http.StatusBadRequest},
|
|
{services.ErrUserNotFound, "User not found", http.StatusNotFound},
|
|
{services.ErrEmailSenderUnavailable, "Email service is unavailable. Please try again later.", http.StatusServiceUnavailable},
|
|
}
|
|
|
|
func HandleServiceError(w http.ResponseWriter, err error, defaultMsg string, defaultCode int) bool {
|
|
if err == nil {
|
|
return true
|
|
}
|
|
|
|
for _, mapping := range AuthErrorMapping {
|
|
if err == mapping.err || errors.Is(err, mapping.err) {
|
|
SendErrorResponse(w, mapping.msg, mapping.code)
|
|
return false
|
|
}
|
|
}
|
|
|
|
errMsg := err.Error()
|
|
for _, mapping := range AuthErrorMapping {
|
|
if mapping.err.Error() == errMsg {
|
|
SendErrorResponse(w, mapping.msg, mapping.code)
|
|
return false
|
|
}
|
|
}
|
|
|
|
SendErrorResponse(w, defaultMsg, defaultCode)
|
|
return false
|
|
}
|
|
|
|
func GetValidatedDTO[T any](r *http.Request) (*T, bool) {
|
|
dtoVal := middleware.GetValidatedDTOFromContext(r.Context())
|
|
if dtoVal == nil {
|
|
return nil, false
|
|
}
|
|
dto, ok := dtoVal.(*T)
|
|
return dto, ok
|
|
}
|
|
|
|
func WithValidation[T any](validationMiddleware func(http.Handler) http.Handler, handler http.HandlerFunc) http.HandlerFunc {
|
|
if validationMiddleware == nil {
|
|
return handler
|
|
}
|
|
var zero T
|
|
dtoType := reflect.TypeOf(zero)
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := middleware.SetDTOTypeInContext(r.Context(), dtoType)
|
|
validationMiddleware(handler).ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
}
|