diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index 5420eb0..229e2e3 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -26,12 +26,7 @@ func NewCORSConfig() *CORSConfig { } switch env { - case "production": - if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" { - config.AllowedOrigins = []string{} - } - config.AllowCredentials = true - case "staging": + case "production", "staging": if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" { config.AllowedOrigins = []string{} } @@ -53,82 +48,66 @@ func NewCORSConfig() *CORSConfig { return config } +func checkOrigin(origin string, allowedOrigins []string) (allowed bool, hasWildcard bool) { + for _, allowedOrigin := range allowedOrigins { + if allowedOrigin == "*" { + hasWildcard = true + allowed = true + break + } + if allowedOrigin == origin { + allowed = true + break + } + } + return +} + +func setCORSHeaders(w http.ResponseWriter, origin string, hasWildcard bool, config *CORSConfig) { + if hasWildcard && !config.AllowCredentials { + w.Header().Set("Access-Control-Allow-Origin", "*") + } else { + w.Header().Set("Access-Control-Allow-Origin", origin) + } + + if config.AllowCredentials && !hasWildcard { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } +} + func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") - if r.Method == "OPTIONS" { - if origin != "" { - allowed := false - hasWildcard := false - for _, allowedOrigin := range config.AllowedOrigins { - if allowedOrigin == "*" { - hasWildcard = true - allowed = true - break - } - if allowedOrigin == origin { - allowed = true - break - } - } - - if !allowed { - http.Error(w, "Origin not allowed", http.StatusForbidden) - return - } - - if hasWildcard && !config.AllowCredentials { - w.Header().Set("Access-Control-Allow-Origin", "*") - } else { - w.Header().Set("Access-Control-Allow-Origin", origin) - } - - w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", ")) - w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", ")) - w.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge)) - - if config.AllowCredentials && !hasWildcard { - w.Header().Set("Access-Control-Allow-Credentials", "true") - } + if origin == "" { + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return } + next.ServeHTTP(w, r) + return + } + + allowed, hasWildcard := checkOrigin(origin, config.AllowedOrigins) + + if !allowed { + http.Error(w, "Origin not allowed", http.StatusForbidden) + return + } + + if r.Method == "OPTIONS" { + setCORSHeaders(w, origin, hasWildcard, config) + + w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", ")) + w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", ")) + w.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge)) w.WriteHeader(http.StatusOK) return } - if origin != "" { - allowed := false - hasWildcard := false - for _, allowedOrigin := range config.AllowedOrigins { - if allowedOrigin == "*" { - hasWildcard = true - allowed = true - break - } - if allowedOrigin == origin { - allowed = true - break - } - } - - if !allowed { - http.Error(w, "Origin not allowed", http.StatusForbidden) - return - } - - if hasWildcard && !config.AllowCredentials { - w.Header().Set("Access-Control-Allow-Origin", "*") - } else { - w.Header().Set("Access-Control-Allow-Origin", origin) - } - - if config.AllowCredentials && !hasWildcard { - w.Header().Set("Access-Control-Allow-Credentials", "true") - } - } - + setCORSHeaders(w, origin, hasWildcard, config) next.ServeHTTP(w, r) }) }