diff --git a/internal/middleware/cors_test.go b/internal/middleware/cors_test.go index f24ab6c..2f1617b 100644 --- a/internal/middleware/cors_test.go +++ b/internal/middleware/cors_test.go @@ -181,7 +181,12 @@ func TestCORSWithConfig_WildcardOrigin(t *testing.T) { } } -func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) { +func TestCORSWithConfig_WildcardWithCredentialsPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatal("expected panic for AllowCredentials with wildcard AllowedOrigins") + } + }() config := &CORSConfig{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "POST"}, @@ -189,24 +194,7 @@ func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) { MaxAge: 3600, AllowCredentials: true, } - - handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - req := httptest.NewRequest("GET", "/", nil) - req.Header.Set("Origin", "http://example.com") - w := httptest.NewRecorder() - - handler.ServeHTTP(w, req) - - if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" { - t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin")) - } - - if w.Header().Get("Access-Control-Allow-Credentials") != "" { - t.Errorf("Expected Access-Control-Allow-Credentials to be empty with wildcard, got '%s'", w.Header().Get("Access-Control-Allow-Credentials")) - } + _ = CORSWithConfig(config) } func TestCORSWithConfig_NoOriginHeader(t *testing.T) { @@ -393,6 +381,30 @@ func TestCORSOPTIONSRequest(t *testing.T) { } } +func TestCORSAllowedOriginsTrimmedFromEnv(t *testing.T) { + t.Setenv("GOYCO_ENV", "development") + t.Setenv("CORS_ALLOWED_ORIGINS", " http://localhost:3000 , https://yourdomain.com ") + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware := CORS(handler) + + for _, origin := range []string{"http://localhost:3000", "https://yourdomain.com"} { + request := httptest.NewRequest("GET", "/api/posts", nil) + request.Header.Set("Origin", origin) + recorder := httptest.NewRecorder() + middleware.ServeHTTP(recorder, request) + if recorder.Code != http.StatusOK { + t.Fatalf("origin %q: status %d", origin, recorder.Code) + } + if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != origin { + t.Fatalf("origin %q: Allow-Origin %q", origin, got) + } + } +} + func TestCORSAllowedOrigins(t *testing.T) { t.Setenv("GOYCO_ENV", "development") t.Setenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://yourdomain.com")