test(middleware): CORS wildcard+credentials panic and trimmed env origins

This commit is contained in:
2026-05-06 20:06:55 +02:00
parent 89131331a6
commit add60ad3c2
+31 -19
View File
@@ -181,7 +181,12 @@ func TestCORSWithConfig_WildcardOrigin(t *testing.T) {
} }
} }
func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) { func TestCORSWithConfig_WildcardWithCredentialsPanics(t *testing.T) {
defer func() {
if recover() == nil {
t.Fatal("expected panic for AllowCredentials with wildcard AllowedOrigins")
}
}()
config := &CORSConfig{ config := &CORSConfig{
AllowedOrigins: []string{"*"}, AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST"}, AllowedMethods: []string{"GET", "POST"},
@@ -189,24 +194,7 @@ func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) {
MaxAge: 3600, MaxAge: 3600,
AllowCredentials: true, AllowCredentials: true,
} }
_ = CORSWithConfig(config)
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Origin", "http://example.com")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
}
if w.Header().Get("Access-Control-Allow-Credentials") != "" {
t.Errorf("Expected Access-Control-Allow-Credentials to be empty with wildcard, got '%s'", w.Header().Get("Access-Control-Allow-Credentials"))
}
} }
func TestCORSWithConfig_NoOriginHeader(t *testing.T) { func TestCORSWithConfig_NoOriginHeader(t *testing.T) {
@@ -393,6 +381,30 @@ func TestCORSOPTIONSRequest(t *testing.T) {
} }
} }
func TestCORSAllowedOriginsTrimmedFromEnv(t *testing.T) {
t.Setenv("GOYCO_ENV", "development")
t.Setenv("CORS_ALLOWED_ORIGINS", " http://localhost:3000 , https://yourdomain.com ")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := CORS(handler)
for _, origin := range []string{"http://localhost:3000", "https://yourdomain.com"} {
request := httptest.NewRequest("GET", "/api/posts", nil)
request.Header.Set("Origin", origin)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
t.Fatalf("origin %q: status %d", origin, recorder.Code)
}
if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != origin {
t.Fatalf("origin %q: Allow-Origin %q", origin, got)
}
}
}
func TestCORSAllowedOrigins(t *testing.T) { func TestCORSAllowedOrigins(t *testing.T) {
t.Setenv("GOYCO_ENV", "development") t.Setenv("GOYCO_ENV", "development")
t.Setenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://yourdomain.com") t.Setenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://yourdomain.com")