diff --git a/internal/middleware/csrf_test.go b/internal/middleware/csrf_test.go index 1b58fec..2018152 100644 --- a/internal/middleware/csrf_test.go +++ b/internal/middleware/csrf_test.go @@ -186,8 +186,9 @@ func TestCSRFMiddlewareAllowsValidToken(t *testing.T) { } } -func TestCSRFMiddlewareSkipsAPI(t *testing.T) { +func TestCSRFMiddlewareSkipsAPIWithBearerToken(t *testing.T) { request := httptest.NewRequest("POST", "/api/test", nil) + request.Header.Set("Authorization", "Bearer valid-token") recorder := httptest.NewRecorder() handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -197,7 +198,22 @@ func TestCSRFMiddlewareSkipsAPI(t *testing.T) { handler.ServeHTTP(recorder, request) if recorder.Code != http.StatusOK { - t.Errorf("API requests should skip CSRF validation, got status %d", recorder.Code) + t.Errorf("API requests with Bearer token should skip CSRF validation, got status %d", recorder.Code) + } +} + +func TestCSRFMiddlewareBlocksAPIWithoutBearerToken(t *testing.T) { + request := httptest.NewRequest("POST", "/api/test", nil) + recorder := httptest.NewRecorder() + + handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + handler.ServeHTTP(recorder, request) + + if recorder.Code != http.StatusForbidden { + t.Errorf("API requests without Bearer token should require CSRF validation, got status %d", recorder.Code) } }