test(middleware): CORS wildcard+credentials panic and trimmed env origins
This commit is contained in:
@@ -181,7 +181,12 @@ func TestCORSWithConfig_WildcardOrigin(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) {
|
||||
func TestCORSWithConfig_WildcardWithCredentialsPanics(t *testing.T) {
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatal("expected panic for AllowCredentials with wildcard AllowedOrigins")
|
||||
}
|
||||
}()
|
||||
config := &CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
@@ -189,24 +194,7 @@ func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) {
|
||||
MaxAge: 3600,
|
||||
AllowCredentials: true,
|
||||
}
|
||||
|
||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Credentials") != "" {
|
||||
t.Errorf("Expected Access-Control-Allow-Credentials to be empty with wildcard, got '%s'", w.Header().Get("Access-Control-Allow-Credentials"))
|
||||
}
|
||||
_ = CORSWithConfig(config)
|
||||
}
|
||||
|
||||
func TestCORSWithConfig_NoOriginHeader(t *testing.T) {
|
||||
@@ -393,6 +381,30 @@ func TestCORSOPTIONSRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSAllowedOriginsTrimmedFromEnv(t *testing.T) {
|
||||
t.Setenv("GOYCO_ENV", "development")
|
||||
t.Setenv("CORS_ALLOWED_ORIGINS", " http://localhost:3000 , https://yourdomain.com ")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := CORS(handler)
|
||||
|
||||
for _, origin := range []string{"http://localhost:3000", "https://yourdomain.com"} {
|
||||
request := httptest.NewRequest("GET", "/api/posts", nil)
|
||||
request.Header.Set("Origin", origin)
|
||||
recorder := httptest.NewRecorder()
|
||||
middleware.ServeHTTP(recorder, request)
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("origin %q: status %d", origin, recorder.Code)
|
||||
}
|
||||
if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != origin {
|
||||
t.Fatalf("origin %q: Allow-Origin %q", origin, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSAllowedOrigins(t *testing.T) {
|
||||
t.Setenv("GOYCO_ENV", "development")
|
||||
t.Setenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://yourdomain.com")
|
||||
|
||||
Reference in New Issue
Block a user