Files
goyco/internal/middleware/cors_test.go

515 lines
16 KiB
Go

package middleware
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestCORSWithAuthHeader(t *testing.T) {
config := &CORSConfig{
AllowedOrigins: []string{"http://example.com"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"Content-Type", "Authorization"},
MaxAge: 3600,
AllowCredentials: true,
}
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
testCases := []struct {
name string
origin string
path string
hasAuth bool
expectedOrigin string
expectedStatus int
}{
{
name: "Allowed origin with auth on API path",
origin: "http://example.com",
path: "/api/test",
hasAuth: true,
expectedOrigin: "http://example.com",
expectedStatus: http.StatusOK,
},
{
name: "Disallowed origin with auth on API path",
origin: "http://malicious.com",
path: "/api/test",
hasAuth: true,
expectedOrigin: "",
expectedStatus: http.StatusForbidden,
},
{
name: "Allowed origin without auth on API path",
origin: "http://example.com",
path: "/api/test",
hasAuth: false,
expectedOrigin: "http://example.com",
expectedStatus: http.StatusOK,
},
{
name: "Disallowed origin without auth on API path",
origin: "http://malicious.com",
path: "/api/test",
hasAuth: false,
expectedOrigin: "",
expectedStatus: http.StatusForbidden,
},
{
name: "Allowed origin with auth on non-API path",
origin: "http://example.com",
path: "/public/page",
hasAuth: true,
expectedOrigin: "http://example.com",
expectedStatus: http.StatusOK,
},
{
name: "Disallowed origin with auth on non-API path",
origin: "http://malicious.com",
path: "/public/page",
hasAuth: true,
expectedOrigin: "",
expectedStatus: http.StatusForbidden,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", tc.path, nil)
req.Header.Set("Origin", tc.origin)
if tc.hasAuth {
req.Header.Set("Authorization", "Bearer fake-token")
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, w.Code)
}
if w.Header().Get("Access-Control-Allow-Origin") != tc.expectedOrigin {
t.Errorf("Expected Access-Control-Allow-Origin to be '%s', got '%s'",
tc.expectedOrigin, w.Header().Get("Access-Control-Allow-Origin"))
}
})
}
}
func TestCORSWithConfig_AllowedOrigin(t *testing.T) {
config := &CORSConfig{
AllowedOrigins: []string{"http://example.com"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"Content-Type"},
MaxAge: 3600,
AllowCredentials: true,
}
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Origin", "http://example.com")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
}
if w.Header().Get("Access-Control-Allow-Credentials") != "true" {
t.Errorf("Expected Access-Control-Allow-Credentials to be 'true', got '%s'", w.Header().Get("Access-Control-Allow-Credentials"))
}
}
func TestCORSWithConfig_DisallowedOrigin(t *testing.T) {
config := &CORSConfig{
AllowedOrigins: []string{"http://example.com"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"Content-Type"},
MaxAge: 3600,
AllowCredentials: false,
}
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Origin", "http://malicious.com")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Expected status 403 for disallowed origin, got %d", w.Code)
}
if w.Header().Get("Access-Control-Allow-Origin") != "" {
t.Errorf("Expected Access-Control-Allow-Origin to be empty for disallowed origin, got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
}
}
func TestCORSWithConfig_WildcardOrigin(t *testing.T) {
config := &CORSConfig{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"Content-Type"},
MaxAge: 3600,
AllowCredentials: false,
}
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Origin", "http://any-origin.com")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Header().Get("Access-Control-Allow-Origin") != "*" {
t.Errorf("Expected Access-Control-Allow-Origin to be '*', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
}
if w.Header().Get("Access-Control-Allow-Credentials") != "" {
t.Errorf("Expected Access-Control-Allow-Credentials to be empty with wildcard, got '%s'", w.Header().Get("Access-Control-Allow-Credentials"))
}
}
func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) {
config := &CORSConfig{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"Content-Type"},
MaxAge: 3600,
AllowCredentials: true,
}
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Origin", "http://example.com")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
}
if w.Header().Get("Access-Control-Allow-Credentials") != "" {
t.Errorf("Expected Access-Control-Allow-Credentials to be empty with wildcard, got '%s'", w.Header().Get("Access-Control-Allow-Credentials"))
}
}
func TestCORSWithConfig_NoOriginHeader(t *testing.T) {
config := &CORSConfig{
AllowedOrigins: []string{"http://example.com"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"Content-Type"},
MaxAge: 3600,
AllowCredentials: false,
}
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Header().Get("Access-Control-Allow-Origin") != "" {
t.Errorf("Expected Access-Control-Allow-Origin to be empty, got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
}
}
func TestCORSWithConfig_NoOriginWithWildcard(t *testing.T) {
config := &CORSConfig{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"Content-Type"},
MaxAge: 3600,
AllowCredentials: false,
}
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Header().Get("Access-Control-Allow-Origin") != "" {
t.Errorf("Expected Access-Control-Allow-Origin to be empty (no origin in request), got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
}
}
func TestCORSWithConfig_PreflightRequest(t *testing.T) {
config := &CORSConfig{
AllowedOrigins: []string{"http://example.com"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE"},
AllowedHeaders: []string{"Content-Type", "Authorization"},
MaxAge: 86400,
AllowCredentials: true,
}
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("Next handler should not be called for OPTIONS request")
}))
req := httptest.NewRequest("OPTIONS", "/api/test", nil)
req.Header.Set("Origin", "http://example.com")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
}
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
}
if w.Header().Get("Access-Control-Allow-Methods") != "GET, POST, PUT, DELETE" {
t.Errorf("Expected Access-Control-Allow-Methods to be 'GET, POST, PUT, DELETE', got '%s'", w.Header().Get("Access-Control-Allow-Methods"))
}
if w.Header().Get("Access-Control-Allow-Headers") != "Content-Type, Authorization" {
t.Errorf("Expected Access-Control-Allow-Headers to be 'Content-Type, Authorization', got '%s'", w.Header().Get("Access-Control-Allow-Headers"))
}
if w.Header().Get("Access-Control-Max-Age") != "86400" {
t.Errorf("Expected Access-Control-Max-Age to be '86400', got '%s'", w.Header().Get("Access-Control-Max-Age"))
}
}
func TestCORSWithConfig_MultipleAllowedOrigins(t *testing.T) {
config := &CORSConfig{
AllowedOrigins: []string{"http://example1.com", "http://example2.com", "http://example3.com"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"Content-Type"},
MaxAge: 3600,
AllowCredentials: true,
}
testCases := []struct {
origin string
expected string
status int
}{
{"http://example1.com", "http://example1.com", http.StatusOK},
{"http://example2.com", "http://example2.com", http.StatusOK},
{"http://example3.com", "http://example3.com", http.StatusOK},
{"http://notallowed.com", "", http.StatusForbidden},
}
for _, tc := range testCases {
t.Run(tc.origin, func(t *testing.T) {
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Origin", tc.origin)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != tc.status {
t.Errorf("For origin '%s', expected status %d, got %d", tc.origin, tc.status, w.Code)
}
if w.Header().Get("Access-Control-Allow-Origin") != tc.expected {
t.Errorf("For origin '%s', expected Access-Control-Allow-Origin to be '%s', got '%s'",
tc.origin, tc.expected, w.Header().Get("Access-Control-Allow-Origin"))
}
})
}
}
func TestCORSWithConfig_CORSHeaders(t *testing.T) {
config := &CORSConfig{
AllowedOrigins: []string{"http://example.com"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Content-Type", "Authorization", "X-Custom-Header"},
MaxAge: 7200,
AllowCredentials: true,
}
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Origin", "http://example.com")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
}
if w.Header().Get("Access-Control-Allow-Credentials") != "true" {
t.Errorf("Expected Access-Control-Allow-Credentials to be 'true', got '%s'", w.Header().Get("Access-Control-Allow-Credentials"))
}
}
func TestCORSOPTIONSRequest(t *testing.T) {
t.Setenv("GOYCO_ENV", "development")
t.Setenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://yourdomain.com")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("should not reach handler"))
})
middleware := CORS(handler)
request := httptest.NewRequest("OPTIONS", "/api/posts", nil)
request.Header.Set("Origin", "http://localhost:3000")
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", recorder.Code)
}
if recorder.Body.String() != "" {
t.Error("OPTIONS request should not reach the handler")
}
}
func TestCORSAllowedOrigins(t *testing.T) {
t.Setenv("GOYCO_ENV", "development")
t.Setenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://yourdomain.com")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := CORS(handler)
allowedOrigins := []string{
"http://localhost:3000",
"https://yourdomain.com",
}
unauthorizedOrigins := []string{
"https://malicious.com",
"http://evil.com",
"https://attacker.net",
}
for _, origin := range allowedOrigins {
request := httptest.NewRequest("GET", "/api/auth/me", nil)
request.Header.Set("Origin", origin)
request.Header.Set("Authorization", "Bearer token123")
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
t.Errorf("Origin %s should be allowed, got status %d", origin, recorder.Code)
}
actualOrigin := recorder.Header().Get("Access-Control-Allow-Origin")
if actualOrigin != origin {
t.Errorf("Origin %s should be allowed, got Access-Control-Allow-Origin %s", origin, actualOrigin)
}
}
for _, origin := range unauthorizedOrigins {
request := httptest.NewRequest("GET", "/api/auth/me", nil)
request.Header.Set("Origin", origin)
request.Header.Set("Authorization", "Bearer token123")
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Errorf("Origin %s should be blocked (403), got status %d", origin, recorder.Code)
}
actualOrigin := recorder.Header().Get("Access-Control-Allow-Origin")
if actualOrigin != "" {
t.Errorf("Origin %s should be blocked, got Access-Control-Allow-Origin %s", origin, actualOrigin)
}
}
}
func TestCORSWithoutOrigin(t *testing.T) {
testCases := []struct {
name string
allowedOrigins []string
expectedAllowOrigin string
shouldSetHeader bool
}{
{
name: "No origin header with wildcard config",
allowedOrigins: []string{"*"},
expectedAllowOrigin: "",
shouldSetHeader: false,
},
{
name: "No origin header without wildcard config",
allowedOrigins: []string{"http://example.com"},
expectedAllowOrigin: "",
shouldSetHeader: false,
},
{
name: "No origin header with multiple specific origins",
allowedOrigins: []string{"http://example1.com", "http://example2.com"},
expectedAllowOrigin: "",
shouldSetHeader: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
config := &CORSConfig{
AllowedOrigins: tc.allowedOrigins,
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"Content-Type"},
MaxAge: 3600,
AllowCredentials: false,
}
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
actualOrigin := w.Header().Get("Access-Control-Allow-Origin")
if tc.shouldSetHeader {
if actualOrigin != tc.expectedAllowOrigin {
t.Errorf("Expected Access-Control-Allow-Origin to be '%s', got '%s'",
tc.expectedAllowOrigin, actualOrigin)
}
} else {
if actualOrigin != "" {
t.Errorf("Expected Access-Control-Allow-Origin to be empty (not set), got '%s'",
actualOrigin)
}
}
})
}
}