To gitea and beyond, let's go(-yco)
This commit is contained in:
514
internal/middleware/cors_test.go
Normal file
514
internal/middleware/cors_test.go
Normal file
@@ -0,0 +1,514 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCORSWithAuthHeader(t *testing.T) {
|
||||
config := &CORSConfig{
|
||||
AllowedOrigins: []string{"http://example.com"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowedHeaders: []string{"Content-Type", "Authorization"},
|
||||
MaxAge: 3600,
|
||||
AllowCredentials: true,
|
||||
}
|
||||
|
||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
origin string
|
||||
path string
|
||||
hasAuth bool
|
||||
expectedOrigin string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Allowed origin with auth on API path",
|
||||
origin: "http://example.com",
|
||||
path: "/api/test",
|
||||
hasAuth: true,
|
||||
expectedOrigin: "http://example.com",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Disallowed origin with auth on API path",
|
||||
origin: "http://malicious.com",
|
||||
path: "/api/test",
|
||||
hasAuth: true,
|
||||
expectedOrigin: "",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "Allowed origin without auth on API path",
|
||||
origin: "http://example.com",
|
||||
path: "/api/test",
|
||||
hasAuth: false,
|
||||
expectedOrigin: "http://example.com",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Disallowed origin without auth on API path",
|
||||
origin: "http://malicious.com",
|
||||
path: "/api/test",
|
||||
hasAuth: false,
|
||||
expectedOrigin: "",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "Allowed origin with auth on non-API path",
|
||||
origin: "http://example.com",
|
||||
path: "/public/page",
|
||||
hasAuth: true,
|
||||
expectedOrigin: "http://example.com",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Disallowed origin with auth on non-API path",
|
||||
origin: "http://malicious.com",
|
||||
path: "/public/page",
|
||||
hasAuth: true,
|
||||
expectedOrigin: "",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", tc.path, nil)
|
||||
req.Header.Set("Origin", tc.origin)
|
||||
if tc.hasAuth {
|
||||
req.Header.Set("Authorization", "Bearer fake-token")
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tc.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tc.expectedStatus, w.Code)
|
||||
}
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != tc.expectedOrigin {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be '%s', got '%s'",
|
||||
tc.expectedOrigin, w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSWithConfig_AllowedOrigin(t *testing.T) {
|
||||
config := &CORSConfig{
|
||||
AllowedOrigins: []string{"http://example.com"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowedHeaders: []string{"Content-Type"},
|
||||
MaxAge: 3600,
|
||||
AllowCredentials: true,
|
||||
}
|
||||
|
||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
if w.Header().Get("Access-Control-Allow-Credentials") != "true" {
|
||||
t.Errorf("Expected Access-Control-Allow-Credentials to be 'true', got '%s'", w.Header().Get("Access-Control-Allow-Credentials"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSWithConfig_DisallowedOrigin(t *testing.T) {
|
||||
config := &CORSConfig{
|
||||
AllowedOrigins: []string{"http://example.com"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowedHeaders: []string{"Content-Type"},
|
||||
MaxAge: 3600,
|
||||
AllowCredentials: false,
|
||||
}
|
||||
|
||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Origin", "http://malicious.com")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("Expected status 403 for disallowed origin, got %d", w.Code)
|
||||
}
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be empty for disallowed origin, got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSWithConfig_WildcardOrigin(t *testing.T) {
|
||||
config := &CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowedHeaders: []string{"Content-Type"},
|
||||
MaxAge: 3600,
|
||||
AllowCredentials: false,
|
||||
}
|
||||
|
||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Origin", "http://any-origin.com")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "*" {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be '*', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Credentials") != "" {
|
||||
t.Errorf("Expected Access-Control-Allow-Credentials to be empty with wildcard, got '%s'", w.Header().Get("Access-Control-Allow-Credentials"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) {
|
||||
config := &CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowedHeaders: []string{"Content-Type"},
|
||||
MaxAge: 3600,
|
||||
AllowCredentials: true,
|
||||
}
|
||||
|
||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Credentials") != "" {
|
||||
t.Errorf("Expected Access-Control-Allow-Credentials to be empty with wildcard, got '%s'", w.Header().Get("Access-Control-Allow-Credentials"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSWithConfig_NoOriginHeader(t *testing.T) {
|
||||
config := &CORSConfig{
|
||||
AllowedOrigins: []string{"http://example.com"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowedHeaders: []string{"Content-Type"},
|
||||
MaxAge: 3600,
|
||||
AllowCredentials: false,
|
||||
}
|
||||
|
||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be empty, got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSWithConfig_NoOriginWithWildcard(t *testing.T) {
|
||||
config := &CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowedHeaders: []string{"Content-Type"},
|
||||
MaxAge: 3600,
|
||||
AllowCredentials: false,
|
||||
}
|
||||
|
||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be empty (no origin in request), got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSWithConfig_PreflightRequest(t *testing.T) {
|
||||
config := &CORSConfig{
|
||||
AllowedOrigins: []string{"http://example.com"},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE"},
|
||||
AllowedHeaders: []string{"Content-Type", "Authorization"},
|
||||
MaxAge: 86400,
|
||||
AllowCredentials: true,
|
||||
}
|
||||
|
||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("Next handler should not be called for OPTIONS request")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("OPTIONS", "/api/test", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
if w.Header().Get("Access-Control-Allow-Methods") != "GET, POST, PUT, DELETE" {
|
||||
t.Errorf("Expected Access-Control-Allow-Methods to be 'GET, POST, PUT, DELETE', got '%s'", w.Header().Get("Access-Control-Allow-Methods"))
|
||||
}
|
||||
if w.Header().Get("Access-Control-Allow-Headers") != "Content-Type, Authorization" {
|
||||
t.Errorf("Expected Access-Control-Allow-Headers to be 'Content-Type, Authorization', got '%s'", w.Header().Get("Access-Control-Allow-Headers"))
|
||||
}
|
||||
if w.Header().Get("Access-Control-Max-Age") != "86400" {
|
||||
t.Errorf("Expected Access-Control-Max-Age to be '86400', got '%s'", w.Header().Get("Access-Control-Max-Age"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSWithConfig_MultipleAllowedOrigins(t *testing.T) {
|
||||
config := &CORSConfig{
|
||||
AllowedOrigins: []string{"http://example1.com", "http://example2.com", "http://example3.com"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowedHeaders: []string{"Content-Type"},
|
||||
MaxAge: 3600,
|
||||
AllowCredentials: true,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
origin string
|
||||
expected string
|
||||
status int
|
||||
}{
|
||||
{"http://example1.com", "http://example1.com", http.StatusOK},
|
||||
{"http://example2.com", "http://example2.com", http.StatusOK},
|
||||
{"http://example3.com", "http://example3.com", http.StatusOK},
|
||||
{"http://notallowed.com", "", http.StatusForbidden},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.origin, func(t *testing.T) {
|
||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Origin", tc.origin)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tc.status {
|
||||
t.Errorf("For origin '%s', expected status %d, got %d", tc.origin, tc.status, w.Code)
|
||||
}
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != tc.expected {
|
||||
t.Errorf("For origin '%s', expected Access-Control-Allow-Origin to be '%s', got '%s'",
|
||||
tc.origin, tc.expected, w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSWithConfig_CORSHeaders(t *testing.T) {
|
||||
config := &CORSConfig{
|
||||
AllowedOrigins: []string{"http://example.com"},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Content-Type", "Authorization", "X-Custom-Header"},
|
||||
MaxAge: 7200,
|
||||
AllowCredentials: true,
|
||||
}
|
||||
|
||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
if w.Header().Get("Access-Control-Allow-Credentials") != "true" {
|
||||
t.Errorf("Expected Access-Control-Allow-Credentials to be 'true', got '%s'", w.Header().Get("Access-Control-Allow-Credentials"))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestCORSOPTIONSRequest(t *testing.T) {
|
||||
t.Setenv("GOYCO_ENV", "development")
|
||||
t.Setenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://yourdomain.com")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("should not reach handler"))
|
||||
})
|
||||
|
||||
middleware := CORS(handler)
|
||||
request := httptest.NewRequest("OPTIONS", "/api/posts", nil)
|
||||
request.Header.Set("Origin", "http://localhost:3000")
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
middleware.ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", recorder.Code)
|
||||
}
|
||||
|
||||
if recorder.Body.String() != "" {
|
||||
t.Error("OPTIONS request should not reach the handler")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSAllowedOrigins(t *testing.T) {
|
||||
t.Setenv("GOYCO_ENV", "development")
|
||||
t.Setenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://yourdomain.com")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := CORS(handler)
|
||||
|
||||
allowedOrigins := []string{
|
||||
"http://localhost:3000",
|
||||
"https://yourdomain.com",
|
||||
}
|
||||
|
||||
unauthorizedOrigins := []string{
|
||||
"https://malicious.com",
|
||||
"http://evil.com",
|
||||
"https://attacker.net",
|
||||
}
|
||||
|
||||
for _, origin := range allowedOrigins {
|
||||
request := httptest.NewRequest("GET", "/api/auth/me", nil)
|
||||
request.Header.Set("Origin", origin)
|
||||
request.Header.Set("Authorization", "Bearer token123")
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
middleware.ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Errorf("Origin %s should be allowed, got status %d", origin, recorder.Code)
|
||||
}
|
||||
actualOrigin := recorder.Header().Get("Access-Control-Allow-Origin")
|
||||
if actualOrigin != origin {
|
||||
t.Errorf("Origin %s should be allowed, got Access-Control-Allow-Origin %s", origin, actualOrigin)
|
||||
}
|
||||
}
|
||||
|
||||
for _, origin := range unauthorizedOrigins {
|
||||
request := httptest.NewRequest("GET", "/api/auth/me", nil)
|
||||
request.Header.Set("Origin", origin)
|
||||
request.Header.Set("Authorization", "Bearer token123")
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
middleware.ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusForbidden {
|
||||
t.Errorf("Origin %s should be blocked (403), got status %d", origin, recorder.Code)
|
||||
}
|
||||
actualOrigin := recorder.Header().Get("Access-Control-Allow-Origin")
|
||||
if actualOrigin != "" {
|
||||
t.Errorf("Origin %s should be blocked, got Access-Control-Allow-Origin %s", origin, actualOrigin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSWithoutOrigin(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
allowedOrigins []string
|
||||
expectedAllowOrigin string
|
||||
shouldSetHeader bool
|
||||
}{
|
||||
{
|
||||
name: "No origin header with wildcard config",
|
||||
allowedOrigins: []string{"*"},
|
||||
expectedAllowOrigin: "",
|
||||
shouldSetHeader: false,
|
||||
},
|
||||
{
|
||||
name: "No origin header without wildcard config",
|
||||
allowedOrigins: []string{"http://example.com"},
|
||||
expectedAllowOrigin: "",
|
||||
shouldSetHeader: false,
|
||||
},
|
||||
{
|
||||
name: "No origin header with multiple specific origins",
|
||||
allowedOrigins: []string{"http://example1.com", "http://example2.com"},
|
||||
expectedAllowOrigin: "",
|
||||
shouldSetHeader: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := &CORSConfig{
|
||||
AllowedOrigins: tc.allowedOrigins,
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowedHeaders: []string{"Content-Type"},
|
||||
MaxAge: 3600,
|
||||
AllowCredentials: false,
|
||||
}
|
||||
|
||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
actualOrigin := w.Header().Get("Access-Control-Allow-Origin")
|
||||
|
||||
if tc.shouldSetHeader {
|
||||
if actualOrigin != tc.expectedAllowOrigin {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be '%s', got '%s'",
|
||||
tc.expectedAllowOrigin, actualOrigin)
|
||||
}
|
||||
} else {
|
||||
if actualOrigin != "" {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be empty (not set), got '%s'",
|
||||
actualOrigin)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user