From 0cd68e847cec81b33fb7a87010b31171f92b97c2 Mon Sep 17 00:00:00 2001 From: Kharec Date: Tue, 9 Dec 2025 15:58:28 +0100 Subject: [PATCH] refactor: add a helper to centralize CSRF token retrieval --- internal/integration/csrf_integration_test.go | 104 +++++------------- 1 file changed, 29 insertions(+), 75 deletions(-) diff --git a/internal/integration/csrf_integration_test.go b/internal/integration/csrf_integration_test.go index 90deb5f..eab12f6 100644 --- a/internal/integration/csrf_integration_test.go +++ b/internal/integration/csrf_integration_test.go @@ -14,6 +14,26 @@ func TestIntegration_CSRF_Protection(t *testing.T) { ctx := setupPageHandlerTestContext(t) router := ctx.Router + getCSRFToken := func(t *testing.T, path string, cookies ...*http.Cookie) *http.Cookie { + t.Helper() + + request := httptest.NewRequest("GET", path, nil) + for _, c := range cookies { + request.AddCookie(c) + } + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, request) + + for _, cookie := range recorder.Result().Cookies() { + if cookie.Name == "csrf_token" { + return cookie + } + } + + t.Fatalf("Expected CSRF cookie to be set for %s", path) + return nil + } + t.Run("CSRF_Blocks_Form_Without_Token", func(t *testing.T) { requestBody := url.Values{} requestBody.Set("username", "testuser") @@ -35,30 +55,13 @@ func TestIntegration_CSRF_Protection(t *testing.T) { }) t.Run("CSRF_Allows_Form_With_Valid_Token", func(t *testing.T) { - getRequest := httptest.NewRequest("GET", "/register", nil) - getRecorder := httptest.NewRecorder() - router.ServeHTTP(getRecorder, getRequest) - - cookies := getRecorder.Result().Cookies() - var csrfCookie *http.Cookie - for _, cookie := range cookies { - if cookie.Name == "csrf_token" { - csrfCookie = cookie - break - } - } - - if csrfCookie == nil { - t.Fatal("Expected CSRF cookie to be set") - } - - csrfToken := csrfCookie.Value + csrfCookie := getCSRFToken(t, "/register") requestBody := url.Values{} requestBody.Set("username", "csrf_user") requestBody.Set("email", "csrf@example.com") requestBody.Set("password", "SecurePass123!") - requestBody.Set("csrf_token", csrfToken) + requestBody.Set("csrf_token", csrfCookie.Value) request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode())) request.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -91,22 +94,7 @@ func TestIntegration_CSRF_Protection(t *testing.T) { }) t.Run("CSRF_Blocks_Mismatched_Token", func(t *testing.T) { - getRequest := httptest.NewRequest("GET", "/register", nil) - getRecorder := httptest.NewRecorder() - router.ServeHTTP(getRecorder, getRequest) - - cookies := getRecorder.Result().Cookies() - var csrfCookie *http.Cookie - for _, cookie := range cookies { - if cookie.Name == "csrf_token" { - csrfCookie = cookie - break - } - } - - if csrfCookie == nil { - t.Fatal("Expected CSRF cookie to be set") - } + csrfCookie := getCSRFToken(t, "/register") requestBody := url.Values{} requestBody.Set("username", "mismatch_user") @@ -141,24 +129,7 @@ func TestIntegration_CSRF_Protection(t *testing.T) { }) t.Run("CSRF_Token_In_Header", func(t *testing.T) { - getRequest := httptest.NewRequest("GET", "/register", nil) - getRecorder := httptest.NewRecorder() - router.ServeHTTP(getRecorder, getRequest) - - cookies := getRecorder.Result().Cookies() - var csrfCookie *http.Cookie - for _, cookie := range cookies { - if cookie.Name == "csrf_token" { - csrfCookie = cookie - break - } - } - - if csrfCookie == nil { - t.Fatal("Expected CSRF cookie to be set") - } - - csrfToken := csrfCookie.Value + csrfCookie := getCSRFToken(t, "/register") requestBody := url.Values{} requestBody.Set("username", "header_user") @@ -167,7 +138,7 @@ func TestIntegration_CSRF_Protection(t *testing.T) { request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode())) request.Header.Set("Content-Type", "application/x-www-form-urlencoded") - request.Header.Set("X-CSRF-Token", csrfToken) + request.Header.Set("X-CSRF-Token", csrfCookie.Value) request.AddCookie(csrfCookie) recorder := httptest.NewRecorder() @@ -182,35 +153,18 @@ func TestIntegration_CSRF_Protection(t *testing.T) { ctx.Suite.EmailSender.Reset() user := createUserWithCleanup(t, ctx, "csrf_form_user", "csrf_form@example.com") - getRequest := httptest.NewRequest("GET", "/posts/new", nil) - getRequest.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) - getRecorder := httptest.NewRecorder() - router.ServeHTTP(getRecorder, getRequest) - - cookies := getRecorder.Result().Cookies() - var csrfCookie *http.Cookie - for _, cookie := range cookies { - if cookie.Name == "csrf_token" { - csrfCookie = cookie - break - } - } - - if csrfCookie == nil { - t.Fatal("Expected CSRF cookie to be set") - } - - csrfToken := csrfCookie.Value + authCookie := &http.Cookie{Name: "auth_token", Value: user.Token} + csrfCookie := getCSRFToken(t, "/posts/new", authCookie) requestBody := url.Values{} requestBody.Set("title", "CSRF Test Post") requestBody.Set("url", "https://example.com/csrf-test") requestBody.Set("content", "Test content") - requestBody.Set("csrf_token", csrfToken) + requestBody.Set("csrf_token", csrfCookie.Value) request := httptest.NewRequest("POST", "/posts", strings.NewReader(requestBody.Encode())) request.Header.Set("Content-Type", "application/x-www-form-urlencoded") - request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) + request.AddCookie(authCookie) request.AddCookie(csrfCookie) recorder := httptest.NewRecorder()