diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index 229e2e3..3571efe 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -15,8 +15,39 @@ type CORSConfig struct { AllowCredentials bool } +func parseAllowedOriginsEnv() []string { + raw := os.Getenv("CORS_ALLOWED_ORIGINS") + if raw == "" { + return nil + } + parts := strings.Split(raw, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + out = append(out, p) + } + } + return out +} + +func validateCORSConfig(config *CORSConfig) { + if config == nil { + panic("middleware.CORS: config is nil") + } + if !config.AllowCredentials { + return + } + for _, o := range config.AllowedOrigins { + if o == "*" { + panic("middleware.CORS: AllowCredentials with wildcard AllowedOrigins (*) is not permitted") + } + } +} + func NewCORSConfig() *CORSConfig { env := os.Getenv("GOYCO_ENV") + originsEnv := parseAllowedOriginsEnv() config := &CORSConfig{ AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, @@ -27,24 +58,27 @@ func NewCORSConfig() *CORSConfig { switch env { case "production", "staging": - if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" { + config.AllowCredentials = true + if len(originsEnv) > 0 { + config.AllowedOrigins = originsEnv + } else { config.AllowedOrigins = []string{} } - config.AllowCredentials = true default: - config.AllowedOrigins = []string{ - "http://localhost:3000", - "http://localhost:8080", - "http://127.0.0.1:3000", - "http://127.0.0.1:8080", - } config.AllowCredentials = true + if len(originsEnv) > 0 { + config.AllowedOrigins = originsEnv + } else { + config.AllowedOrigins = []string{ + "http://localhost:3000", + "http://localhost:8080", + "http://127.0.0.1:3000", + "http://127.0.0.1:8080", + } + } } - if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins != "" { - config.AllowedOrigins = strings.Split(origins, ",") - } - + validateCORSConfig(config) return config } @@ -76,6 +110,7 @@ func setCORSHeaders(w http.ResponseWriter, origin string, hasWildcard bool, conf } func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler { + validateCORSConfig(config) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin")