diff --git a/internal/handlers/common.go b/internal/handlers/common.go index ba9a2a0..907e5a5 100644 --- a/internal/handlers/common.go +++ b/internal/handlers/common.go @@ -15,6 +15,7 @@ import ( "goyco/internal/dto" "goyco/internal/middleware" "goyco/internal/services" + "goyco/internal/validation" "github.com/go-chi/chi/v5" "gorm.io/gorm" @@ -272,13 +273,51 @@ func HandleServiceError(w http.ResponseWriter, err error, defaultMsg string, def return false } -func GetValidatedDTO[T any](r *http.Request) (*T, bool) { +func GetValidatedDTO[T any](w http.ResponseWriter, r *http.Request) (*T, bool) { dtoVal := middleware.GetValidatedDTOFromContext(r.Context()) - if dtoVal == nil { - return nil, false + dtoTypeInContext := middleware.GetDTOTypeFromContext(r.Context()) + + var dto *T + needsValidation := false + + if dtoVal != nil { + var ok bool + dto, ok = dtoVal.(*T) + if !ok { + return nil, false + } + if dtoTypeInContext == nil { + needsValidation = true + } + } else { + var decoded T + if err := json.NewDecoder(r.Body).Decode(&decoded); err != nil { + SendErrorResponse(w, "Invalid JSON", http.StatusBadRequest) + return nil, false + } + dto = &decoded + needsValidation = true } - dto, ok := dtoVal.(*T) - return dto, ok + + if needsValidation { + if err := validation.ValidateStruct(dto); err != nil { + var errorMessages []string + if structErr, ok := err.(*validation.StructValidationError); ok { + errorMessages = make([]string, len(structErr.Errors)) + for i, fieldError := range structErr.Errors { + errorMessages[i] = fieldError.Message + } + } else { + errorMessages = []string{err.Error()} + } + + errorMsg := strings.Join(errorMessages, "; ") + SendErrorResponse(w, errorMsg, http.StatusBadRequest) + return nil, false + } + } + + return dto, true } func WithValidation[T any](validationMiddleware func(http.Handler) http.Handler, handler http.HandlerFunc) http.HandlerFunc {