test: add security header preservation tests for cache
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -551,6 +553,149 @@ func TestCacheMiddlewareRequestResponseSerialization(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMiddlewarePreservesSecurityHeaders(t *testing.T) {
|
||||
testCache := NewInMemoryCache()
|
||||
testConfig := &CacheConfig{
|
||||
TTL: 5 * time.Minute,
|
||||
MaxSize: 1000,
|
||||
CacheablePaths: []string{"/test"},
|
||||
}
|
||||
|
||||
securityMiddleware := SecurityHeadersMiddleware()
|
||||
cacheMiddleware := CacheMiddleware(testCache, testConfig)
|
||||
|
||||
callCount := 0
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("test response"))
|
||||
})
|
||||
|
||||
handler := securityMiddleware(cacheMiddleware(testHandler))
|
||||
|
||||
request := httptest.NewRequest("GET", "/test", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", recorder.Code)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected handler to be called once, got %d", callCount)
|
||||
}
|
||||
|
||||
securityHeaders := []string{
|
||||
"X-Content-Type-Options",
|
||||
"X-Frame-Options",
|
||||
"X-XSS-Protection",
|
||||
"Referrer-Policy",
|
||||
"Content-Security-Policy",
|
||||
"Permissions-Policy",
|
||||
}
|
||||
|
||||
for _, headerName := range securityHeaders {
|
||||
if recorder.Header().Get(headerName) == "" {
|
||||
t.Errorf("Expected security header %s to be present on first request", headerName)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
request2 := httptest.NewRequest("GET", "/test", nil)
|
||||
recorder2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder2, request2)
|
||||
|
||||
if recorder2.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", recorder2.Code)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected handler to still be called once (cached), got %d", callCount)
|
||||
}
|
||||
if recorder2.Header().Get("X-Cache") != "HIT" {
|
||||
t.Error("Expected X-Cache: HIT header")
|
||||
}
|
||||
|
||||
for _, headerName := range securityHeaders {
|
||||
if recorder2.Header().Get(headerName) == "" {
|
||||
t.Errorf("Expected security header %s to be present on cache hit", headerName)
|
||||
}
|
||||
if headerName == "Content-Security-Policy" {
|
||||
csp1 := recorder.Header().Get(headerName)
|
||||
csp2 := recorder2.Header().Get(headerName)
|
||||
if !strings.Contains(csp1, "script-src") || !strings.Contains(csp2, "script-src") {
|
||||
t.Errorf("Expected CSP to contain script-src directive on both requests")
|
||||
}
|
||||
} else {
|
||||
if recorder.Header().Get(headerName) != recorder2.Header().Get(headerName) {
|
||||
t.Errorf("Expected security header %s to match between first request and cache hit", headerName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMiddlewarePreservesHSTSHeader(t *testing.T) {
|
||||
testCache := NewInMemoryCache()
|
||||
testConfig := &CacheConfig{
|
||||
TTL: 5 * time.Minute,
|
||||
MaxSize: 1000,
|
||||
CacheablePaths: []string{"/test"},
|
||||
}
|
||||
|
||||
hstsMiddleware := HSTSMiddleware()
|
||||
cacheMiddleware := CacheMiddleware(testCache, testConfig)
|
||||
|
||||
callCount := 0
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("test response"))
|
||||
})
|
||||
|
||||
handler := hstsMiddleware(cacheMiddleware(testHandler))
|
||||
|
||||
request := httptest.NewRequest("GET", "/test", nil)
|
||||
request.TLS = &tls.ConnectionState{}
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", recorder.Code)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected handler to be called once, got %d", callCount)
|
||||
}
|
||||
|
||||
hstsValue := recorder.Header().Get("Strict-Transport-Security")
|
||||
if hstsValue == "" {
|
||||
t.Error("Expected Strict-Transport-Security header to be present on first request")
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
request2 := httptest.NewRequest("GET", "/test", nil)
|
||||
request2.TLS = &tls.ConnectionState{}
|
||||
recorder2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder2, request2)
|
||||
|
||||
if recorder2.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", recorder2.Code)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected handler to still be called once (cached), got %d", callCount)
|
||||
}
|
||||
if recorder2.Header().Get("X-Cache") != "HIT" {
|
||||
t.Error("Expected X-Cache: HIT header")
|
||||
}
|
||||
|
||||
hstsValue2 := recorder2.Header().Get("Strict-Transport-Security")
|
||||
if hstsValue2 == "" {
|
||||
t.Error("Expected Strict-Transport-Security header to be present on cache hit")
|
||||
}
|
||||
if hstsValue != hstsValue2 {
|
||||
t.Error("Expected Strict-Transport-Security header to match between first request and cache hit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheInvalidationMiddleware(t *testing.T) {
|
||||
cache := NewInMemoryCache()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user