refactor: add a helper to centralize CSRF token retrieval
This commit is contained in:
@@ -14,6 +14,26 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
|
|||||||
ctx := setupPageHandlerTestContext(t)
|
ctx := setupPageHandlerTestContext(t)
|
||||||
router := ctx.Router
|
router := ctx.Router
|
||||||
|
|
||||||
|
getCSRFToken := func(t *testing.T, path string, cookies ...*http.Cookie) *http.Cookie {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
request := httptest.NewRequest("GET", path, nil)
|
||||||
|
for _, c := range cookies {
|
||||||
|
request.AddCookie(c)
|
||||||
|
}
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
for _, cookie := range recorder.Result().Cookies() {
|
||||||
|
if cookie.Name == "csrf_token" {
|
||||||
|
return cookie
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Fatalf("Expected CSRF cookie to be set for %s", path)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
t.Run("CSRF_Blocks_Form_Without_Token", func(t *testing.T) {
|
t.Run("CSRF_Blocks_Form_Without_Token", func(t *testing.T) {
|
||||||
requestBody := url.Values{}
|
requestBody := url.Values{}
|
||||||
requestBody.Set("username", "testuser")
|
requestBody.Set("username", "testuser")
|
||||||
@@ -35,30 +55,13 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("CSRF_Allows_Form_With_Valid_Token", func(t *testing.T) {
|
t.Run("CSRF_Allows_Form_With_Valid_Token", func(t *testing.T) {
|
||||||
getRequest := httptest.NewRequest("GET", "/register", nil)
|
csrfCookie := getCSRFToken(t, "/register")
|
||||||
getRecorder := httptest.NewRecorder()
|
|
||||||
router.ServeHTTP(getRecorder, getRequest)
|
|
||||||
|
|
||||||
cookies := getRecorder.Result().Cookies()
|
|
||||||
var csrfCookie *http.Cookie
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
if cookie.Name == "csrf_token" {
|
|
||||||
csrfCookie = cookie
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if csrfCookie == nil {
|
|
||||||
t.Fatal("Expected CSRF cookie to be set")
|
|
||||||
}
|
|
||||||
|
|
||||||
csrfToken := csrfCookie.Value
|
|
||||||
|
|
||||||
requestBody := url.Values{}
|
requestBody := url.Values{}
|
||||||
requestBody.Set("username", "csrf_user")
|
requestBody.Set("username", "csrf_user")
|
||||||
requestBody.Set("email", "csrf@example.com")
|
requestBody.Set("email", "csrf@example.com")
|
||||||
requestBody.Set("password", "SecurePass123!")
|
requestBody.Set("password", "SecurePass123!")
|
||||||
requestBody.Set("csrf_token", csrfToken)
|
requestBody.Set("csrf_token", csrfCookie.Value)
|
||||||
|
|
||||||
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
|
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
|
||||||
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -91,22 +94,7 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("CSRF_Blocks_Mismatched_Token", func(t *testing.T) {
|
t.Run("CSRF_Blocks_Mismatched_Token", func(t *testing.T) {
|
||||||
getRequest := httptest.NewRequest("GET", "/register", nil)
|
csrfCookie := getCSRFToken(t, "/register")
|
||||||
getRecorder := httptest.NewRecorder()
|
|
||||||
router.ServeHTTP(getRecorder, getRequest)
|
|
||||||
|
|
||||||
cookies := getRecorder.Result().Cookies()
|
|
||||||
var csrfCookie *http.Cookie
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
if cookie.Name == "csrf_token" {
|
|
||||||
csrfCookie = cookie
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if csrfCookie == nil {
|
|
||||||
t.Fatal("Expected CSRF cookie to be set")
|
|
||||||
}
|
|
||||||
|
|
||||||
requestBody := url.Values{}
|
requestBody := url.Values{}
|
||||||
requestBody.Set("username", "mismatch_user")
|
requestBody.Set("username", "mismatch_user")
|
||||||
@@ -141,24 +129,7 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("CSRF_Token_In_Header", func(t *testing.T) {
|
t.Run("CSRF_Token_In_Header", func(t *testing.T) {
|
||||||
getRequest := httptest.NewRequest("GET", "/register", nil)
|
csrfCookie := getCSRFToken(t, "/register")
|
||||||
getRecorder := httptest.NewRecorder()
|
|
||||||
router.ServeHTTP(getRecorder, getRequest)
|
|
||||||
|
|
||||||
cookies := getRecorder.Result().Cookies()
|
|
||||||
var csrfCookie *http.Cookie
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
if cookie.Name == "csrf_token" {
|
|
||||||
csrfCookie = cookie
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if csrfCookie == nil {
|
|
||||||
t.Fatal("Expected CSRF cookie to be set")
|
|
||||||
}
|
|
||||||
|
|
||||||
csrfToken := csrfCookie.Value
|
|
||||||
|
|
||||||
requestBody := url.Values{}
|
requestBody := url.Values{}
|
||||||
requestBody.Set("username", "header_user")
|
requestBody.Set("username", "header_user")
|
||||||
@@ -167,7 +138,7 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
|
|||||||
|
|
||||||
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
|
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
|
||||||
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
request.Header.Set("X-CSRF-Token", csrfToken)
|
request.Header.Set("X-CSRF-Token", csrfCookie.Value)
|
||||||
request.AddCookie(csrfCookie)
|
request.AddCookie(csrfCookie)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
@@ -182,35 +153,18 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
|
|||||||
ctx.Suite.EmailSender.Reset()
|
ctx.Suite.EmailSender.Reset()
|
||||||
user := createUserWithCleanup(t, ctx, "csrf_form_user", "csrf_form@example.com")
|
user := createUserWithCleanup(t, ctx, "csrf_form_user", "csrf_form@example.com")
|
||||||
|
|
||||||
getRequest := httptest.NewRequest("GET", "/posts/new", nil)
|
authCookie := &http.Cookie{Name: "auth_token", Value: user.Token}
|
||||||
getRequest.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
|
csrfCookie := getCSRFToken(t, "/posts/new", authCookie)
|
||||||
getRecorder := httptest.NewRecorder()
|
|
||||||
router.ServeHTTP(getRecorder, getRequest)
|
|
||||||
|
|
||||||
cookies := getRecorder.Result().Cookies()
|
|
||||||
var csrfCookie *http.Cookie
|
|
||||||
for _, cookie := range cookies {
|
|
||||||
if cookie.Name == "csrf_token" {
|
|
||||||
csrfCookie = cookie
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if csrfCookie == nil {
|
|
||||||
t.Fatal("Expected CSRF cookie to be set")
|
|
||||||
}
|
|
||||||
|
|
||||||
csrfToken := csrfCookie.Value
|
|
||||||
|
|
||||||
requestBody := url.Values{}
|
requestBody := url.Values{}
|
||||||
requestBody.Set("title", "CSRF Test Post")
|
requestBody.Set("title", "CSRF Test Post")
|
||||||
requestBody.Set("url", "https://example.com/csrf-test")
|
requestBody.Set("url", "https://example.com/csrf-test")
|
||||||
requestBody.Set("content", "Test content")
|
requestBody.Set("content", "Test content")
|
||||||
requestBody.Set("csrf_token", csrfToken)
|
requestBody.Set("csrf_token", csrfCookie.Value)
|
||||||
|
|
||||||
request := httptest.NewRequest("POST", "/posts", strings.NewReader(requestBody.Encode()))
|
request := httptest.NewRequest("POST", "/posts", strings.NewReader(requestBody.Encode()))
|
||||||
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
|
request.AddCookie(authCookie)
|
||||||
request.AddCookie(csrfCookie)
|
request.AddCookie(csrfCookie)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user