Compare commits

..

29 Commits

Author SHA1 Message Date
Kharec f7d43def1c chore: ignore SECURITY_AUDIT.md and stop tracking it 2026-05-06 20:14:33 +02:00
Kharec d891b33b57 docs(SECURITY_AUDIT): mark Phase 2–4 remediation complete 2026-05-06 20:13:56 +02:00
Kharec 60daeddbe4 docs: proxy HSTS trust, middleware order, and Swagger gating 2026-05-06 20:13:56 +02:00
Kharec 537a7e3759 docs(.env.example): document SWAGGER_ENABLED for production Swagger 2026-05-06 20:13:56 +02:00
Kharec 194884293f test(e2e): align security header checks with CSP-only XSS defense 2026-05-06 20:13:56 +02:00
Kharec 0fbb6f4a88 test(integration): drop deprecated X-XSS-Protection expectation 2026-05-06 20:13:56 +02:00
Kharec b3f6f5b15e test(handlers): RequireAuth distinguishes missing context from user id zero 2026-05-06 20:13:56 +02:00
Kharec 2ede636bd6 test(server): Swagger hidden in production unless SWAGGER_ENABLED 2026-05-06 20:13:56 +02:00
Kharec 7c525e71cb test(middleware): encoded SQL query triggers suspicious activity log 2026-05-06 20:13:56 +02:00
Kharec 620798577e test(middleware): cache LRU, SHA-256 keys, prefix invalidation 2026-05-06 20:13:56 +02:00
Kharec b41d3bb20c fix(server): gate Swagger by env and pass cache invalidation prefixes 2026-05-06 20:13:56 +02:00
Kharec abaf46e624 test(middleware): CSP config and removed XSS auditor header 2026-05-06 20:13:56 +02:00
Kharec 61875201f9 fix(middleware): configurable Swagger CSP, log CSP nonce errors, drop X-XSS-Protection 2026-05-06 20:13:56 +02:00
Kharec d668567dc5 test(middleware): GetUserIDFromContext returns nil or pointer 2026-05-06 20:13:56 +02:00
Kharec 102f1d8400 fix(middleware): decode URL before suspicious SQL/XSS probes 2026-05-06 20:13:56 +02:00
Kharec 98985db537 fix(middleware): rate-limit key uses optional user ID pointer 2026-05-06 20:13:56 +02:00
Kharec be64e7c8d2 fix(middleware): SHA-256 keys, LRU cache, and prefix-scoped invalidation 2026-05-06 20:13:56 +02:00
Kharec 1aa256c6a8 fix(handlers): RequireAuth and VoteContext use optional user ID pointer 2026-05-06 20:07:47 +02:00
Kharec dccf85e038 fix(middleware): return *uint from GetUserIDFromContext for nil when unauthenticated 2026-05-06 20:07:41 +02:00
Kharec 4e188eb8d5 test(middleware): expect CSRF cookie readable by script for header submit 2026-05-06 20:07:35 +02:00
Kharec 2adf72c138 fix(middleware): set CSRF cookie HttpOnly false for double-submit from JS 2026-05-06 20:07:00 +02:00
Kharec add60ad3c2 test(middleware): CORS wildcard+credentials panic and trimmed env origins 2026-05-06 20:06:55 +02:00
Kharec 89131331a6 fix(middleware): validate CORS origins and reject wildcard with credentials 2026-05-06 20:06:53 +02:00
Kharec 0baf7053fc test(middleware): lock rapid-request tracker reset in TestIsRapidRequest 2026-05-06 16:47:46 +02:00
Kharec 5d145613d2 fix(middleware): add mutex for rapid-request counter 2026-05-06 16:47:35 +02:00
Kharec 12db6409ce test: cover CSRF skip behavior for Bearer vs cookie auth 2026-04-23 13:34:51 +02:00
Kharec 5fc208c9da fix: only skip CSRF for /api/ routes with Bearer tokens 2026-04-23 13:34:43 +02:00
Kharec ab17ff8b79 test: verify DecompressionMiddleware enforces size limit 2026-04-23 13:26:15 +02:00
Kharec 8990f5afb7 fix: cap decompressed request body side to prevent DoS 2026-04-23 13:26:03 +02:00
24 changed files with 668 additions and 207 deletions
+3
View File
@@ -48,6 +48,9 @@ RATE_LIMIT_TRUST_PROXY=false
# Set to: development, staging, or production
GOYCO_ENV=development
# When GOYCO_ENV=production, set to true only if you intentionally want Swagger UI mounted (default: omitted/false hides it).
SWAGGER_ENABLED=false
# CORS Configuration (optional, comma-separated)
# Example: CORS_ALLOWED_ORIGINS=https://example.com,https://www.example.com
CORS_ALLOWED_ORIGINS=
+3
View File
@@ -18,6 +18,9 @@ go.work.sum
# env file
.env
# local security audit notes (not tracked)
SECURITY_AUDIT.md
# binaries
bin/goyco
+10 -2
View File
@@ -171,13 +171,21 @@ server {
}
```
### Security headers and middleware ordering
When `RATE_LIMIT_TRUST_PROXY` is `true`, the application trusts `X-Forwarded-Proto` (among other forwarded headers) for HSTS and client IP derivation. Configure your reverse proxy to set trustworthy values and strip or overwrite any client-supplied forwarding headers before they reach the app.
Response caching uses `DecompressionMiddleware` before `DefaultRequestSizeLimitMiddleware`. Decompression is additionally capped internally so gzipped bodies cannot expand without limit before other limits apply.
In production (`GOYCO_ENV=production`), Swagger UI is not registered unless `SWAGGER_ENABLED=true`.
## API Documentation
The API is fully documented with Swagger.
The API is fully documented with Swagger (enabled for non-production environments, or when `SWAGGER_ENABLED=true`).
Once running, visit:
- **Swagger UI**: `https://goyco.example.com/swagger/index.html`
- **Swagger UI**: `https://goyco.example.com/swagger/index.html` (not served in production unless explicitly enabled via `SWAGGER_ENABLED`)
You can also use `curl` to get the API info, health check and even metrics:
-2
View File
@@ -547,7 +547,6 @@ func TestE2E_SecurityHeadersEnhanced(t *testing.T) {
expectedHeaders := map[string]string{
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Referrer-Policy": "strict-origin-when-cross-origin",
}
@@ -707,7 +706,6 @@ func TestE2E_SecurityHeaderCombinations(t *testing.T) {
requiredHeaders := []string{
"X-Content-Type-Options",
"X-Frame-Options",
"X-XSS-Protection",
"Referrer-Policy",
"Content-Security-Policy",
}
+8 -4
View File
@@ -200,17 +200,21 @@ func ParseUintParam(w http.ResponseWriter, r *http.Request, paramName, entityNam
}
func RequireAuth(w http.ResponseWriter, r *http.Request) (uint, bool) {
userID := middleware.GetUserIDFromContext(r.Context())
if userID == 0 {
userPtr := middleware.GetUserIDFromContext(r.Context())
if userPtr == nil {
SendErrorResponse(w, "Authentication required", http.StatusUnauthorized)
return 0, false
}
return userID, true
return *userPtr, true
}
func NewVoteContext(r *http.Request) services.VoteContext {
var uid uint
if userPtr := middleware.GetUserIDFromContext(r.Context()); userPtr != nil {
uid = *userPtr
}
return services.VoteContext{
UserID: middleware.GetUserIDFromContext(r.Context()),
UserID: uid,
IPAddress: GetClientIP(r),
UserAgent: r.UserAgent(),
}
+19 -8
View File
@@ -569,7 +569,8 @@ func TestParseUintParam(t *testing.T) {
func TestRequireAuth(t *testing.T) {
tests := []struct {
name string
userID uint
setUserKey bool
userIDValue uint
expectedID uint
expectedOK bool
expectedStatus int
@@ -577,25 +578,32 @@ func TestRequireAuth(t *testing.T) {
}{
{
name: "authenticated user",
userID: 123,
setUserKey: true,
userIDValue: 123,
expectedID: 123,
expectedOK: true,
expectedStatus: 0,
},
{
name: "unauthenticated user (no userID)",
userID: 0,
name: "unauthenticated user (missing context)",
setUserKey: false,
expectedID: 0,
expectedOK: false,
expectedStatus: http.StatusUnauthorized,
expectedError: "Authentication required",
},
{
name: "authenticated user with id zero",
setUserKey: true,
userIDValue: 0,
expectedID: 0,
expectedOK: true,
},
{
name: "authenticated user with large ID",
userID: 4294967295,
setUserKey: true,
userIDValue: 4294967295,
expectedID: 4294967295,
expectedOK: true,
expectedStatus: 0,
},
}
@@ -604,7 +612,10 @@ func TestRequireAuth(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(r.Context(), middleware.UserIDKey, tt.userID)
ctx := context.Background()
if tt.setUserKey {
ctx = context.WithValue(ctx, middleware.UserIDKey, tt.userIDValue)
}
r = r.WithContext(ctx)
userID, ok := RequireAuth(w, r)
@@ -22,7 +22,6 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
assertHeader(t, request, "X-Content-Type-Options")
assertHeader(t, request, "X-Frame-Options")
assertHeader(t, request, "X-XSS-Protection")
})
t.Run("CORS_Headers_Present", func(t *testing.T) {
+4 -3
View File
@@ -73,9 +73,10 @@ func NewAuth(verifier TokenVerifier) func(http.Handler) http.Handler {
}
}
func GetUserIDFromContext(ctx context.Context) uint {
func GetUserIDFromContext(ctx context.Context) *uint {
if userID, ok := ctx.Value(UserIDKey).(uint); ok {
return userID
u := userID
return &u
}
return 0
return nil
}
+14 -8
View File
@@ -28,8 +28,8 @@ func TestNewAuthWithoutAuthorization(t *testing.T) {
middleware := NewAuth(verifier)
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
if id := GetUserIDFromContext(r.Context()); id != 0 {
t.Fatalf("unexpected user id %d", id)
if id := GetUserIDFromContext(r.Context()); id != nil {
t.Fatalf("unexpected user id %v", id)
}
}))
@@ -54,8 +54,13 @@ func TestNewAuthValidToken(t *testing.T) {
handlerCalled := false
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
if id := GetUserIDFromContext(r.Context()); id != 99 {
t.Fatalf("expected user id 99, got %d", id)
id := GetUserIDFromContext(r.Context())
if id == nil || *id != 99 {
v := uint(0)
if id != nil {
v = *id
}
t.Fatalf("expected user id 99, got %d", v)
}
}))
@@ -131,11 +136,12 @@ func TestNewAuthVerifierError(t *testing.T) {
func TestGetUserIDFromContext(t *testing.T) {
ctx := context.WithValue(context.Background(), UserIDKey, uint(55))
if id := GetUserIDFromContext(ctx); id != 55 {
t.Fatalf("expected id 55, got %d", id)
id := GetUserIDFromContext(ctx)
if id == nil || *id != 55 {
t.Fatalf("expected id 55, got %v", id)
}
if id := GetUserIDFromContext(context.Background()); id != 0 {
t.Fatalf("expected zero when id missing, got %d", id)
if ptr := GetUserIDFromContext(context.Background()); ptr != nil {
t.Fatalf("expected nil when id missing, got %v", ptr)
}
}
+177 -16
View File
@@ -2,8 +2,11 @@ package middleware
import (
"bytes"
"crypto/md5"
"container/list"
"crypto/sha256"
"encoding/hex"
"fmt"
"log"
"net/http"
"strings"
"sync"
@@ -24,47 +27,192 @@ type Cache interface {
Clear() error
}
func applyCacheMaxSize(cache Cache, max int) {
if im, ok := cache.(*InMemoryCache); ok {
im.SetMaxEntries(max)
}
}
func registerIndexedCacheKey(cache Cache, cacheKey string, path string, cacheablePrefixes []string) {
if im, ok := cache.(*InMemoryCache); ok {
im.RegisterKeyForPath(cacheKey, path, cacheablePrefixes)
}
}
func invalidateCacheForMutation(cache Cache, mutationPath string, cacheablePrefixes []string) {
if im, ok := cache.(*InMemoryCache); ok {
im.InvalidateForMutationPath(mutationPath, cacheablePrefixes)
return
}
_ = cache.Clear()
}
type InMemoryCache struct {
mu sync.RWMutex
mu sync.Mutex
data map[string]*CacheEntry
maxSize int
ll *list.List
lruEl map[string]*list.Element
prefixKeys map[string]map[string]struct{}
keyPrefixes map[string]map[string]struct{}
}
func NewInMemoryCache() *InMemoryCache {
return &InMemoryCache{
data: make(map[string]*CacheEntry),
maxSize: 1000,
ll: list.New(),
lruEl: make(map[string]*list.Element),
prefixKeys: make(map[string]map[string]struct{}),
keyPrefixes: make(map[string]map[string]struct{}),
}
}
func (cache *InMemoryCache) SetMaxEntries(n int) {
cache.mu.Lock()
defer cache.mu.Unlock()
cache.maxSize = n
for n > 0 && len(cache.data) > n {
cache.evictOldestLocked()
}
}
func (cache *InMemoryCache) RegisterKeyForPath(cacheKey string, path string, cacheablePrefixes []string) {
cache.mu.Lock()
defer cache.mu.Unlock()
for _, p := range matchingCachePrefixes(path, cacheablePrefixes) {
if cache.prefixKeys[p] == nil {
cache.prefixKeys[p] = make(map[string]struct{})
}
cache.prefixKeys[p][cacheKey] = struct{}{}
if cache.keyPrefixes[cacheKey] == nil {
cache.keyPrefixes[cacheKey] = make(map[string]struct{})
}
cache.keyPrefixes[cacheKey][p] = struct{}{}
}
}
func (cache *InMemoryCache) InvalidateForMutationPath(mutationPath string, cacheablePrefixes []string) {
cache.mu.Lock()
defer cache.mu.Unlock()
seen := make(map[string]struct{})
var stale []string
for _, prefix := range cacheablePrefixes {
if !strings.HasPrefix(mutationPath, prefix) {
continue
}
for key := range cache.prefixKeys[prefix] {
if _, dup := seen[key]; dup {
continue
}
seen[key] = struct{}{}
stale = append(stale, key)
}
}
for _, key := range stale {
cache.removeKeyLocked(key)
}
}
func matchingCachePrefixes(path string, cacheablePrefixes []string) []string {
var out []string
for _, p := range cacheablePrefixes {
if strings.HasPrefix(path, p) {
out = append(out, p)
}
}
return out
}
func (cache *InMemoryCache) Get(key string) (*CacheEntry, error) {
cache.mu.RLock()
entry, exists := cache.data[key]
cache.mu.RUnlock()
cache.mu.Lock()
defer cache.mu.Unlock()
entry, exists := cache.data[key]
if !exists {
return nil, fmt.Errorf("key not found")
}
if time.Since(entry.Timestamp) > entry.TTL {
cache.mu.Lock()
delete(cache.data, key)
cache.mu.Unlock()
cache.removeKeyLocked(key)
return nil, fmt.Errorf("entry expired")
}
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
cache.ll.MoveToFront(el)
}
return entry, nil
}
func (cache *InMemoryCache) Set(key string, entry *CacheEntry) error {
cache.mu.Lock()
defer cache.mu.Unlock()
if _, exists := cache.data[key]; exists {
cache.data[key] = entry
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
cache.ll.MoveToFront(el)
}
return nil
}
cache.data[key] = entry
if cache.ll != nil {
el := cache.ll.PushFront(key)
cache.lruEl[key] = el
}
for cache.maxSize > 0 && len(cache.data) > cache.maxSize {
cache.evictOldestLocked()
}
return nil
}
func (cache *InMemoryCache) evictOldestLocked() {
if cache.ll == nil || cache.ll.Len() == 0 {
return
}
el := cache.ll.Back()
if el == nil {
return
}
key, _ := el.Value.(string)
cache.removeKeyLocked(key)
}
func (cache *InMemoryCache) removeKeyLocked(key string) {
delete(cache.data, key)
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
cache.ll.Remove(el)
}
delete(cache.lruEl, key)
for p := range cache.keyPrefixes[key] {
if m, ok := cache.prefixKeys[p]; ok {
delete(m, key)
if len(m) == 0 {
delete(cache.prefixKeys, p)
}
}
}
delete(cache.keyPrefixes, key)
}
func (cache *InMemoryCache) Delete(key string) error {
cache.mu.Lock()
defer cache.mu.Unlock()
delete(cache.data, key)
if _, ok := cache.data[key]; !ok {
return fmt.Errorf("key not found")
}
cache.removeKeyLocked(key)
return nil
}
@@ -72,6 +220,10 @@ func (cache *InMemoryCache) Clear() error {
cache.mu.Lock()
defer cache.mu.Unlock()
cache.data = make(map[string]*CacheEntry)
cache.prefixKeys = make(map[string]map[string]struct{})
cache.keyPrefixes = make(map[string]map[string]struct{})
cache.lruEl = make(map[string]*list.Element)
cache.ll = list.New()
return nil
}
@@ -95,6 +247,7 @@ func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.H
if config == nil {
config = DefaultCacheConfig()
}
applyCacheMaxSize(cache, config.MaxSize)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -141,8 +294,15 @@ func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.H
TTL: config.TTL,
}
path := r.URL.Path
prefixes := config.CacheablePaths
go func() {
cache.Set(cacheKey, entry)
if err := cache.Set(cacheKey, entry); err != nil {
log.Printf("middleware cache Set: %v", err)
return
}
registerIndexedCacheKey(cache, cacheKey, path, prefixes)
}()
}
})
@@ -190,20 +350,21 @@ func generateCacheKey(r *http.Request) string {
key += "?" + r.URL.RawQuery
}
if userID := GetUserIDFromContext(r.Context()); userID != 0 {
key += fmt.Sprintf(":user:%d", userID)
if userID := GetUserIDFromContext(r.Context()); userID != nil {
key += fmt.Sprintf(":user:%d", *userID)
}
hash := md5.Sum([]byte(key))
return fmt.Sprintf("cache:%x", hash)
sum := sha256.Sum256([]byte(key))
return "cache:" + hex.EncodeToString(sum[:])
}
func CacheInvalidationMiddleware(cache Cache) func(http.Handler) http.Handler {
func CacheInvalidationMiddleware(cache Cache, cacheablePrefixes []string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" {
mPath := r.URL.Path
go func() {
cache.Clear()
invalidateCacheForMutation(cache, mPath, cacheablePrefixes)
}()
}
+84 -69
View File
@@ -97,6 +97,29 @@ func TestInMemoryCache(t *testing.T) {
t.Error("Expected error for expired entry")
}
})
t.Run("LRU evicts oldest at max size", func(t *testing.T) {
c := NewInMemoryCache()
c.SetMaxEntries(2)
entry := func(b byte) *CacheEntry {
return &CacheEntry{Data: []byte{b}, Headers: make(http.Header), Timestamp: time.Now(), TTL: time.Hour}
}
_ = c.Set("k1", entry('a'))
_ = c.Set("k2", entry('b'))
if _, err := c.Get("k1"); err != nil {
t.Fatal(err)
}
_ = c.Set("k3", entry('c'))
if _, err := c.Get("k1"); err != nil {
t.Fatal(err)
}
if _, err := c.Get("k3"); err != nil {
t.Fatal(err)
}
if _, err := c.Get("k2"); err == nil {
t.Fatal("expected k2 evicted")
}
})
}
func TestCacheMiddleware(t *testing.T) {
@@ -255,10 +278,10 @@ func TestCacheKeyGeneration(t *testing.T) {
query string
expected string
}{
{"GET", "/test", "", "cache:e2b43a77e8b6707afcc1571382ca7c73"},
{"GET", "/test", "param=value", "cache:067b4b550d6cee93dfb106d6912ef91b"},
{"POST", "/test", "", "cache:fb3126bb69b4d21769b5fa4d78318b0e"},
{"PUT", "/users/123", "", "cache:40b0b7a2306bfd4998d6219c1ef29783"},
{"GET", "/test", "", "cache:dbbdf14ce9e8333532d3760e4e1254e9a4f9b4bd7e98446754bfc23420d5e7c9"},
{"GET", "/test", "param=value", "cache:da0e5eaf04e82e40b49ebb8f0a1c85954a207119d7e2423a9c24a94ddb189f71"},
{"POST", "/test", "", "cache:719d94211ce99e5e0d039a4a7dfa57409eadf2573544454005c1fd4f3fce988f"},
{"PUT", "/users/123", "", "cache:168e0c53c01e3f92badb40db057805a786749b1fd9be4d1562f34ba6cfac77fe"},
}
for _, tt := range tests {
@@ -587,7 +610,6 @@ func TestCacheMiddlewarePreservesSecurityHeaders(t *testing.T) {
securityHeaders := []string{
"X-Content-Type-Options",
"X-Frame-Options",
"X-XSS-Protection",
"Referrer-Policy",
"Content-Security-Policy",
"Permissions-Policy",
@@ -698,31 +720,24 @@ func TestCacheMiddlewarePreservesHSTSHeader(t *testing.T) {
func TestCacheInvalidationMiddleware(t *testing.T) {
cache := NewInMemoryCache()
prefixes := []string{"/api/posts", "/api/other"}
entries := []struct {
key string
entry *CacheEntry
}{
{"cache:abc123", &CacheEntry{Data: []byte("data1"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}},
{"cache:def456", &CacheEntry{Data: []byte("data2"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}},
{"cache:ghi789", &CacheEntry{Data: []byte("data3"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}},
}
for _, e := range entries {
if err := cache.Set(e.key, e.entry); err != nil {
setIndexed := func(key string, entry *CacheEntry, path string) {
if err := cache.Set(key, entry); err != nil {
t.Fatalf("Failed to set cache entry: %v", err)
}
cache.RegisterKeyForPath(key, path, prefixes)
}
for _, e := range entries {
if _, err := cache.Get(e.key); err != nil {
t.Fatalf("Expected entry %s to exist, got error: %v", e.key, err)
}
}
postsEntry := &CacheEntry{Data: []byte("posts"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
otherEntry := &CacheEntry{Data: []byte("other"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
middleware := CacheInvalidationMiddleware(cache)
setIndexed("postsKey", postsEntry, "/api/posts/top")
setIndexed("otherKey", otherEntry, "/api/other/x")
t.Run("POST clears cache", func(t *testing.T) {
middleware := CacheInvalidationMiddleware(cache, prefixes)
t.Run("POST under posts prefix invalidates posts keys only", func(t *testing.T) {
request := httptest.NewRequest("POST", "/api/posts", nil)
recorder := httptest.NewRecorder()
@@ -732,80 +747,80 @@ func TestCacheInvalidationMiddleware(t *testing.T) {
time.Sleep(100 * time.Millisecond)
for _, e := range entries {
if _, err := cache.Get(e.key); err == nil {
t.Errorf("Expected entry %s to be cleared, but it still exists", e.key)
if _, err := cache.Get("postsKey"); err == nil {
t.Error("expected postsKey cleared")
}
if _, err := cache.Get("otherKey"); err != nil {
t.Errorf("expected otherKey to remain: %v", err)
}
})
for _, e := range entries {
if err := cache.Set(e.key, e.entry); err != nil {
t.Fatalf("Failed to repopulate cache: %v", err)
}
}
setIndexed("postsKey", postsEntry, "/api/posts/top")
wildEntry := &CacheEntry{Data: []byte("wild"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
_ = cache.Set("untracked", wildEntry)
t.Run("PUT clears cache", func(t *testing.T) {
t.Run("mutation does not wipe untracked keys", func(t *testing.T) {
request := httptest.NewRequest("PUT", "/api/posts/1", nil)
recorder := httptest.NewRecorder()
middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(recorder, request)
time.Sleep(100 * time.Millisecond)
for _, e := range entries {
if _, err := cache.Get(e.key); err == nil {
t.Errorf("Expected entry %s to be cleared, but it still exists", e.key)
}
if _, err := cache.Get("untracked"); err != nil {
t.Fatal("untracked key should remain")
}
})
for _, e := range entries {
if err := cache.Set(e.key, e.entry); err != nil {
t.Fatalf("Failed to repopulate cache: %v", err)
t.Run("GET does not invalidate", func(t *testing.T) {
cache2 := NewInMemoryCache()
setIndexed := func(key string, path string) {
e := &CacheEntry{Data: []byte("d"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
if err := cache2.Set(key, e); err != nil {
t.Fatal(err)
}
cache2.RegisterKeyForPath(key, path, prefixes)
}
setIndexed("gk", "/api/posts/1")
t.Run("DELETE clears cache", func(t *testing.T) {
request := httptest.NewRequest("DELETE", "/api/posts/1", nil)
recorder := httptest.NewRecorder()
middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mw := CacheInvalidationMiddleware(cache2, prefixes)
req := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder()
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(recorder, request)
time.Sleep(100 * time.Millisecond)
for _, e := range entries {
if _, err := cache.Get(e.key); err == nil {
t.Errorf("Expected entry %s to be cleared, but it still exists", e.key)
}
})).ServeHTTP(rec, req)
time.Sleep(50 * time.Millisecond)
if _, err := cache2.Get("gk"); err != nil {
t.Fatal(err)
}
})
t.Run("GET does not clear cache", func(t *testing.T) {
for _, e := range entries {
if err := cache.Set(e.key, e.entry); err != nil {
t.Fatalf("Failed to repopulate cache: %v", err)
t.Run("DELETE under other prefix", func(t *testing.T) {
cache3 := NewInMemoryCache()
ep := &CacheEntry{Data: []byte("p"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
eo := &CacheEntry{Data: []byte("o"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
if err := cache3.Set("pk", ep); err != nil {
t.Fatal(err)
}
cache3.RegisterKeyForPath("pk", "/api/posts/1", prefixes)
if err := cache3.Set("ok", eo); err != nil {
t.Fatal(err)
}
cache3.RegisterKeyForPath("ok", "/api/other/y", prefixes)
request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder()
middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mw := CacheInvalidationMiddleware(cache3, prefixes)
delReq := httptest.NewRequest("DELETE", "/api/other/y", nil)
rec := httptest.NewRecorder()
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(recorder, request)
time.Sleep(100 * time.Millisecond)
for _, e := range entries {
if _, err := cache.Get(e.key); err != nil {
t.Errorf("Expected entry %s to still exist, got error: %v", e.key, err)
})).ServeHTTP(rec, delReq)
time.Sleep(50 * time.Millisecond)
if _, err := cache3.Get("ok"); err == nil {
t.Error("expected ok cleared")
}
if _, err := cache3.Get("pk"); err != nil {
t.Fatal("posts key should remain")
}
})
}
+35 -17
View File
@@ -150,23 +150,7 @@ func shouldCompressResponse(contentType string, config *CompressionConfig) bool
}
func DecompressionMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Content-Encoding") == "gzip" {
gz, err := gzip.NewReader(r.Body)
if err != nil {
http.Error(w, "Invalid gzip encoding", http.StatusBadRequest)
return
}
defer gz.Close()
r.Body = io.NopCloser(gz)
r.Header.Del("Content-Encoding")
}
next.ServeHTTP(w, r)
})
}
return DecompressionMiddlewareWithConfig(nil)
}
type CompressionConfig struct {
@@ -189,3 +173,37 @@ func DefaultCompressionConfig() *CompressionConfig {
},
}
}
type DecompressionConfig struct {
MaxDecompressedSize int64
}
func DefaultDecompressionConfig() *DecompressionConfig {
return &DecompressionConfig{
MaxDecompressedSize: 1024 * 1024, // 1MB
}
}
func DecompressionMiddlewareWithConfig(config *DecompressionConfig) func(http.Handler) http.Handler {
if config == nil {
config = DefaultDecompressionConfig()
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Content-Encoding") == "gzip" {
gz, err := gzip.NewReader(r.Body)
if err != nil {
http.Error(w, "Invalid gzip encoding", http.StatusBadRequest)
return
}
defer gz.Close()
r.Body = io.NopCloser(io.LimitReader(gz, config.MaxDecompressedSize))
r.Header.Del("Content-Encoding")
}
next.ServeHTTP(w, r)
})
}
}
+35
View File
@@ -562,6 +562,41 @@ func TestDecompressionMiddleware(t *testing.T) {
t.Errorf("Expected empty body, got '%s'", recorder.Body.String())
}
})
t.Run("Limits decompressed size", func(t *testing.T) {
config := &DecompressionConfig{
MaxDecompressedSize: 10,
}
middleware := DecompressionMiddlewareWithConfig(config)
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
gz.Write([]byte("this is more than ten bytes of data"))
gz.Close()
request := httptest.NewRequest("POST", "/test", &buf)
request.Header.Set("Content-Encoding", "gzip")
recorder := httptest.NewRecorder()
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed to read request body: %v", err)
}
w.WriteHeader(http.StatusOK)
w.Write(body)
}))
handler.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", recorder.Code)
}
if len(recorder.Body.String()) > 10 {
t.Errorf("Expected body to be truncated to <= 10 bytes, got %d bytes", len(recorder.Body.String()))
}
})
}
func TestShouldCompressWithConfig(t *testing.T) {
+42 -7
View File
@@ -15,8 +15,39 @@ type CORSConfig struct {
AllowCredentials bool
}
func parseAllowedOriginsEnv() []string {
raw := os.Getenv("CORS_ALLOWED_ORIGINS")
if raw == "" {
return nil
}
parts := strings.Split(raw, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
return out
}
func validateCORSConfig(config *CORSConfig) {
if config == nil {
panic("middleware.CORS: config is nil")
}
if !config.AllowCredentials {
return
}
for _, o := range config.AllowedOrigins {
if o == "*" {
panic("middleware.CORS: AllowCredentials with wildcard AllowedOrigins (*) is not permitted")
}
}
}
func NewCORSConfig() *CORSConfig {
env := os.Getenv("GOYCO_ENV")
originsEnv := parseAllowedOriginsEnv()
config := &CORSConfig{
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
@@ -27,24 +58,27 @@ func NewCORSConfig() *CORSConfig {
switch env {
case "production", "staging":
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" {
config.AllowCredentials = true
if len(originsEnv) > 0 {
config.AllowedOrigins = originsEnv
} else {
config.AllowedOrigins = []string{}
}
config.AllowCredentials = true
default:
config.AllowCredentials = true
if len(originsEnv) > 0 {
config.AllowedOrigins = originsEnv
} else {
config.AllowedOrigins = []string{
"http://localhost:3000",
"http://localhost:8080",
"http://127.0.0.1:3000",
"http://127.0.0.1:8080",
}
config.AllowCredentials = true
}
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins != "" {
config.AllowedOrigins = strings.Split(origins, ",")
}
}
validateCORSConfig(config)
return config
}
@@ -76,6 +110,7 @@ func setCORSHeaders(w http.ResponseWriter, origin string, hasWildcard bool, conf
}
func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler {
validateCORSConfig(config)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
+31 -19
View File
@@ -181,7 +181,12 @@ func TestCORSWithConfig_WildcardOrigin(t *testing.T) {
}
}
func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) {
func TestCORSWithConfig_WildcardWithCredentialsPanics(t *testing.T) {
defer func() {
if recover() == nil {
t.Fatal("expected panic for AllowCredentials with wildcard AllowedOrigins")
}
}()
config := &CORSConfig{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST"},
@@ -189,24 +194,7 @@ func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) {
MaxAge: 3600,
AllowCredentials: true,
}
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Origin", "http://example.com")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
}
if w.Header().Get("Access-Control-Allow-Credentials") != "" {
t.Errorf("Expected Access-Control-Allow-Credentials to be empty with wildcard, got '%s'", w.Header().Get("Access-Control-Allow-Credentials"))
}
_ = CORSWithConfig(config)
}
func TestCORSWithConfig_NoOriginHeader(t *testing.T) {
@@ -393,6 +381,30 @@ func TestCORSOPTIONSRequest(t *testing.T) {
}
}
func TestCORSAllowedOriginsTrimmedFromEnv(t *testing.T) {
t.Setenv("GOYCO_ENV", "development")
t.Setenv("CORS_ALLOWED_ORIGINS", " http://localhost:3000 , https://yourdomain.com ")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := CORS(handler)
for _, origin := range []string{"http://localhost:3000", "https://yourdomain.com"} {
request := httptest.NewRequest("GET", "/api/posts", nil)
request.Header.Set("Origin", origin)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
t.Fatalf("origin %q: status %d", origin, recorder.Code)
}
if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != origin {
t.Fatalf("origin %q: Allow-Origin %q", origin, got)
}
}
}
func TestCORSAllowedOrigins(t *testing.T) {
t.Setenv("GOYCO_ENV", "development")
t.Setenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://yourdomain.com")
+7 -2
View File
@@ -28,7 +28,7 @@ func SetCSRFToken(w http.ResponseWriter, r *http.Request, token string) {
Name: CSRFTokenCookieName,
Value: token,
Path: "/",
HttpOnly: true,
HttpOnly: false,
Secure: IsHTTPS(r),
SameSite: http.SameSiteLaxMode,
MaxAge: 3600,
@@ -71,7 +71,7 @@ func CSRFMiddleware() func(http.Handler) http.Handler {
return
}
if strings.HasPrefix(r.URL.Path, "/api/") {
if strings.HasPrefix(r.URL.Path, "/api/") && hasBearerToken(r) {
next.ServeHTTP(w, r)
return
}
@@ -86,6 +86,11 @@ func CSRFMiddleware() func(http.Handler) http.Handler {
}
}
func hasBearerToken(r *http.Request) bool {
auth := strings.TrimSpace(r.Header.Get("Authorization"))
return strings.HasPrefix(auth, "Bearer ")
}
func IsHTTPS(r *http.Request) bool {
if r.TLS != nil {
return true
+20 -4
View File
@@ -186,8 +186,9 @@ func TestCSRFMiddlewareAllowsValidToken(t *testing.T) {
}
}
func TestCSRFMiddlewareSkipsAPI(t *testing.T) {
func TestCSRFMiddlewareSkipsAPIWithBearerToken(t *testing.T) {
request := httptest.NewRequest("POST", "/api/test", nil)
request.Header.Set("Authorization", "Bearer valid-token")
recorder := httptest.NewRecorder()
handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -197,7 +198,22 @@ func TestCSRFMiddlewareSkipsAPI(t *testing.T) {
handler.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
t.Errorf("API requests should skip CSRF validation, got status %d", recorder.Code)
t.Errorf("API requests with Bearer token should skip CSRF validation, got status %d", recorder.Code)
}
}
func TestCSRFMiddlewareBlocksAPIWithoutBearerToken(t *testing.T) {
request := httptest.NewRequest("POST", "/api/test", nil)
recorder := httptest.NewRecorder()
handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
handler.ServeHTTP(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Errorf("API requests without Bearer token should require CSRF validation, got status %d", recorder.Code)
}
}
@@ -226,8 +242,8 @@ func TestSetCSRFToken(t *testing.T) {
t.Errorf("Expected cookie value %s, got %s", token, cookie.Value)
}
if !cookie.HttpOnly {
t.Error("CSRF token cookie should be HttpOnly")
if cookie.HttpOnly {
t.Error("CSRF token cookie must not be HttpOnly so JS can mirror it to X-CSRF-Token")
}
if cookie.SameSite != http.SameSiteLaxMode {
+2 -2
View File
@@ -327,8 +327,8 @@ func GetSecureClientIP(r *http.Request) string {
func GetKey(r *http.Request) string {
ip := GetSecureClientIP(r)
if userID := GetUserIDFromContext(r.Context()); userID != 0 {
return fmt.Sprintf("user:%d:ip:%s", userID, ip)
if userID := GetUserIDFromContext(r.Context()); userID != nil {
return fmt.Sprintf("user:%d:ip:%s", *userID, ip)
}
return fmt.Sprintf("ip:%s", ip)
+15 -4
View File
@@ -5,12 +5,21 @@ import (
"crypto/rand"
"encoding/base64"
"fmt"
"log"
"net/http"
"strings"
)
const CSPNonceKey contextKey = "csp_nonce"
type SecurityHeadersConfig struct {
RelaxSwaggerCSP bool
}
func DefaultSecurityHeadersConfig() SecurityHeadersConfig {
return SecurityHeadersConfig{RelaxSwaggerCSP: true}
}
func GenerateCSPNonce() (string, error) {
nonceBytes := make([]byte, 16)
if _, err := rand.Read(nonceBytes); err != nil {
@@ -27,15 +36,18 @@ func GetCSPNonceFromContext(ctx context.Context) string {
}
func SecurityHeadersMiddleware() func(http.Handler) http.Handler {
return SecurityHeadersMiddlewareWithConfig(DefaultSecurityHeadersConfig())
}
func SecurityHeadersMiddlewareWithConfig(cfg SecurityHeadersConfig) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
isSwaggerRoute := strings.HasPrefix(r.URL.Path, "/swagger")
if isSwaggerRoute {
if isSwaggerRoute && cfg.RelaxSwaggerCSP {
csp := "default-src 'self'; " +
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; " +
"style-src 'self' 'unsafe-inline'; " +
@@ -51,7 +63,7 @@ func SecurityHeadersMiddleware() func(http.Handler) http.Handler {
} else {
nonce, err := GenerateCSPNonce()
if err != nil {
log.Printf("middleware security headers: CSP nonce: %v", err)
nonce = ""
}
@@ -72,7 +84,6 @@ func SecurityHeadersMiddleware() func(http.Handler) http.Handler {
csp = "script-src 'self' 'nonce-" + nonce + "'; " +
"style-src 'self' 'nonce-" + nonce + "'; " + csp
} else {
csp = "script-src 'self'; " +
"style-src 'self'; " + csp
}
+22 -2
View File
@@ -22,7 +22,6 @@ func TestSecurityHeadersMiddleware(t *testing.T) {
expectedHeaders := map[string]string{
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Server": "",
}
@@ -176,7 +175,6 @@ func TestSecurityHeadersMiddleware_MultipleRequests(t *testing.T) {
requiredHeaders := []string{
"X-Content-Type-Options",
"X-Frame-Options",
"X-XSS-Protection",
"Referrer-Policy",
"Content-Security-Policy",
"Permissions-Policy",
@@ -289,3 +287,25 @@ func TestCSPNonceInContext(t *testing.T) {
t.Errorf("CSP header should contain nonce from context. CSP: %s, Nonce: %s", csp, capturedNonce)
}
}
func TestSecurityHeadersMiddleware_SwaggerStrictWhenRelaxedDisabled(t *testing.T) {
handler := SecurityHeadersMiddlewareWithConfig(SecurityHeadersConfig{RelaxSwaggerCSP: false})(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/swagger/index.html", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
csp := rec.Header().Get("Content-Security-Policy")
if csp == "" {
t.Fatal("expected CSP")
}
if strings.Contains(csp, "'unsafe-eval'") || strings.Contains(csp, "'unsafe-inline'") {
t.Fatalf("unexpected relaxed CSP for swagger path: %s", csp)
}
if !strings.Contains(csp, "nonce-") {
t.Fatalf("expected nonce CSP for swagger path, got %s", csp)
}
}
+42 -11
View File
@@ -3,8 +3,10 @@ package middleware
import (
"log"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
)
@@ -52,7 +54,10 @@ func SecurityLoggingMiddleware(logger *SecurityLogger) func(http.Handler) http.H
next.ServeHTTP(rw, r)
userID := GetUserIDFromContext(r.Context())
userUID := uint(0)
if u := GetUserIDFromContext(r.Context()); u != nil {
userUID = *u
}
ip := getClientIP(r)
event := SecurityEvent{
@@ -60,7 +65,7 @@ func SecurityLoggingMiddleware(logger *SecurityLogger) func(http.Handler) http.H
UserAgent: r.UserAgent(),
Path: r.URL.Path,
Method: r.Method,
UserID: userID,
UserID: userUID,
Timestamp: start,
}
@@ -105,12 +110,15 @@ func SuspiciousActivityMiddleware(logger *SecurityLogger) func(http.Handler) htt
suspicious := false
details := ""
if containsSQLInjection(r.URL.RawQuery) || containsSQLInjection(r.URL.Path) {
pathProbe := layeredUnescape(r.URL.Path, url.PathUnescape)
queryProbe := layeredUnescape(r.URL.RawQuery, url.QueryUnescape)
if containsSQLInjection(pathProbe) || containsSQLInjection(queryProbe) {
suspicious = true
details = "Potential SQL injection attempt"
}
if containsXSS(r.URL.RawQuery) || containsXSS(r.URL.Path) {
if containsXSS(pathProbe) || containsXSS(queryProbe) {
suspicious = true
details = "Potential XSS attempt"
}
@@ -158,6 +166,18 @@ func getClientIP(r *http.Request) string {
return GetSecureClientIP(r)
}
func layeredUnescape(s string, decoder func(string) (string, error)) string {
out := s
for range 3 {
d, err := decoder(out)
if err != nil || d == out {
return out
}
out = d
}
return out
}
func containsSQLInjection(input string) bool {
sqlPatterns := []string{
"' OR '1'='1",
@@ -220,18 +240,29 @@ func isSuspiciousUserAgent(userAgent string) bool {
return false
}
var requestCounts = make(map[string]int)
var lastReset = time.Now()
type rapidRequestTracker struct {
mu sync.Mutex
counts map[string]int
lastReset time.Time
}
var rapidRequests = rapidRequestTracker{
counts: make(map[string]int),
lastReset: time.Now(),
}
func isRapidRequest(ip string) bool {
rapidRequests.mu.Lock()
defer rapidRequests.mu.Unlock()
now := time.Now()
if now.Sub(lastReset) > time.Minute {
requestCounts = make(map[string]int)
lastReset = now
if now.Sub(rapidRequests.lastReset) > time.Minute {
rapidRequests.counts = make(map[string]int)
rapidRequests.lastReset = now
}
requestCounts[ip]++
rapidRequests.counts[ip]++
return requestCounts[ip] > 100
return rapidRequests.counts[ip] > 100
}
+27 -2
View File
@@ -6,6 +6,7 @@ import (
"log"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
@@ -453,8 +454,10 @@ func TestIsSuspiciousUserAgent(t *testing.T) {
func TestIsRapidRequest(t *testing.T) {
requestCounts = make(map[string]int)
lastReset = time.Now()
rapidRequests.mu.Lock()
rapidRequests.counts = make(map[string]int)
rapidRequests.lastReset = time.Now()
rapidRequests.mu.Unlock()
ip := "192.168.1.1"
@@ -564,6 +567,28 @@ func TestSuspiciousActivityMiddleware_NoSuspiciousActivity(t *testing.T) {
}
}
func TestSuspiciousActivityMiddleware_EncodedSQLInQuery(t *testing.T) {
var buf bytes.Buffer
logger := &SecurityLogger{
logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile),
}
handler := SuspiciousActivityMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
q := url.Values{}
q.Set("s", "' OR '1'='1")
req := httptest.NewRequest("GET", "/search?"+q.Encode(), nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
out := buf.String()
if !strings.Contains(out, "SQL injection") {
t.Fatalf("expected SQL injection log, got %q", out)
}
}
func TestSuspiciousActivityMiddleware_Debug(t *testing.T) {
t.Run("SQL Detection", func(t *testing.T) {
+16 -2
View File
@@ -3,6 +3,7 @@ package server
import (
"mime"
"net/http"
"os"
"path/filepath"
"strings"
"time"
@@ -32,12 +33,23 @@ type RouterConfig struct {
RateLimitConfig config.RateLimitConfig
}
func swaggerExposed() bool {
if strings.EqualFold(strings.TrimSpace(os.Getenv("SWAGGER_ENABLED")), "true") {
return true
}
return strings.ToLower(strings.TrimSpace(os.Getenv("GOYCO_ENV"))) != "production"
}
func NewRouter(cfg RouterConfig) http.Handler {
middleware.SetTrustProxyHeaders(cfg.RateLimitConfig.TrustProxyHeaders)
exposeSwagger := swaggerExposed()
router := chi.NewRouter()
router.Use(middleware.Logging(cfg.Debug))
router.Use(middleware.SecurityHeadersMiddleware())
router.Use(middleware.SecurityHeadersMiddlewareWithConfig(middleware.SecurityHeadersConfig{
RelaxSwaggerCSP: exposeSwagger,
}))
router.Use(middleware.HSTSMiddleware())
router.Use(middleware.CORS)
@@ -54,7 +66,7 @@ func NewRouter(cfg RouterConfig) http.Handler {
cacheConfig.CacheablePaths = append([]string{}, cfg.CacheablePaths...)
}
router.Use(middleware.CacheMiddleware(cache, cacheConfig))
router.Use(middleware.CacheInvalidationMiddleware(cache))
router.Use(middleware.CacheInvalidationMiddleware(cache, cacheConfig.CacheablePaths))
}
var dbMonitor middleware.DBMonitor
@@ -94,8 +106,10 @@ func NewRouter(cfg RouterConfig) http.Handler {
metricsRateLimited.Get("/metrics", cfg.APIHandler.GetMetrics)
}
if exposeSwagger {
swaggerRateLimited := router.With(middleware.GeneralRateLimitMiddlewareWithLimit(cfg.RateLimitConfig.GeneralLimit))
swaggerRateLimited.Get("/swagger/*", httpSwagger.Handler())
}
router.Get("/robots.txt", func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, filepath.Join(cfg.StaticDir, "robots.txt"))
+30
View File
@@ -348,6 +348,8 @@ func TestRouterWithoutPageHandler(t *testing.T) {
}
func TestSwaggerRoute(t *testing.T) {
t.Setenv("GOYCO_ENV", "development")
t.Setenv("SWAGGER_ENABLED", "")
router := createTestRouter(createDefaultRouterConfig())
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
@@ -360,6 +362,34 @@ func TestSwaggerRoute(t *testing.T) {
}
}
func TestSwaggerRouteHiddenInProduction(t *testing.T) {
t.Setenv("GOYCO_ENV", "production")
t.Setenv("SWAGGER_ENABLED", "")
router := createTestRouter(createDefaultRouterConfig())
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusNotFound {
t.Errorf("Expected 404 for swagger in production, got %d", recorder.Code)
}
}
func TestSwaggerRouteEnabledInProductionWhenForced(t *testing.T) {
t.Setenv("GOYCO_ENV", "production")
t.Setenv("SWAGGER_ENABLED", "true")
router := createTestRouter(createDefaultRouterConfig())
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK && recorder.Code != http.StatusMovedPermanently {
t.Errorf("Expected status 200 or 301 for swagger when SWAGGER_ENABLED, got %d", recorder.Code)
}
}
func TestStaticFileRoute(t *testing.T) {
cfg := createDefaultRouterConfig()
cfg.StaticDir = "../../internal/static/"