fix(middleware): validate CORS origins and reject wildcard with credentials

This commit is contained in:
2026-05-06 20:06:53 +02:00
parent 0baf7053fc
commit 89131331a6
+42 -7
View File
@@ -15,8 +15,39 @@ type CORSConfig struct {
AllowCredentials bool AllowCredentials bool
} }
func parseAllowedOriginsEnv() []string {
raw := os.Getenv("CORS_ALLOWED_ORIGINS")
if raw == "" {
return nil
}
parts := strings.Split(raw, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
return out
}
func validateCORSConfig(config *CORSConfig) {
if config == nil {
panic("middleware.CORS: config is nil")
}
if !config.AllowCredentials {
return
}
for _, o := range config.AllowedOrigins {
if o == "*" {
panic("middleware.CORS: AllowCredentials with wildcard AllowedOrigins (*) is not permitted")
}
}
}
func NewCORSConfig() *CORSConfig { func NewCORSConfig() *CORSConfig {
env := os.Getenv("GOYCO_ENV") env := os.Getenv("GOYCO_ENV")
originsEnv := parseAllowedOriginsEnv()
config := &CORSConfig{ config := &CORSConfig{
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
@@ -27,24 +58,27 @@ func NewCORSConfig() *CORSConfig {
switch env { switch env {
case "production", "staging": case "production", "staging":
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" { config.AllowCredentials = true
if len(originsEnv) > 0 {
config.AllowedOrigins = originsEnv
} else {
config.AllowedOrigins = []string{} config.AllowedOrigins = []string{}
} }
config.AllowCredentials = true
default: default:
config.AllowCredentials = true
if len(originsEnv) > 0 {
config.AllowedOrigins = originsEnv
} else {
config.AllowedOrigins = []string{ config.AllowedOrigins = []string{
"http://localhost:3000", "http://localhost:3000",
"http://localhost:8080", "http://localhost:8080",
"http://127.0.0.1:3000", "http://127.0.0.1:3000",
"http://127.0.0.1:8080", "http://127.0.0.1:8080",
} }
config.AllowCredentials = true }
}
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins != "" {
config.AllowedOrigins = strings.Split(origins, ",")
} }
validateCORSConfig(config)
return config return config
} }
@@ -76,6 +110,7 @@ func setCORSHeaders(w http.ResponseWriter, origin string, hasWildcard bool, conf
} }
func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler { func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler {
validateCORSConfig(config)
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin") origin := r.Header.Get("Origin")