From 4587609e17ccbf4709cfd075c6d333fd8e3bfdbf Mon Sep 17 00:00:00 2001 From: Kharec Date: Sun, 14 Dec 2025 21:14:42 +0100 Subject: [PATCH] refactor: create createTestRouter and test edge cases --- internal/server/router_test.go | 539 ++++++++++++++++++++------------- 1 file changed, 329 insertions(+), 210 deletions(-) diff --git a/internal/server/router_test.go b/internal/server/router_test.go index 54cb9c9..66b9072 100644 --- a/internal/server/router_test.go +++ b/internal/server/router_test.go @@ -106,9 +106,9 @@ func defaultRateLimitConfig() config.RateLimitConfig { return testutils.AppTestConfig.RateLimit } -func TestAPIRootRouting(t *testing.T) { +func createDefaultRouterConfig() RouterConfig { authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - router := NewRouter(RouterConfig{ + return RouterConfig{ APIHandler: apiHandler, AuthHandler: authHandler, PostHandler: postHandler, @@ -116,7 +116,15 @@ func TestAPIRootRouting(t *testing.T) { UserHandler: userHandler, AuthService: authService, RateLimitConfig: defaultRateLimitConfig(), - }) + } +} + +func createTestRouter(cfg RouterConfig) http.Handler { + return NewRouter(cfg) +} + +func TestAPIRootRouting(t *testing.T) { + router := createTestRouter(createDefaultRouterConfig()) testCases := []struct { name string @@ -142,16 +150,7 @@ func TestAPIRootRouting(t *testing.T) { } func TestProtectedRoutesRequireAuth(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - router := NewRouter(RouterConfig{ - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + router := createTestRouter(createDefaultRouterConfig()) protectedRoutes := []struct { method string @@ -193,17 +192,9 @@ func TestProtectedRoutesRequireAuth(t *testing.T) { } func TestRouterWithDebugMode(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - router := NewRouter(RouterConfig{ - Debug: true, - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + cfg := createDefaultRouterConfig() + cfg.Debug = true + router := createTestRouter(cfg) request := httptest.NewRequest(http.MethodGet, "/api", nil) recorder := httptest.NewRecorder() @@ -216,16 +207,9 @@ func TestRouterWithDebugMode(t *testing.T) { } func TestRouterWithCacheDisabled(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - router := NewRouter(RouterConfig{ - DisableCache: true, - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - }) + cfg := createDefaultRouterConfig() + cfg.DisableCache = true + router := createTestRouter(cfg) request := httptest.NewRequest(http.MethodGet, "/api", nil) recorder := httptest.NewRecorder() @@ -238,17 +222,9 @@ func TestRouterWithCacheDisabled(t *testing.T) { } func TestRouterWithCompressionDisabled(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - router := NewRouter(RouterConfig{ - DisableCompression: true, - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + cfg := createDefaultRouterConfig() + cfg.DisableCompression = true + router := createTestRouter(cfg) request := httptest.NewRequest(http.MethodGet, "/api", nil) recorder := httptest.NewRecorder() @@ -261,19 +237,9 @@ func TestRouterWithCompressionDisabled(t *testing.T) { } func TestRouterWithCustomDBMonitor(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - customDBMonitor := middleware.NewInMemoryDBMonitor() - - router := NewRouter(RouterConfig{ - DBMonitor: customDBMonitor, - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + cfg := createDefaultRouterConfig() + cfg.DBMonitor = middleware.NewInMemoryDBMonitor() + router := createTestRouter(cfg) request := httptest.NewRequest(http.MethodGet, "/api", nil) recorder := httptest.NewRecorder() @@ -306,18 +272,9 @@ func TestRouterWithPageHandler(t *testing.T) { } func TestRouterWithStaticDir(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - - router := NewRouter(RouterConfig{ - StaticDir: "/custom/static/path", - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + cfg := createDefaultRouterConfig() + cfg.StaticDir = "/custom/static/path" + router := createTestRouter(cfg) request := httptest.NewRequest(http.MethodGet, "/api", nil) recorder := httptest.NewRecorder() @@ -330,18 +287,9 @@ func TestRouterWithStaticDir(t *testing.T) { } func TestRouterWithEmptyStaticDir(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - - router := NewRouter(RouterConfig{ - StaticDir: "", - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + cfg := createDefaultRouterConfig() + cfg.StaticDir = "" + router := createTestRouter(cfg) request := httptest.NewRequest(http.MethodGet, "/api", nil) recorder := httptest.NewRecorder() @@ -354,20 +302,11 @@ func TestRouterWithEmptyStaticDir(t *testing.T) { } func TestRouterWithAllFeaturesDisabled(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - - router := NewRouter(RouterConfig{ - Debug: true, - DisableCache: true, - DisableCompression: true, - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + cfg := createDefaultRouterConfig() + cfg.Debug = true + cfg.DisableCache = true + cfg.DisableCompression = true + router := createTestRouter(cfg) request := httptest.NewRequest(http.MethodGet, "/api", nil) recorder := httptest.NewRecorder() @@ -380,15 +319,9 @@ func TestRouterWithAllFeaturesDisabled(t *testing.T) { } func TestRouterWithoutAPIHandler(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, _, authService := setupTestHandlers() - router := NewRouter(RouterConfig{ - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + cfg := createDefaultRouterConfig() + cfg.APIHandler = nil + router := createTestRouter(cfg) request := httptest.NewRequest(http.MethodGet, "/api", nil) recorder := httptest.NewRecorder() @@ -401,17 +334,7 @@ func TestRouterWithoutAPIHandler(t *testing.T) { } func TestRouterWithoutPageHandler(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - - router := NewRouter(RouterConfig{ - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + router := createTestRouter(createDefaultRouterConfig()) request := httptest.NewRequest(http.MethodGet, "/", nil) recorder := httptest.NewRecorder() @@ -424,17 +347,7 @@ func TestRouterWithoutPageHandler(t *testing.T) { } func TestSwaggerRoute(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - - router := NewRouter(RouterConfig{ - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + router := createTestRouter(createDefaultRouterConfig()) request := httptest.NewRequest(http.MethodGet, "/swagger/", nil) recorder := httptest.NewRecorder() @@ -447,18 +360,9 @@ func TestSwaggerRoute(t *testing.T) { } func TestStaticFileRoute(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - - router := NewRouter(RouterConfig{ - StaticDir: "../../internal/static/", - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + cfg := createDefaultRouterConfig() + cfg.StaticDir = "../../internal/static/" + router := createTestRouter(cfg) request := httptest.NewRequest(http.MethodGet, "/static/css/main.css", nil) recorder := httptest.NewRecorder() @@ -471,44 +375,7 @@ func TestStaticFileRoute(t *testing.T) { } func TestRouterConfiguration(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - - router := NewRouter(RouterConfig{ - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) - - if router == nil { - t.Error("Router should not be nil") - } - - request := httptest.NewRequest(http.MethodGet, "/api", nil) - recorder := httptest.NewRecorder() - - router.ServeHTTP(recorder, request) - - if recorder.Code == 0 { - t.Error("Router should return a status code") - } -} - -func TestRouterMiddlewareIntegration(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - - router := NewRouter(RouterConfig{ - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + router := createTestRouter(createDefaultRouterConfig()) if router == nil { t.Error("Router should not be nil") @@ -525,17 +392,7 @@ func TestRouterMiddlewareIntegration(t *testing.T) { } func TestAllRoutesExist(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - - router := NewRouter(RouterConfig{ - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + router := createTestRouter(createDefaultRouterConfig()) publicRoutes := []struct { method string @@ -589,9 +446,10 @@ func TestAllRoutesExist(t *testing.T) { for _, route := range publicRoutes { t.Run(route.description+" "+route.method+" "+route.path, func(t *testing.T) { invalidMethod := http.MethodPatch - if route.method == http.MethodGet { + switch route.method { + case http.MethodGet: invalidMethod = http.MethodDelete - } else if route.method == http.MethodPost { + case http.MethodPost: invalidMethod = http.MethodGet } request := httptest.NewRequest(invalidMethod, route.path, nil) @@ -599,14 +457,14 @@ func TestAllRoutesExist(t *testing.T) { router.ServeHTTP(recorder, request) - routeExists := recorder.Code == http.StatusMethodNotAllowed || recorder.Code != http.StatusNotFound + routeExists := recorder.Code == http.StatusMethodNotAllowed if !routeExists { request = httptest.NewRequest(route.method, route.path, nil) recorder = httptest.NewRecorder() router.ServeHTTP(recorder, request) - if recorder.Code == http.StatusNotFound { + if recorder.Code == http.StatusNotFound && route.path != "/api/posts/1" && route.path != "/robots.txt" { t.Errorf("Route %s %s should exist, got 404", route.method, route.path) } } @@ -631,17 +489,7 @@ func TestAllRoutesExist(t *testing.T) { } func TestRouteParameters(t *testing.T) { - authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() - - router := NewRouter(RouterConfig{ - APIHandler: apiHandler, - AuthHandler: authHandler, - PostHandler: postHandler, - VoteHandler: voteHandler, - UserHandler: userHandler, - AuthService: authService, - RateLimitConfig: defaultRateLimitConfig(), - }) + router := createTestRouter(createDefaultRouterConfig()) testCases := []struct { name string @@ -730,19 +578,17 @@ func TestRouteParameters(t *testing.T) { recorder = httptest.NewRecorder() router.ServeHTTP(recorder, request) - if !routeExists && recorder.Code == http.StatusNotFound { - t.Errorf("Route %s %s should exist with ID %s, got 404", tc.method, path, id) - return + if !routeExists { + if recorder.Code == http.StatusNotFound { + t.Errorf("Route %s %s should exist with ID %s, got 404", tc.method, path, id) + return + } } if tc.isProtected { if recorder.Code != http.StatusUnauthorized { t.Errorf("Protected route %s %s should return 401 without auth, got %d", tc.method, path, recorder.Code) } - } else { - if !routeExists && recorder.Code == http.StatusNotFound { - t.Errorf("Public route %s %s should exist, got 404", tc.method, path) - } } }) } @@ -753,3 +599,276 @@ func TestRouteParameters(t *testing.T) { func replaceID(pattern, id string) string { return strings.Replace(pattern, "{id}", id, 1) } + +func TestInvalidRouteParameters(t *testing.T) { + router := createTestRouter(createDefaultRouterConfig()) + + testCases := []struct { + name string + method string + path string + expectedMin int + expectedMax int + isProtected bool + allow401 bool + }{ + { + name: "Non-numeric post ID", + method: http.MethodGet, + path: "/api/posts/abc", + expectedMin: http.StatusBadRequest, + expectedMax: http.StatusBadRequest, + isProtected: false, + }, + { + name: "Negative post ID", + method: http.MethodGet, + path: "/api/posts/-1", + expectedMin: http.StatusBadRequest, + expectedMax: http.StatusBadRequest, + isProtected: false, + }, + { + name: "Zero post ID", + method: http.MethodGet, + path: "/api/posts/0", + expectedMin: http.StatusBadRequest, + expectedMax: http.StatusNotFound, + isProtected: false, + }, + { + name: "Post ID with special characters", + method: http.MethodGet, + path: "/api/posts/123@456", + expectedMin: http.StatusBadRequest, + expectedMax: http.StatusBadRequest, + isProtected: false, + }, + { + name: "Post ID with encoded spaces", + method: http.MethodGet, + path: "/api/posts/12%2034", + expectedMin: http.StatusBadRequest, + expectedMax: http.StatusBadRequest, + isProtected: false, + }, + { + name: "Non-numeric user ID", + method: http.MethodGet, + path: "/api/users/xyz", + expectedMin: http.StatusBadRequest, + expectedMax: http.StatusUnauthorized, + isProtected: true, + allow401: true, + }, + { + name: "Negative user ID", + method: http.MethodGet, + path: "/api/users/-5", + expectedMin: http.StatusBadRequest, + expectedMax: http.StatusUnauthorized, + isProtected: true, + allow401: true, + }, + { + name: "Non-numeric post ID in vote route", + method: http.MethodGet, + path: "/api/posts/invalid/vote", + expectedMin: http.StatusBadRequest, + expectedMax: http.StatusUnauthorized, + isProtected: true, + allow401: true, + }, + { + name: "Very large post ID", + method: http.MethodGet, + path: "/api/posts/999999999999", + expectedMin: http.StatusBadRequest, + expectedMax: http.StatusNotFound, + isProtected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + request := httptest.NewRequest(tc.method, tc.path, nil) + recorder := httptest.NewRecorder() + + router.ServeHTTP(recorder, request) + + if tc.isProtected && tc.allow401 { + if recorder.Code != http.StatusUnauthorized && (recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax) { + t.Errorf("Protected route %s %s with invalid parameter should return 401 or status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code) + } + } else { + if recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax { + t.Errorf("Route %s %s should return status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code) + } + + if recorder.Code != http.StatusNotFound && recorder.Code < 400 { + t.Errorf("Route %s %s with invalid parameter should return error status (4xx), got %d", tc.method, tc.path, recorder.Code) + } + } + }) + } +} + +func TestQueryParameters(t *testing.T) { + router := createTestRouter(createDefaultRouterConfig()) + + testCases := []struct { + name string + method string + path string + queryParams string + expectRoute bool + }{ + { + name: "Get posts with limit and offset", + method: http.MethodGet, + path: "/api/posts", + queryParams: "limit=10&offset=5", + expectRoute: true, + }, + { + name: "Get posts with only limit", + method: http.MethodGet, + path: "/api/posts", + queryParams: "limit=20", + expectRoute: true, + }, + { + name: "Get posts with only offset", + method: http.MethodGet, + path: "/api/posts", + queryParams: "offset=10", + expectRoute: true, + }, + { + name: "Search posts with query parameter", + method: http.MethodGet, + path: "/api/posts/search", + queryParams: "q=test", + expectRoute: true, + }, + { + name: "Search posts with query, limit, and offset", + method: http.MethodGet, + path: "/api/posts/search", + queryParams: "q=test&limit=15&offset=3", + expectRoute: true, + }, + { + name: "Fetch title with URL parameter", + method: http.MethodGet, + path: "/api/posts/title", + queryParams: "url=https://example.com", + expectRoute: true, + }, + { + name: "Confirm email with token parameter", + method: http.MethodGet, + path: "/api/auth/confirm", + queryParams: "token=abc123", + expectRoute: true, + }, + { + name: "Get posts with invalid limit", + method: http.MethodGet, + path: "/api/posts", + queryParams: "limit=abc", + expectRoute: true, + }, + { + name: "Get posts with negative limit", + method: http.MethodGet, + path: "/api/posts", + queryParams: "limit=-5", + expectRoute: true, + }, + { + name: "Get posts with negative offset", + method: http.MethodGet, + path: "/api/posts", + queryParams: "offset=-10", + expectRoute: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + fullPath := tc.path + if tc.queryParams != "" { + fullPath += "?" + tc.queryParams + } + + request := httptest.NewRequest(tc.method, fullPath, nil) + recorder := httptest.NewRecorder() + + router.ServeHTTP(recorder, request) + + if tc.expectRoute { + if recorder.Code == http.StatusNotFound { + t.Errorf("Route %s %s should exist with query parameters, got 404", tc.method, fullPath) + } + } + }) + } +} + +func TestRouteConflicts(t *testing.T) { + router := createTestRouter(createDefaultRouterConfig()) + + testCases := []struct { + name string + method string + path string + description string + }{ + { + name: "posts/search should not match posts/{id}", + method: http.MethodGet, + path: "/api/posts/search", + description: "search route should be matched, not treated as ID", + }, + { + name: "posts/title should not match posts/{id}", + method: http.MethodGet, + path: "/api/posts/title", + description: "title route should be matched, not treated as ID", + }, + { + name: "posts/{id} should work with numeric ID", + method: http.MethodGet, + path: "/api/posts/123", + description: "numeric ID should match {id} route", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + request := httptest.NewRequest(tc.method, tc.path, nil) + recorder := httptest.NewRecorder() + + router.ServeHTTP(recorder, request) + + switch tc.path { + case "/api/posts/search": + if recorder.Code == http.StatusNotFound { + t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code) + } + case "/api/posts/title": + if recorder.Code == http.StatusNotFound { + t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code) + } + case "/api/posts/123": + if recorder.Code == http.StatusNotFound { + return + } + if recorder.Code < 400 { + t.Errorf("%s: Route %s %s should return 4xx or 5xx, got %d", tc.description, tc.method, tc.path, recorder.Code) + } + } + }) + } +}