162 lines
4.7 KiB
Go
162 lines
4.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"testing"
|
|
)
|
|
|
|
type TestUser struct {
|
|
Username string `json:"username" validate:"required,min=3,max=20"`
|
|
Email string `json:"email" validate:"required,email"`
|
|
Age int `json:"age" validate:"min=18,max=120"`
|
|
URL string `json:"url" validate:"url"`
|
|
Status string `json:"status" validate:"oneof=active inactive pending"`
|
|
}
|
|
|
|
type TestPost struct {
|
|
Title string `json:"title" validate:"required,min=1,max=200"`
|
|
Content string `json:"content" validate:"required,min=10"`
|
|
Tags string `json:"tags" validate:"omitempty,min=1"`
|
|
}
|
|
|
|
func TestValidationMiddleware(t *testing.T) {
|
|
middleware := ValidationMiddleware()
|
|
|
|
t.Run("Valid POST request", func(t *testing.T) {
|
|
user := TestUser{
|
|
Username: "testuser",
|
|
Email: "test@example.com",
|
|
Age: 25,
|
|
URL: "https://example.com",
|
|
Status: "active",
|
|
}
|
|
|
|
body, _ := json.Marshal(user)
|
|
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
|
|
request.Header.Set("Content-Type", "application/json")
|
|
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeOf(TestUser{}))
|
|
request = request.WithContext(ctx)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("success"))
|
|
}))
|
|
|
|
handler.ServeHTTP(recorder, request)
|
|
|
|
if recorder.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", recorder.Code)
|
|
}
|
|
|
|
if recorder.Body.String() != "success" {
|
|
t.Errorf("Expected 'success', got '%s'", recorder.Body.String())
|
|
}
|
|
})
|
|
|
|
t.Run("Invalid POST request - missing required field", func(t *testing.T) {
|
|
user := TestUser{
|
|
Email: "test@example.com",
|
|
Age: 25,
|
|
URL: "https://example.com",
|
|
Status: "active",
|
|
}
|
|
|
|
body, _ := json.Marshal(user)
|
|
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
|
|
request.Header.Set("Content-Type", "application/json")
|
|
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeOf(TestUser{}))
|
|
request = request.WithContext(ctx)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Error("Handler should not be called for invalid request")
|
|
}))
|
|
|
|
handler.ServeHTTP(recorder, request)
|
|
|
|
if recorder.Code != http.StatusBadRequest {
|
|
t.Errorf("Expected status 400, got %d", recorder.Code)
|
|
}
|
|
})
|
|
|
|
t.Run("GET request bypasses validation", func(t *testing.T) {
|
|
request := httptest.NewRequest("GET", "/users", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("success"))
|
|
}))
|
|
|
|
handler.ServeHTTP(recorder, request)
|
|
|
|
if recorder.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", recorder.Code)
|
|
}
|
|
|
|
if recorder.Body.String() != "success" {
|
|
t.Errorf("Expected 'success', got '%s'", recorder.Body.String())
|
|
}
|
|
})
|
|
|
|
t.Run("Invalid POST request - invalid URL format", func(t *testing.T) {
|
|
user := TestUser{
|
|
Username: "testuser",
|
|
Email: "test@example.com",
|
|
Age: 25,
|
|
URL: "http://",
|
|
Status: "active",
|
|
}
|
|
|
|
body, _ := json.Marshal(user)
|
|
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
|
|
request.Header.Set("Content-Type", "application/json")
|
|
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeOf(TestUser{}))
|
|
request = request.WithContext(ctx)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Error("Handler should not be called for invalid request")
|
|
}))
|
|
|
|
handler.ServeHTTP(recorder, request)
|
|
|
|
if recorder.Code != http.StatusBadRequest {
|
|
t.Errorf("Expected status 400, got %d", recorder.Code)
|
|
}
|
|
})
|
|
|
|
t.Run("Invalid POST request - URL without protocol", func(t *testing.T) {
|
|
user := TestUser{
|
|
Username: "testuser",
|
|
Email: "test@example.com",
|
|
Age: 25,
|
|
URL: "example.com",
|
|
Status: "active",
|
|
}
|
|
|
|
body, _ := json.Marshal(user)
|
|
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
|
|
request.Header.Set("Content-Type", "application/json")
|
|
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeOf(TestUser{}))
|
|
request = request.WithContext(ctx)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Error("Handler should not be called for invalid request")
|
|
}))
|
|
|
|
handler.ServeHTTP(recorder, request)
|
|
|
|
if recorder.Code != http.StatusBadRequest {
|
|
t.Errorf("Expected status 400, got %d", recorder.Code)
|
|
}
|
|
})
|
|
}
|