fix(middleware): validate CORS origins and reject wildcard with credentials
This commit is contained in:
@@ -15,8 +15,39 @@ type CORSConfig struct {
|
|||||||
AllowCredentials bool
|
AllowCredentials bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseAllowedOriginsEnv() []string {
|
||||||
|
raw := os.Getenv("CORS_ALLOWED_ORIGINS")
|
||||||
|
if raw == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
parts := strings.Split(raw, ",")
|
||||||
|
out := make([]string, 0, len(parts))
|
||||||
|
for _, p := range parts {
|
||||||
|
p = strings.TrimSpace(p)
|
||||||
|
if p != "" {
|
||||||
|
out = append(out, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateCORSConfig(config *CORSConfig) {
|
||||||
|
if config == nil {
|
||||||
|
panic("middleware.CORS: config is nil")
|
||||||
|
}
|
||||||
|
if !config.AllowCredentials {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, o := range config.AllowedOrigins {
|
||||||
|
if o == "*" {
|
||||||
|
panic("middleware.CORS: AllowCredentials with wildcard AllowedOrigins (*) is not permitted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func NewCORSConfig() *CORSConfig {
|
func NewCORSConfig() *CORSConfig {
|
||||||
env := os.Getenv("GOYCO_ENV")
|
env := os.Getenv("GOYCO_ENV")
|
||||||
|
originsEnv := parseAllowedOriginsEnv()
|
||||||
|
|
||||||
config := &CORSConfig{
|
config := &CORSConfig{
|
||||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||||
@@ -27,24 +58,27 @@ func NewCORSConfig() *CORSConfig {
|
|||||||
|
|
||||||
switch env {
|
switch env {
|
||||||
case "production", "staging":
|
case "production", "staging":
|
||||||
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" {
|
config.AllowCredentials = true
|
||||||
|
if len(originsEnv) > 0 {
|
||||||
|
config.AllowedOrigins = originsEnv
|
||||||
|
} else {
|
||||||
config.AllowedOrigins = []string{}
|
config.AllowedOrigins = []string{}
|
||||||
}
|
}
|
||||||
config.AllowCredentials = true
|
|
||||||
default:
|
default:
|
||||||
|
config.AllowCredentials = true
|
||||||
|
if len(originsEnv) > 0 {
|
||||||
|
config.AllowedOrigins = originsEnv
|
||||||
|
} else {
|
||||||
config.AllowedOrigins = []string{
|
config.AllowedOrigins = []string{
|
||||||
"http://localhost:3000",
|
"http://localhost:3000",
|
||||||
"http://localhost:8080",
|
"http://localhost:8080",
|
||||||
"http://127.0.0.1:3000",
|
"http://127.0.0.1:3000",
|
||||||
"http://127.0.0.1:8080",
|
"http://127.0.0.1:8080",
|
||||||
}
|
}
|
||||||
config.AllowCredentials = true
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins != "" {
|
|
||||||
config.AllowedOrigins = strings.Split(origins, ",")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
validateCORSConfig(config)
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,6 +110,7 @@ func setCORSHeaders(w http.ResponseWriter, origin string, hasWildcard bool, conf
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler {
|
func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler {
|
||||||
|
validateCORSConfig(config)
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
origin := r.Header.Get("Origin")
|
origin := r.Header.Get("Origin")
|
||||||
|
|||||||
Reference in New Issue
Block a user