refactor(cors): deduplicate origin validation and header logic without behavior change

This commit is contained in:
2026-03-06 15:37:44 +01:00
parent 19291b7f61
commit de9b544afb

View File

@@ -26,12 +26,7 @@ func NewCORSConfig() *CORSConfig {
} }
switch env { switch env {
case "production": case "production", "staging":
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" {
config.AllowedOrigins = []string{}
}
config.AllowCredentials = true
case "staging":
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" { if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" {
config.AllowedOrigins = []string{} config.AllowedOrigins = []string{}
} }
@@ -53,82 +48,66 @@ func NewCORSConfig() *CORSConfig {
return config return config
} }
func checkOrigin(origin string, allowedOrigins []string) (allowed bool, hasWildcard bool) {
for _, allowedOrigin := range allowedOrigins {
if allowedOrigin == "*" {
hasWildcard = true
allowed = true
break
}
if allowedOrigin == origin {
allowed = true
break
}
}
return
}
func setCORSHeaders(w http.ResponseWriter, origin string, hasWildcard bool, config *CORSConfig) {
if hasWildcard && !config.AllowCredentials {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
if config.AllowCredentials && !hasWildcard {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler { func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin") origin := r.Header.Get("Origin")
if origin == "" {
if r.Method == "OPTIONS" { if r.Method == "OPTIONS" {
if origin != "" { w.WriteHeader(http.StatusOK)
allowed := false return
hasWildcard := false
for _, allowedOrigin := range config.AllowedOrigins {
if allowedOrigin == "*" {
hasWildcard = true
allowed = true
break
}
if allowedOrigin == origin {
allowed = true
break
} }
next.ServeHTTP(w, r)
return
} }
allowed, hasWildcard := checkOrigin(origin, config.AllowedOrigins)
if !allowed { if !allowed {
http.Error(w, "Origin not allowed", http.StatusForbidden) http.Error(w, "Origin not allowed", http.StatusForbidden)
return return
} }
if hasWildcard && !config.AllowCredentials { if r.Method == "OPTIONS" {
w.Header().Set("Access-Control-Allow-Origin", "*") setCORSHeaders(w, origin, hasWildcard, config)
} else {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", ")) w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", "))
w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", ")) w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
w.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge)) w.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge))
if config.AllowCredentials && !hasWildcard {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
return return
} }
if origin != "" { setCORSHeaders(w, origin, hasWildcard, config)
allowed := false
hasWildcard := false
for _, allowedOrigin := range config.AllowedOrigins {
if allowedOrigin == "*" {
hasWildcard = true
allowed = true
break
}
if allowedOrigin == origin {
allowed = true
break
}
}
if !allowed {
http.Error(w, "Origin not allowed", http.StatusForbidden)
return
}
if hasWildcard && !config.AllowCredentials {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
if config.AllowCredentials && !hasWildcard {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }