test: add security header preservation tests for cache

This commit is contained in:
2025-12-26 17:33:25 +01:00
parent 77886ddef5
commit 027df4f60c

View File

@@ -1,9 +1,11 @@
package middleware package middleware
import ( import (
"crypto/tls"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
) )
@@ -551,6 +553,149 @@ func TestCacheMiddlewareRequestResponseSerialization(t *testing.T) {
} }
} }
func TestCacheMiddlewarePreservesSecurityHeaders(t *testing.T) {
testCache := NewInMemoryCache()
testConfig := &CacheConfig{
TTL: 5 * time.Minute,
MaxSize: 1000,
CacheablePaths: []string{"/test"},
}
securityMiddleware := SecurityHeadersMiddleware()
cacheMiddleware := CacheMiddleware(testCache, testConfig)
callCount := 0
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusOK)
w.Write([]byte("test response"))
})
handler := securityMiddleware(cacheMiddleware(testHandler))
request := httptest.NewRequest("GET", "/test", nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", recorder.Code)
}
if callCount != 1 {
t.Errorf("Expected handler to be called once, got %d", callCount)
}
securityHeaders := []string{
"X-Content-Type-Options",
"X-Frame-Options",
"X-XSS-Protection",
"Referrer-Policy",
"Content-Security-Policy",
"Permissions-Policy",
}
for _, headerName := range securityHeaders {
if recorder.Header().Get(headerName) == "" {
t.Errorf("Expected security header %s to be present on first request", headerName)
}
}
time.Sleep(50 * time.Millisecond)
request2 := httptest.NewRequest("GET", "/test", nil)
recorder2 := httptest.NewRecorder()
handler.ServeHTTP(recorder2, request2)
if recorder2.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", recorder2.Code)
}
if callCount != 1 {
t.Errorf("Expected handler to still be called once (cached), got %d", callCount)
}
if recorder2.Header().Get("X-Cache") != "HIT" {
t.Error("Expected X-Cache: HIT header")
}
for _, headerName := range securityHeaders {
if recorder2.Header().Get(headerName) == "" {
t.Errorf("Expected security header %s to be present on cache hit", headerName)
}
if headerName == "Content-Security-Policy" {
csp1 := recorder.Header().Get(headerName)
csp2 := recorder2.Header().Get(headerName)
if !strings.Contains(csp1, "script-src") || !strings.Contains(csp2, "script-src") {
t.Errorf("Expected CSP to contain script-src directive on both requests")
}
} else {
if recorder.Header().Get(headerName) != recorder2.Header().Get(headerName) {
t.Errorf("Expected security header %s to match between first request and cache hit", headerName)
}
}
}
}
func TestCacheMiddlewarePreservesHSTSHeader(t *testing.T) {
testCache := NewInMemoryCache()
testConfig := &CacheConfig{
TTL: 5 * time.Minute,
MaxSize: 1000,
CacheablePaths: []string{"/test"},
}
hstsMiddleware := HSTSMiddleware()
cacheMiddleware := CacheMiddleware(testCache, testConfig)
callCount := 0
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusOK)
w.Write([]byte("test response"))
})
handler := hstsMiddleware(cacheMiddleware(testHandler))
request := httptest.NewRequest("GET", "/test", nil)
request.TLS = &tls.ConnectionState{}
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", recorder.Code)
}
if callCount != 1 {
t.Errorf("Expected handler to be called once, got %d", callCount)
}
hstsValue := recorder.Header().Get("Strict-Transport-Security")
if hstsValue == "" {
t.Error("Expected Strict-Transport-Security header to be present on first request")
}
time.Sleep(50 * time.Millisecond)
request2 := httptest.NewRequest("GET", "/test", nil)
request2.TLS = &tls.ConnectionState{}
recorder2 := httptest.NewRecorder()
handler.ServeHTTP(recorder2, request2)
if recorder2.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", recorder2.Code)
}
if callCount != 1 {
t.Errorf("Expected handler to still be called once (cached), got %d", callCount)
}
if recorder2.Header().Get("X-Cache") != "HIT" {
t.Error("Expected X-Cache: HIT header")
}
hstsValue2 := recorder2.Header().Get("Strict-Transport-Security")
if hstsValue2 == "" {
t.Error("Expected Strict-Transport-Security header to be present on cache hit")
}
if hstsValue != hstsValue2 {
t.Error("Expected Strict-Transport-Security header to match between first request and cache hit")
}
}
func TestCacheInvalidationMiddleware(t *testing.T) { func TestCacheInvalidationMiddleware(t *testing.T) {
cache := NewInMemoryCache() cache := NewInMemoryCache()