diff --git a/internal/middleware/cache_test.go b/internal/middleware/cache_test.go index b49e096..6acd79e 100644 --- a/internal/middleware/cache_test.go +++ b/internal/middleware/cache_test.go @@ -1,9 +1,11 @@ package middleware import ( + "crypto/tls" "fmt" "net/http" "net/http/httptest" + "strings" "testing" "time" ) @@ -551,6 +553,149 @@ func TestCacheMiddlewareRequestResponseSerialization(t *testing.T) { } } +func TestCacheMiddlewarePreservesSecurityHeaders(t *testing.T) { + testCache := NewInMemoryCache() + testConfig := &CacheConfig{ + TTL: 5 * time.Minute, + MaxSize: 1000, + CacheablePaths: []string{"/test"}, + } + + securityMiddleware := SecurityHeadersMiddleware() + cacheMiddleware := CacheMiddleware(testCache, testConfig) + + callCount := 0 + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + w.Write([]byte("test response")) + }) + + handler := securityMiddleware(cacheMiddleware(testHandler)) + + request := httptest.NewRequest("GET", "/test", nil) + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + + if recorder.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", recorder.Code) + } + if callCount != 1 { + t.Errorf("Expected handler to be called once, got %d", callCount) + } + + securityHeaders := []string{ + "X-Content-Type-Options", + "X-Frame-Options", + "X-XSS-Protection", + "Referrer-Policy", + "Content-Security-Policy", + "Permissions-Policy", + } + + for _, headerName := range securityHeaders { + if recorder.Header().Get(headerName) == "" { + t.Errorf("Expected security header %s to be present on first request", headerName) + } + } + + time.Sleep(50 * time.Millisecond) + + request2 := httptest.NewRequest("GET", "/test", nil) + recorder2 := httptest.NewRecorder() + handler.ServeHTTP(recorder2, request2) + + if recorder2.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", recorder2.Code) + } + if callCount != 1 { + t.Errorf("Expected handler to still be called once (cached), got %d", callCount) + } + if recorder2.Header().Get("X-Cache") != "HIT" { + t.Error("Expected X-Cache: HIT header") + } + + for _, headerName := range securityHeaders { + if recorder2.Header().Get(headerName) == "" { + t.Errorf("Expected security header %s to be present on cache hit", headerName) + } + if headerName == "Content-Security-Policy" { + csp1 := recorder.Header().Get(headerName) + csp2 := recorder2.Header().Get(headerName) + if !strings.Contains(csp1, "script-src") || !strings.Contains(csp2, "script-src") { + t.Errorf("Expected CSP to contain script-src directive on both requests") + } + } else { + if recorder.Header().Get(headerName) != recorder2.Header().Get(headerName) { + t.Errorf("Expected security header %s to match between first request and cache hit", headerName) + } + } + } +} + +func TestCacheMiddlewarePreservesHSTSHeader(t *testing.T) { + testCache := NewInMemoryCache() + testConfig := &CacheConfig{ + TTL: 5 * time.Minute, + MaxSize: 1000, + CacheablePaths: []string{"/test"}, + } + + hstsMiddleware := HSTSMiddleware() + cacheMiddleware := CacheMiddleware(testCache, testConfig) + + callCount := 0 + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + w.Write([]byte("test response")) + }) + + handler := hstsMiddleware(cacheMiddleware(testHandler)) + + request := httptest.NewRequest("GET", "/test", nil) + request.TLS = &tls.ConnectionState{} + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + + if recorder.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", recorder.Code) + } + if callCount != 1 { + t.Errorf("Expected handler to be called once, got %d", callCount) + } + + hstsValue := recorder.Header().Get("Strict-Transport-Security") + if hstsValue == "" { + t.Error("Expected Strict-Transport-Security header to be present on first request") + } + + time.Sleep(50 * time.Millisecond) + + request2 := httptest.NewRequest("GET", "/test", nil) + request2.TLS = &tls.ConnectionState{} + recorder2 := httptest.NewRecorder() + handler.ServeHTTP(recorder2, request2) + + if recorder2.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", recorder2.Code) + } + if callCount != 1 { + t.Errorf("Expected handler to still be called once (cached), got %d", callCount) + } + if recorder2.Header().Get("X-Cache") != "HIT" { + t.Error("Expected X-Cache: HIT header") + } + + hstsValue2 := recorder2.Header().Get("Strict-Transport-Security") + if hstsValue2 == "" { + t.Error("Expected Strict-Transport-Security header to be present on cache hit") + } + if hstsValue != hstsValue2 { + t.Error("Expected Strict-Transport-Security header to match between first request and cache hit") + } +} + func TestCacheInvalidationMiddleware(t *testing.T) { cache := NewInMemoryCache()