Compare commits
29 Commits
8f255a4fe6
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| f7d43def1c | |||
| d891b33b57 | |||
| 60daeddbe4 | |||
| 537a7e3759 | |||
| 194884293f | |||
| 0fbb6f4a88 | |||
| b3f6f5b15e | |||
| 2ede636bd6 | |||
| 7c525e71cb | |||
| 620798577e | |||
| b41d3bb20c | |||
| abaf46e624 | |||
| 61875201f9 | |||
| d668567dc5 | |||
| 102f1d8400 | |||
| 98985db537 | |||
| be64e7c8d2 | |||
| 1aa256c6a8 | |||
| dccf85e038 | |||
| 4e188eb8d5 | |||
| 2adf72c138 | |||
| add60ad3c2 | |||
| 89131331a6 | |||
| 0baf7053fc | |||
| 5d145613d2 | |||
| 12db6409ce | |||
| 5fc208c9da | |||
| ab17ff8b79 | |||
| 8990f5afb7 |
@@ -48,6 +48,9 @@ RATE_LIMIT_TRUST_PROXY=false
|
|||||||
# Set to: development, staging, or production
|
# Set to: development, staging, or production
|
||||||
GOYCO_ENV=development
|
GOYCO_ENV=development
|
||||||
|
|
||||||
|
# When GOYCO_ENV=production, set to true only if you intentionally want Swagger UI mounted (default: omitted/false hides it).
|
||||||
|
SWAGGER_ENABLED=false
|
||||||
|
|
||||||
# CORS Configuration (optional, comma-separated)
|
# CORS Configuration (optional, comma-separated)
|
||||||
# Example: CORS_ALLOWED_ORIGINS=https://example.com,https://www.example.com
|
# Example: CORS_ALLOWED_ORIGINS=https://example.com,https://www.example.com
|
||||||
CORS_ALLOWED_ORIGINS=
|
CORS_ALLOWED_ORIGINS=
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ go.work.sum
|
|||||||
# env file
|
# env file
|
||||||
.env
|
.env
|
||||||
|
|
||||||
|
# local security audit notes (not tracked)
|
||||||
|
SECURITY_AUDIT.md
|
||||||
|
|
||||||
# binaries
|
# binaries
|
||||||
bin/goyco
|
bin/goyco
|
||||||
|
|
||||||
|
|||||||
@@ -171,13 +171,21 @@ server {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Security headers and middleware ordering
|
||||||
|
|
||||||
|
When `RATE_LIMIT_TRUST_PROXY` is `true`, the application trusts `X-Forwarded-Proto` (among other forwarded headers) for HSTS and client IP derivation. Configure your reverse proxy to set trustworthy values and strip or overwrite any client-supplied forwarding headers before they reach the app.
|
||||||
|
|
||||||
|
Response caching uses `DecompressionMiddleware` before `DefaultRequestSizeLimitMiddleware`. Decompression is additionally capped internally so gzipped bodies cannot expand without limit before other limits apply.
|
||||||
|
|
||||||
|
In production (`GOYCO_ENV=production`), Swagger UI is not registered unless `SWAGGER_ENABLED=true`.
|
||||||
|
|
||||||
## API Documentation
|
## API Documentation
|
||||||
|
|
||||||
The API is fully documented with Swagger.
|
The API is fully documented with Swagger (enabled for non-production environments, or when `SWAGGER_ENABLED=true`).
|
||||||
|
|
||||||
Once running, visit:
|
Once running, visit:
|
||||||
|
|
||||||
- **Swagger UI**: `https://goyco.example.com/swagger/index.html`
|
- **Swagger UI**: `https://goyco.example.com/swagger/index.html` (not served in production unless explicitly enabled via `SWAGGER_ENABLED`)
|
||||||
|
|
||||||
You can also use `curl` to get the API info, health check and even metrics:
|
You can also use `curl` to get the API info, health check and even metrics:
|
||||||
|
|
||||||
|
|||||||
@@ -547,7 +547,6 @@ func TestE2E_SecurityHeadersEnhanced(t *testing.T) {
|
|||||||
expectedHeaders := map[string]string{
|
expectedHeaders := map[string]string{
|
||||||
"X-Content-Type-Options": "nosniff",
|
"X-Content-Type-Options": "nosniff",
|
||||||
"X-Frame-Options": "DENY",
|
"X-Frame-Options": "DENY",
|
||||||
"X-XSS-Protection": "1; mode=block",
|
|
||||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -707,7 +706,6 @@ func TestE2E_SecurityHeaderCombinations(t *testing.T) {
|
|||||||
requiredHeaders := []string{
|
requiredHeaders := []string{
|
||||||
"X-Content-Type-Options",
|
"X-Content-Type-Options",
|
||||||
"X-Frame-Options",
|
"X-Frame-Options",
|
||||||
"X-XSS-Protection",
|
|
||||||
"Referrer-Policy",
|
"Referrer-Policy",
|
||||||
"Content-Security-Policy",
|
"Content-Security-Policy",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -200,17 +200,21 @@ func ParseUintParam(w http.ResponseWriter, r *http.Request, paramName, entityNam
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RequireAuth(w http.ResponseWriter, r *http.Request) (uint, bool) {
|
func RequireAuth(w http.ResponseWriter, r *http.Request) (uint, bool) {
|
||||||
userID := middleware.GetUserIDFromContext(r.Context())
|
userPtr := middleware.GetUserIDFromContext(r.Context())
|
||||||
if userID == 0 {
|
if userPtr == nil {
|
||||||
SendErrorResponse(w, "Authentication required", http.StatusUnauthorized)
|
SendErrorResponse(w, "Authentication required", http.StatusUnauthorized)
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
return userID, true
|
return *userPtr, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewVoteContext(r *http.Request) services.VoteContext {
|
func NewVoteContext(r *http.Request) services.VoteContext {
|
||||||
|
var uid uint
|
||||||
|
if userPtr := middleware.GetUserIDFromContext(r.Context()); userPtr != nil {
|
||||||
|
uid = *userPtr
|
||||||
|
}
|
||||||
return services.VoteContext{
|
return services.VoteContext{
|
||||||
UserID: middleware.GetUserIDFromContext(r.Context()),
|
UserID: uid,
|
||||||
IPAddress: GetClientIP(r),
|
IPAddress: GetClientIP(r),
|
||||||
UserAgent: r.UserAgent(),
|
UserAgent: r.UserAgent(),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -569,7 +569,8 @@ func TestParseUintParam(t *testing.T) {
|
|||||||
func TestRequireAuth(t *testing.T) {
|
func TestRequireAuth(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
userID uint
|
setUserKey bool
|
||||||
|
userIDValue uint
|
||||||
expectedID uint
|
expectedID uint
|
||||||
expectedOK bool
|
expectedOK bool
|
||||||
expectedStatus int
|
expectedStatus int
|
||||||
@@ -577,25 +578,32 @@ func TestRequireAuth(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "authenticated user",
|
name: "authenticated user",
|
||||||
userID: 123,
|
setUserKey: true,
|
||||||
|
userIDValue: 123,
|
||||||
expectedID: 123,
|
expectedID: 123,
|
||||||
expectedOK: true,
|
expectedOK: true,
|
||||||
expectedStatus: 0,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "unauthenticated user (no userID)",
|
name: "unauthenticated user (missing context)",
|
||||||
userID: 0,
|
setUserKey: false,
|
||||||
expectedID: 0,
|
expectedID: 0,
|
||||||
expectedOK: false,
|
expectedOK: false,
|
||||||
expectedStatus: http.StatusUnauthorized,
|
expectedStatus: http.StatusUnauthorized,
|
||||||
expectedError: "Authentication required",
|
expectedError: "Authentication required",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "authenticated user with id zero",
|
||||||
|
setUserKey: true,
|
||||||
|
userIDValue: 0,
|
||||||
|
expectedID: 0,
|
||||||
|
expectedOK: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "authenticated user with large ID",
|
name: "authenticated user with large ID",
|
||||||
userID: 4294967295,
|
setUserKey: true,
|
||||||
|
userIDValue: 4294967295,
|
||||||
expectedID: 4294967295,
|
expectedID: 4294967295,
|
||||||
expectedOK: true,
|
expectedOK: true,
|
||||||
expectedStatus: 0,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -604,7 +612,10 @@ func TestRequireAuth(t *testing.T) {
|
|||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), middleware.UserIDKey, tt.userID)
|
ctx := context.Background()
|
||||||
|
if tt.setUserKey {
|
||||||
|
ctx = context.WithValue(ctx, middleware.UserIDKey, tt.userIDValue)
|
||||||
|
}
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
userID, ok := RequireAuth(w, r)
|
userID, ok := RequireAuth(w, r)
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
|
|||||||
|
|
||||||
assertHeader(t, request, "X-Content-Type-Options")
|
assertHeader(t, request, "X-Content-Type-Options")
|
||||||
assertHeader(t, request, "X-Frame-Options")
|
assertHeader(t, request, "X-Frame-Options")
|
||||||
assertHeader(t, request, "X-XSS-Protection")
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("CORS_Headers_Present", func(t *testing.T) {
|
t.Run("CORS_Headers_Present", func(t *testing.T) {
|
||||||
|
|||||||
@@ -73,9 +73,10 @@ func NewAuth(verifier TokenVerifier) func(http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserIDFromContext(ctx context.Context) uint {
|
func GetUserIDFromContext(ctx context.Context) *uint {
|
||||||
if userID, ok := ctx.Value(UserIDKey).(uint); ok {
|
if userID, ok := ctx.Value(UserIDKey).(uint); ok {
|
||||||
return userID
|
u := userID
|
||||||
|
return &u
|
||||||
}
|
}
|
||||||
return 0
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,8 +28,8 @@ func TestNewAuthWithoutAuthorization(t *testing.T) {
|
|||||||
middleware := NewAuth(verifier)
|
middleware := NewAuth(verifier)
|
||||||
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
called = true
|
called = true
|
||||||
if id := GetUserIDFromContext(r.Context()); id != 0 {
|
if id := GetUserIDFromContext(r.Context()); id != nil {
|
||||||
t.Fatalf("unexpected user id %d", id)
|
t.Fatalf("unexpected user id %v", id)
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -54,8 +54,13 @@ func TestNewAuthValidToken(t *testing.T) {
|
|||||||
handlerCalled := false
|
handlerCalled := false
|
||||||
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
handlerCalled = true
|
handlerCalled = true
|
||||||
if id := GetUserIDFromContext(r.Context()); id != 99 {
|
id := GetUserIDFromContext(r.Context())
|
||||||
t.Fatalf("expected user id 99, got %d", id)
|
if id == nil || *id != 99 {
|
||||||
|
v := uint(0)
|
||||||
|
if id != nil {
|
||||||
|
v = *id
|
||||||
|
}
|
||||||
|
t.Fatalf("expected user id 99, got %d", v)
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -131,11 +136,12 @@ func TestNewAuthVerifierError(t *testing.T) {
|
|||||||
func TestGetUserIDFromContext(t *testing.T) {
|
func TestGetUserIDFromContext(t *testing.T) {
|
||||||
ctx := context.WithValue(context.Background(), UserIDKey, uint(55))
|
ctx := context.WithValue(context.Background(), UserIDKey, uint(55))
|
||||||
|
|
||||||
if id := GetUserIDFromContext(ctx); id != 55 {
|
id := GetUserIDFromContext(ctx)
|
||||||
t.Fatalf("expected id 55, got %d", id)
|
if id == nil || *id != 55 {
|
||||||
|
t.Fatalf("expected id 55, got %v", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
if id := GetUserIDFromContext(context.Background()); id != 0 {
|
if ptr := GetUserIDFromContext(context.Background()); ptr != nil {
|
||||||
t.Fatalf("expected zero when id missing, got %d", id)
|
t.Fatalf("expected nil when id missing, got %v", ptr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+178
-17
@@ -2,8 +2,11 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/md5"
|
"container/list"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -24,47 +27,192 @@ type Cache interface {
|
|||||||
Clear() error
|
Clear() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func applyCacheMaxSize(cache Cache, max int) {
|
||||||
|
if im, ok := cache.(*InMemoryCache); ok {
|
||||||
|
im.SetMaxEntries(max)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func registerIndexedCacheKey(cache Cache, cacheKey string, path string, cacheablePrefixes []string) {
|
||||||
|
if im, ok := cache.(*InMemoryCache); ok {
|
||||||
|
im.RegisterKeyForPath(cacheKey, path, cacheablePrefixes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func invalidateCacheForMutation(cache Cache, mutationPath string, cacheablePrefixes []string) {
|
||||||
|
if im, ok := cache.(*InMemoryCache); ok {
|
||||||
|
im.InvalidateForMutationPath(mutationPath, cacheablePrefixes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = cache.Clear()
|
||||||
|
}
|
||||||
|
|
||||||
type InMemoryCache struct {
|
type InMemoryCache struct {
|
||||||
mu sync.RWMutex
|
mu sync.Mutex
|
||||||
|
|
||||||
data map[string]*CacheEntry
|
data map[string]*CacheEntry
|
||||||
|
maxSize int
|
||||||
|
|
||||||
|
ll *list.List
|
||||||
|
lruEl map[string]*list.Element
|
||||||
|
|
||||||
|
prefixKeys map[string]map[string]struct{}
|
||||||
|
keyPrefixes map[string]map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInMemoryCache() *InMemoryCache {
|
func NewInMemoryCache() *InMemoryCache {
|
||||||
return &InMemoryCache{
|
return &InMemoryCache{
|
||||||
data: make(map[string]*CacheEntry),
|
data: make(map[string]*CacheEntry),
|
||||||
|
maxSize: 1000,
|
||||||
|
ll: list.New(),
|
||||||
|
lruEl: make(map[string]*list.Element),
|
||||||
|
prefixKeys: make(map[string]map[string]struct{}),
|
||||||
|
keyPrefixes: make(map[string]map[string]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cache *InMemoryCache) Get(key string) (*CacheEntry, error) {
|
func (cache *InMemoryCache) SetMaxEntries(n int) {
|
||||||
cache.mu.RLock()
|
cache.mu.Lock()
|
||||||
entry, exists := cache.data[key]
|
defer cache.mu.Unlock()
|
||||||
cache.mu.RUnlock()
|
cache.maxSize = n
|
||||||
|
for n > 0 && len(cache.data) > n {
|
||||||
|
cache.evictOldestLocked()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *InMemoryCache) RegisterKeyForPath(cacheKey string, path string, cacheablePrefixes []string) {
|
||||||
|
cache.mu.Lock()
|
||||||
|
defer cache.mu.Unlock()
|
||||||
|
for _, p := range matchingCachePrefixes(path, cacheablePrefixes) {
|
||||||
|
if cache.prefixKeys[p] == nil {
|
||||||
|
cache.prefixKeys[p] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
cache.prefixKeys[p][cacheKey] = struct{}{}
|
||||||
|
|
||||||
|
if cache.keyPrefixes[cacheKey] == nil {
|
||||||
|
cache.keyPrefixes[cacheKey] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
cache.keyPrefixes[cacheKey][p] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *InMemoryCache) InvalidateForMutationPath(mutationPath string, cacheablePrefixes []string) {
|
||||||
|
cache.mu.Lock()
|
||||||
|
defer cache.mu.Unlock()
|
||||||
|
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
var stale []string
|
||||||
|
for _, prefix := range cacheablePrefixes {
|
||||||
|
if !strings.HasPrefix(mutationPath, prefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for key := range cache.prefixKeys[prefix] {
|
||||||
|
if _, dup := seen[key]; dup {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
stale = append(stale, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range stale {
|
||||||
|
cache.removeKeyLocked(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func matchingCachePrefixes(path string, cacheablePrefixes []string) []string {
|
||||||
|
var out []string
|
||||||
|
for _, p := range cacheablePrefixes {
|
||||||
|
if strings.HasPrefix(path, p) {
|
||||||
|
out = append(out, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *InMemoryCache) Get(key string) (*CacheEntry, error) {
|
||||||
|
cache.mu.Lock()
|
||||||
|
defer cache.mu.Unlock()
|
||||||
|
|
||||||
|
entry, exists := cache.data[key]
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, fmt.Errorf("key not found")
|
return nil, fmt.Errorf("key not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Since(entry.Timestamp) > entry.TTL {
|
if time.Since(entry.Timestamp) > entry.TTL {
|
||||||
cache.mu.Lock()
|
cache.removeKeyLocked(key)
|
||||||
delete(cache.data, key)
|
|
||||||
cache.mu.Unlock()
|
|
||||||
return nil, fmt.Errorf("entry expired")
|
return nil, fmt.Errorf("entry expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
|
||||||
|
cache.ll.MoveToFront(el)
|
||||||
|
}
|
||||||
|
|
||||||
return entry, nil
|
return entry, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cache *InMemoryCache) Set(key string, entry *CacheEntry) error {
|
func (cache *InMemoryCache) Set(key string, entry *CacheEntry) error {
|
||||||
cache.mu.Lock()
|
cache.mu.Lock()
|
||||||
defer cache.mu.Unlock()
|
defer cache.mu.Unlock()
|
||||||
|
|
||||||
|
if _, exists := cache.data[key]; exists {
|
||||||
cache.data[key] = entry
|
cache.data[key] = entry
|
||||||
|
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
|
||||||
|
cache.ll.MoveToFront(el)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.data[key] = entry
|
||||||
|
if cache.ll != nil {
|
||||||
|
el := cache.ll.PushFront(key)
|
||||||
|
cache.lruEl[key] = el
|
||||||
|
}
|
||||||
|
|
||||||
|
for cache.maxSize > 0 && len(cache.data) > cache.maxSize {
|
||||||
|
cache.evictOldestLocked()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *InMemoryCache) evictOldestLocked() {
|
||||||
|
if cache.ll == nil || cache.ll.Len() == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
el := cache.ll.Back()
|
||||||
|
if el == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key, _ := el.Value.(string)
|
||||||
|
cache.removeKeyLocked(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cache *InMemoryCache) removeKeyLocked(key string) {
|
||||||
|
delete(cache.data, key)
|
||||||
|
|
||||||
|
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
|
||||||
|
cache.ll.Remove(el)
|
||||||
|
}
|
||||||
|
delete(cache.lruEl, key)
|
||||||
|
|
||||||
|
for p := range cache.keyPrefixes[key] {
|
||||||
|
if m, ok := cache.prefixKeys[p]; ok {
|
||||||
|
delete(m, key)
|
||||||
|
if len(m) == 0 {
|
||||||
|
delete(cache.prefixKeys, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(cache.keyPrefixes, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cache *InMemoryCache) Delete(key string) error {
|
func (cache *InMemoryCache) Delete(key string) error {
|
||||||
cache.mu.Lock()
|
cache.mu.Lock()
|
||||||
defer cache.mu.Unlock()
|
defer cache.mu.Unlock()
|
||||||
delete(cache.data, key)
|
if _, ok := cache.data[key]; !ok {
|
||||||
|
return fmt.Errorf("key not found")
|
||||||
|
}
|
||||||
|
cache.removeKeyLocked(key)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,6 +220,10 @@ func (cache *InMemoryCache) Clear() error {
|
|||||||
cache.mu.Lock()
|
cache.mu.Lock()
|
||||||
defer cache.mu.Unlock()
|
defer cache.mu.Unlock()
|
||||||
cache.data = make(map[string]*CacheEntry)
|
cache.data = make(map[string]*CacheEntry)
|
||||||
|
cache.prefixKeys = make(map[string]map[string]struct{})
|
||||||
|
cache.keyPrefixes = make(map[string]map[string]struct{})
|
||||||
|
cache.lruEl = make(map[string]*list.Element)
|
||||||
|
cache.ll = list.New()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,6 +247,7 @@ func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.H
|
|||||||
if config == nil {
|
if config == nil {
|
||||||
config = DefaultCacheConfig()
|
config = DefaultCacheConfig()
|
||||||
}
|
}
|
||||||
|
applyCacheMaxSize(cache, config.MaxSize)
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -141,8 +294,15 @@ func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.H
|
|||||||
TTL: config.TTL,
|
TTL: config.TTL,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
path := r.URL.Path
|
||||||
|
prefixes := config.CacheablePaths
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
cache.Set(cacheKey, entry)
|
if err := cache.Set(cacheKey, entry); err != nil {
|
||||||
|
log.Printf("middleware cache Set: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
registerIndexedCacheKey(cache, cacheKey, path, prefixes)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -190,20 +350,21 @@ func generateCacheKey(r *http.Request) string {
|
|||||||
key += "?" + r.URL.RawQuery
|
key += "?" + r.URL.RawQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
if userID := GetUserIDFromContext(r.Context()); userID != 0 {
|
if userID := GetUserIDFromContext(r.Context()); userID != nil {
|
||||||
key += fmt.Sprintf(":user:%d", userID)
|
key += fmt.Sprintf(":user:%d", *userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
hash := md5.Sum([]byte(key))
|
sum := sha256.Sum256([]byte(key))
|
||||||
return fmt.Sprintf("cache:%x", hash)
|
return "cache:" + hex.EncodeToString(sum[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheInvalidationMiddleware(cache Cache) func(http.Handler) http.Handler {
|
func CacheInvalidationMiddleware(cache Cache, cacheablePrefixes []string) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" {
|
if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" {
|
||||||
|
mPath := r.URL.Path
|
||||||
go func() {
|
go func() {
|
||||||
cache.Clear()
|
invalidateCacheForMutation(cache, mPath, cacheablePrefixes)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -97,6 +97,29 @@ func TestInMemoryCache(t *testing.T) {
|
|||||||
t.Error("Expected error for expired entry")
|
t.Error("Expected error for expired entry")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("LRU evicts oldest at max size", func(t *testing.T) {
|
||||||
|
c := NewInMemoryCache()
|
||||||
|
c.SetMaxEntries(2)
|
||||||
|
entry := func(b byte) *CacheEntry {
|
||||||
|
return &CacheEntry{Data: []byte{b}, Headers: make(http.Header), Timestamp: time.Now(), TTL: time.Hour}
|
||||||
|
}
|
||||||
|
_ = c.Set("k1", entry('a'))
|
||||||
|
_ = c.Set("k2", entry('b'))
|
||||||
|
if _, err := c.Get("k1"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_ = c.Set("k3", entry('c'))
|
||||||
|
if _, err := c.Get("k1"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := c.Get("k3"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := c.Get("k2"); err == nil {
|
||||||
|
t.Fatal("expected k2 evicted")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCacheMiddleware(t *testing.T) {
|
func TestCacheMiddleware(t *testing.T) {
|
||||||
@@ -255,10 +278,10 @@ func TestCacheKeyGeneration(t *testing.T) {
|
|||||||
query string
|
query string
|
||||||
expected string
|
expected string
|
||||||
}{
|
}{
|
||||||
{"GET", "/test", "", "cache:e2b43a77e8b6707afcc1571382ca7c73"},
|
{"GET", "/test", "", "cache:dbbdf14ce9e8333532d3760e4e1254e9a4f9b4bd7e98446754bfc23420d5e7c9"},
|
||||||
{"GET", "/test", "param=value", "cache:067b4b550d6cee93dfb106d6912ef91b"},
|
{"GET", "/test", "param=value", "cache:da0e5eaf04e82e40b49ebb8f0a1c85954a207119d7e2423a9c24a94ddb189f71"},
|
||||||
{"POST", "/test", "", "cache:fb3126bb69b4d21769b5fa4d78318b0e"},
|
{"POST", "/test", "", "cache:719d94211ce99e5e0d039a4a7dfa57409eadf2573544454005c1fd4f3fce988f"},
|
||||||
{"PUT", "/users/123", "", "cache:40b0b7a2306bfd4998d6219c1ef29783"},
|
{"PUT", "/users/123", "", "cache:168e0c53c01e3f92badb40db057805a786749b1fd9be4d1562f34ba6cfac77fe"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -587,7 +610,6 @@ func TestCacheMiddlewarePreservesSecurityHeaders(t *testing.T) {
|
|||||||
securityHeaders := []string{
|
securityHeaders := []string{
|
||||||
"X-Content-Type-Options",
|
"X-Content-Type-Options",
|
||||||
"X-Frame-Options",
|
"X-Frame-Options",
|
||||||
"X-XSS-Protection",
|
|
||||||
"Referrer-Policy",
|
"Referrer-Policy",
|
||||||
"Content-Security-Policy",
|
"Content-Security-Policy",
|
||||||
"Permissions-Policy",
|
"Permissions-Policy",
|
||||||
@@ -698,31 +720,24 @@ func TestCacheMiddlewarePreservesHSTSHeader(t *testing.T) {
|
|||||||
|
|
||||||
func TestCacheInvalidationMiddleware(t *testing.T) {
|
func TestCacheInvalidationMiddleware(t *testing.T) {
|
||||||
cache := NewInMemoryCache()
|
cache := NewInMemoryCache()
|
||||||
|
prefixes := []string{"/api/posts", "/api/other"}
|
||||||
|
|
||||||
entries := []struct {
|
setIndexed := func(key string, entry *CacheEntry, path string) {
|
||||||
key string
|
if err := cache.Set(key, entry); err != nil {
|
||||||
entry *CacheEntry
|
|
||||||
}{
|
|
||||||
{"cache:abc123", &CacheEntry{Data: []byte("data1"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}},
|
|
||||||
{"cache:def456", &CacheEntry{Data: []byte("data2"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}},
|
|
||||||
{"cache:ghi789", &CacheEntry{Data: []byte("data3"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, e := range entries {
|
|
||||||
if err := cache.Set(e.key, e.entry); err != nil {
|
|
||||||
t.Fatalf("Failed to set cache entry: %v", err)
|
t.Fatalf("Failed to set cache entry: %v", err)
|
||||||
}
|
}
|
||||||
|
cache.RegisterKeyForPath(key, path, prefixes)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, e := range entries {
|
postsEntry := &CacheEntry{Data: []byte("posts"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
|
||||||
if _, err := cache.Get(e.key); err != nil {
|
otherEntry := &CacheEntry{Data: []byte("other"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
|
||||||
t.Fatalf("Expected entry %s to exist, got error: %v", e.key, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
middleware := CacheInvalidationMiddleware(cache)
|
setIndexed("postsKey", postsEntry, "/api/posts/top")
|
||||||
|
setIndexed("otherKey", otherEntry, "/api/other/x")
|
||||||
|
|
||||||
t.Run("POST clears cache", func(t *testing.T) {
|
middleware := CacheInvalidationMiddleware(cache, prefixes)
|
||||||
|
|
||||||
|
t.Run("POST under posts prefix invalidates posts keys only", func(t *testing.T) {
|
||||||
request := httptest.NewRequest("POST", "/api/posts", nil)
|
request := httptest.NewRequest("POST", "/api/posts", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
@@ -732,80 +747,80 @@ func TestCacheInvalidationMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
for _, e := range entries {
|
if _, err := cache.Get("postsKey"); err == nil {
|
||||||
if _, err := cache.Get(e.key); err == nil {
|
t.Error("expected postsKey cleared")
|
||||||
t.Errorf("Expected entry %s to be cleared, but it still exists", e.key)
|
|
||||||
}
|
}
|
||||||
|
if _, err := cache.Get("otherKey"); err != nil {
|
||||||
|
t.Errorf("expected otherKey to remain: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, e := range entries {
|
setIndexed("postsKey", postsEntry, "/api/posts/top")
|
||||||
if err := cache.Set(e.key, e.entry); err != nil {
|
wildEntry := &CacheEntry{Data: []byte("wild"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
|
||||||
t.Fatalf("Failed to repopulate cache: %v", err)
|
_ = cache.Set("untracked", wildEntry)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("PUT clears cache", func(t *testing.T) {
|
t.Run("mutation does not wipe untracked keys", func(t *testing.T) {
|
||||||
request := httptest.NewRequest("PUT", "/api/posts/1", nil)
|
request := httptest.NewRequest("PUT", "/api/posts/1", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})).ServeHTTP(recorder, request)
|
})).ServeHTTP(recorder, request)
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
for _, e := range entries {
|
if _, err := cache.Get("untracked"); err != nil {
|
||||||
if _, err := cache.Get(e.key); err == nil {
|
t.Fatal("untracked key should remain")
|
||||||
t.Errorf("Expected entry %s to be cleared, but it still exists", e.key)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, e := range entries {
|
t.Run("GET does not invalidate", func(t *testing.T) {
|
||||||
if err := cache.Set(e.key, e.entry); err != nil {
|
cache2 := NewInMemoryCache()
|
||||||
t.Fatalf("Failed to repopulate cache: %v", err)
|
setIndexed := func(key string, path string) {
|
||||||
|
e := &CacheEntry{Data: []byte("d"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
|
||||||
|
if err := cache2.Set(key, e); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
cache2.RegisterKeyForPath(key, path, prefixes)
|
||||||
}
|
}
|
||||||
|
setIndexed("gk", "/api/posts/1")
|
||||||
|
|
||||||
t.Run("DELETE clears cache", func(t *testing.T) {
|
mw := CacheInvalidationMiddleware(cache2, prefixes)
|
||||||
request := httptest.NewRequest("DELETE", "/api/posts/1", nil)
|
req := httptest.NewRequest("GET", "/api/posts", nil)
|
||||||
recorder := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})).ServeHTTP(recorder, request)
|
})).ServeHTTP(rec, req)
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
time.Sleep(100 * time.Millisecond)
|
if _, err := cache2.Get("gk"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
for _, e := range entries {
|
|
||||||
if _, err := cache.Get(e.key); err == nil {
|
|
||||||
t.Errorf("Expected entry %s to be cleared, but it still exists", e.key)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("GET does not clear cache", func(t *testing.T) {
|
t.Run("DELETE under other prefix", func(t *testing.T) {
|
||||||
|
cache3 := NewInMemoryCache()
|
||||||
for _, e := range entries {
|
ep := &CacheEntry{Data: []byte("p"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
|
||||||
if err := cache.Set(e.key, e.entry); err != nil {
|
eo := &CacheEntry{Data: []byte("o"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
|
||||||
t.Fatalf("Failed to repopulate cache: %v", err)
|
if err := cache3.Set("pk", ep); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
cache3.RegisterKeyForPath("pk", "/api/posts/1", prefixes)
|
||||||
|
if err := cache3.Set("ok", eo); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
cache3.RegisterKeyForPath("ok", "/api/other/y", prefixes)
|
||||||
|
|
||||||
request := httptest.NewRequest("GET", "/api/posts", nil)
|
mw := CacheInvalidationMiddleware(cache3, prefixes)
|
||||||
recorder := httptest.NewRecorder()
|
delReq := httptest.NewRequest("DELETE", "/api/other/y", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})).ServeHTTP(recorder, request)
|
})).ServeHTTP(rec, delReq)
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
time.Sleep(100 * time.Millisecond)
|
if _, err := cache3.Get("ok"); err == nil {
|
||||||
|
t.Error("expected ok cleared")
|
||||||
for _, e := range entries {
|
|
||||||
if _, err := cache.Get(e.key); err != nil {
|
|
||||||
t.Errorf("Expected entry %s to still exist, got error: %v", e.key, err)
|
|
||||||
}
|
}
|
||||||
|
if _, err := cache3.Get("pk"); err != nil {
|
||||||
|
t.Fatal("posts key should remain")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -150,23 +150,7 @@ func shouldCompressResponse(contentType string, config *CompressionConfig) bool
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DecompressionMiddleware() func(http.Handler) http.Handler {
|
func DecompressionMiddleware() func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return DecompressionMiddlewareWithConfig(nil)
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Header.Get("Content-Encoding") == "gzip" {
|
|
||||||
gz, err := gzip.NewReader(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Invalid gzip encoding", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer gz.Close()
|
|
||||||
|
|
||||||
r.Body = io.NopCloser(gz)
|
|
||||||
r.Header.Del("Content-Encoding")
|
|
||||||
}
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompressionConfig struct {
|
type CompressionConfig struct {
|
||||||
@@ -189,3 +173,37 @@ func DefaultCompressionConfig() *CompressionConfig {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DecompressionConfig struct {
|
||||||
|
MaxDecompressedSize int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultDecompressionConfig() *DecompressionConfig {
|
||||||
|
return &DecompressionConfig{
|
||||||
|
MaxDecompressedSize: 1024 * 1024, // 1MB
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecompressionMiddlewareWithConfig(config *DecompressionConfig) func(http.Handler) http.Handler {
|
||||||
|
if config == nil {
|
||||||
|
config = DefaultDecompressionConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Header.Get("Content-Encoding") == "gzip" {
|
||||||
|
gz, err := gzip.NewReader(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid gzip encoding", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer gz.Close()
|
||||||
|
|
||||||
|
r.Body = io.NopCloser(io.LimitReader(gz, config.MaxDecompressedSize))
|
||||||
|
r.Header.Del("Content-Encoding")
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -562,6 +562,41 @@ func TestDecompressionMiddleware(t *testing.T) {
|
|||||||
t.Errorf("Expected empty body, got '%s'", recorder.Body.String())
|
t.Errorf("Expected empty body, got '%s'", recorder.Body.String())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("Limits decompressed size", func(t *testing.T) {
|
||||||
|
config := &DecompressionConfig{
|
||||||
|
MaxDecompressedSize: 10,
|
||||||
|
}
|
||||||
|
middleware := DecompressionMiddlewareWithConfig(config)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gz := gzip.NewWriter(&buf)
|
||||||
|
gz.Write([]byte("this is more than ten bytes of data"))
|
||||||
|
gz.Close()
|
||||||
|
|
||||||
|
request := httptest.NewRequest("POST", "/test", &buf)
|
||||||
|
request.Header.Set("Content-Encoding", "gzip")
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read request body: %v", err)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write(body)
|
||||||
|
}))
|
||||||
|
|
||||||
|
handler.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(recorder.Body.String()) > 10 {
|
||||||
|
t.Errorf("Expected body to be truncated to <= 10 bytes, got %d bytes", len(recorder.Body.String()))
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldCompressWithConfig(t *testing.T) {
|
func TestShouldCompressWithConfig(t *testing.T) {
|
||||||
|
|||||||
@@ -15,8 +15,39 @@ type CORSConfig struct {
|
|||||||
AllowCredentials bool
|
AllowCredentials bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseAllowedOriginsEnv() []string {
|
||||||
|
raw := os.Getenv("CORS_ALLOWED_ORIGINS")
|
||||||
|
if raw == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
parts := strings.Split(raw, ",")
|
||||||
|
out := make([]string, 0, len(parts))
|
||||||
|
for _, p := range parts {
|
||||||
|
p = strings.TrimSpace(p)
|
||||||
|
if p != "" {
|
||||||
|
out = append(out, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateCORSConfig(config *CORSConfig) {
|
||||||
|
if config == nil {
|
||||||
|
panic("middleware.CORS: config is nil")
|
||||||
|
}
|
||||||
|
if !config.AllowCredentials {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, o := range config.AllowedOrigins {
|
||||||
|
if o == "*" {
|
||||||
|
panic("middleware.CORS: AllowCredentials with wildcard AllowedOrigins (*) is not permitted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func NewCORSConfig() *CORSConfig {
|
func NewCORSConfig() *CORSConfig {
|
||||||
env := os.Getenv("GOYCO_ENV")
|
env := os.Getenv("GOYCO_ENV")
|
||||||
|
originsEnv := parseAllowedOriginsEnv()
|
||||||
|
|
||||||
config := &CORSConfig{
|
config := &CORSConfig{
|
||||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||||
@@ -27,24 +58,27 @@ func NewCORSConfig() *CORSConfig {
|
|||||||
|
|
||||||
switch env {
|
switch env {
|
||||||
case "production", "staging":
|
case "production", "staging":
|
||||||
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" {
|
config.AllowCredentials = true
|
||||||
|
if len(originsEnv) > 0 {
|
||||||
|
config.AllowedOrigins = originsEnv
|
||||||
|
} else {
|
||||||
config.AllowedOrigins = []string{}
|
config.AllowedOrigins = []string{}
|
||||||
}
|
}
|
||||||
config.AllowCredentials = true
|
|
||||||
default:
|
default:
|
||||||
|
config.AllowCredentials = true
|
||||||
|
if len(originsEnv) > 0 {
|
||||||
|
config.AllowedOrigins = originsEnv
|
||||||
|
} else {
|
||||||
config.AllowedOrigins = []string{
|
config.AllowedOrigins = []string{
|
||||||
"http://localhost:3000",
|
"http://localhost:3000",
|
||||||
"http://localhost:8080",
|
"http://localhost:8080",
|
||||||
"http://127.0.0.1:3000",
|
"http://127.0.0.1:3000",
|
||||||
"http://127.0.0.1:8080",
|
"http://127.0.0.1:8080",
|
||||||
}
|
}
|
||||||
config.AllowCredentials = true
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins != "" {
|
|
||||||
config.AllowedOrigins = strings.Split(origins, ",")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
validateCORSConfig(config)
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,6 +110,7 @@ func setCORSHeaders(w http.ResponseWriter, origin string, hasWildcard bool, conf
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler {
|
func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler {
|
||||||
|
validateCORSConfig(config)
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
origin := r.Header.Get("Origin")
|
origin := r.Header.Get("Origin")
|
||||||
|
|||||||
@@ -181,7 +181,12 @@ func TestCORSWithConfig_WildcardOrigin(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) {
|
func TestCORSWithConfig_WildcardWithCredentialsPanics(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if recover() == nil {
|
||||||
|
t.Fatal("expected panic for AllowCredentials with wildcard AllowedOrigins")
|
||||||
|
}
|
||||||
|
}()
|
||||||
config := &CORSConfig{
|
config := &CORSConfig{
|
||||||
AllowedOrigins: []string{"*"},
|
AllowedOrigins: []string{"*"},
|
||||||
AllowedMethods: []string{"GET", "POST"},
|
AllowedMethods: []string{"GET", "POST"},
|
||||||
@@ -189,24 +194,7 @@ func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) {
|
|||||||
MaxAge: 3600,
|
MaxAge: 3600,
|
||||||
AllowCredentials: true,
|
AllowCredentials: true,
|
||||||
}
|
}
|
||||||
|
_ = CORSWithConfig(config)
|
||||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
req.Header.Set("Origin", "http://example.com")
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
|
|
||||||
t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
|
|
||||||
}
|
|
||||||
|
|
||||||
if w.Header().Get("Access-Control-Allow-Credentials") != "" {
|
|
||||||
t.Errorf("Expected Access-Control-Allow-Credentials to be empty with wildcard, got '%s'", w.Header().Get("Access-Control-Allow-Credentials"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCORSWithConfig_NoOriginHeader(t *testing.T) {
|
func TestCORSWithConfig_NoOriginHeader(t *testing.T) {
|
||||||
@@ -393,6 +381,30 @@ func TestCORSOPTIONSRequest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCORSAllowedOriginsTrimmedFromEnv(t *testing.T) {
|
||||||
|
t.Setenv("GOYCO_ENV", "development")
|
||||||
|
t.Setenv("CORS_ALLOWED_ORIGINS", " http://localhost:3000 , https://yourdomain.com ")
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
middleware := CORS(handler)
|
||||||
|
|
||||||
|
for _, origin := range []string{"http://localhost:3000", "https://yourdomain.com"} {
|
||||||
|
request := httptest.NewRequest("GET", "/api/posts", nil)
|
||||||
|
request.Header.Set("Origin", origin)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
middleware.ServeHTTP(recorder, request)
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
t.Fatalf("origin %q: status %d", origin, recorder.Code)
|
||||||
|
}
|
||||||
|
if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != origin {
|
||||||
|
t.Fatalf("origin %q: Allow-Origin %q", origin, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCORSAllowedOrigins(t *testing.T) {
|
func TestCORSAllowedOrigins(t *testing.T) {
|
||||||
t.Setenv("GOYCO_ENV", "development")
|
t.Setenv("GOYCO_ENV", "development")
|
||||||
t.Setenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://yourdomain.com")
|
t.Setenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://yourdomain.com")
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func SetCSRFToken(w http.ResponseWriter, r *http.Request, token string) {
|
|||||||
Name: CSRFTokenCookieName,
|
Name: CSRFTokenCookieName,
|
||||||
Value: token,
|
Value: token,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
HttpOnly: true,
|
HttpOnly: false,
|
||||||
Secure: IsHTTPS(r),
|
Secure: IsHTTPS(r),
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
MaxAge: 3600,
|
MaxAge: 3600,
|
||||||
@@ -71,7 +71,7 @@ func CSRFMiddleware() func(http.Handler) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(r.URL.Path, "/api/") {
|
if strings.HasPrefix(r.URL.Path, "/api/") && hasBearerToken(r) {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -86,6 +86,11 @@ func CSRFMiddleware() func(http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasBearerToken(r *http.Request) bool {
|
||||||
|
auth := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||||
|
return strings.HasPrefix(auth, "Bearer ")
|
||||||
|
}
|
||||||
|
|
||||||
func IsHTTPS(r *http.Request) bool {
|
func IsHTTPS(r *http.Request) bool {
|
||||||
if r.TLS != nil {
|
if r.TLS != nil {
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -186,8 +186,9 @@ func TestCSRFMiddlewareAllowsValidToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRFMiddlewareSkipsAPI(t *testing.T) {
|
func TestCSRFMiddlewareSkipsAPIWithBearerToken(t *testing.T) {
|
||||||
request := httptest.NewRequest("POST", "/api/test", nil)
|
request := httptest.NewRequest("POST", "/api/test", nil)
|
||||||
|
request.Header.Set("Authorization", "Bearer valid-token")
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -197,7 +198,22 @@ func TestCSRFMiddlewareSkipsAPI(t *testing.T) {
|
|||||||
handler.ServeHTTP(recorder, request)
|
handler.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
if recorder.Code != http.StatusOK {
|
if recorder.Code != http.StatusOK {
|
||||||
t.Errorf("API requests should skip CSRF validation, got status %d", recorder.Code)
|
t.Errorf("API requests with Bearer token should skip CSRF validation, got status %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCSRFMiddlewareBlocksAPIWithoutBearerToken(t *testing.T) {
|
||||||
|
request := httptest.NewRequest("POST", "/api/test", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
handler.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("API requests without Bearer token should require CSRF validation, got status %d", recorder.Code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -226,8 +242,8 @@ func TestSetCSRFToken(t *testing.T) {
|
|||||||
t.Errorf("Expected cookie value %s, got %s", token, cookie.Value)
|
t.Errorf("Expected cookie value %s, got %s", token, cookie.Value)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !cookie.HttpOnly {
|
if cookie.HttpOnly {
|
||||||
t.Error("CSRF token cookie should be HttpOnly")
|
t.Error("CSRF token cookie must not be HttpOnly so JS can mirror it to X-CSRF-Token")
|
||||||
}
|
}
|
||||||
|
|
||||||
if cookie.SameSite != http.SameSiteLaxMode {
|
if cookie.SameSite != http.SameSiteLaxMode {
|
||||||
|
|||||||
@@ -327,8 +327,8 @@ func GetSecureClientIP(r *http.Request) string {
|
|||||||
func GetKey(r *http.Request) string {
|
func GetKey(r *http.Request) string {
|
||||||
ip := GetSecureClientIP(r)
|
ip := GetSecureClientIP(r)
|
||||||
|
|
||||||
if userID := GetUserIDFromContext(r.Context()); userID != 0 {
|
if userID := GetUserIDFromContext(r.Context()); userID != nil {
|
||||||
return fmt.Sprintf("user:%d:ip:%s", userID, ip)
|
return fmt.Sprintf("user:%d:ip:%s", *userID, ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("ip:%s", ip)
|
return fmt.Sprintf("ip:%s", ip)
|
||||||
|
|||||||
@@ -5,12 +5,21 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
const CSPNonceKey contextKey = "csp_nonce"
|
const CSPNonceKey contextKey = "csp_nonce"
|
||||||
|
|
||||||
|
type SecurityHeadersConfig struct {
|
||||||
|
RelaxSwaggerCSP bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultSecurityHeadersConfig() SecurityHeadersConfig {
|
||||||
|
return SecurityHeadersConfig{RelaxSwaggerCSP: true}
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateCSPNonce() (string, error) {
|
func GenerateCSPNonce() (string, error) {
|
||||||
nonceBytes := make([]byte, 16)
|
nonceBytes := make([]byte, 16)
|
||||||
if _, err := rand.Read(nonceBytes); err != nil {
|
if _, err := rand.Read(nonceBytes); err != nil {
|
||||||
@@ -27,15 +36,18 @@ func GetCSPNonceFromContext(ctx context.Context) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SecurityHeadersMiddleware() func(http.Handler) http.Handler {
|
func SecurityHeadersMiddleware() func(http.Handler) http.Handler {
|
||||||
|
return SecurityHeadersMiddlewareWithConfig(DefaultSecurityHeadersConfig())
|
||||||
|
}
|
||||||
|
|
||||||
|
func SecurityHeadersMiddlewareWithConfig(cfg SecurityHeadersConfig) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
w.Header().Set("X-Frame-Options", "DENY")
|
w.Header().Set("X-Frame-Options", "DENY")
|
||||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
||||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||||
|
|
||||||
isSwaggerRoute := strings.HasPrefix(r.URL.Path, "/swagger")
|
isSwaggerRoute := strings.HasPrefix(r.URL.Path, "/swagger")
|
||||||
if isSwaggerRoute {
|
if isSwaggerRoute && cfg.RelaxSwaggerCSP {
|
||||||
csp := "default-src 'self'; " +
|
csp := "default-src 'self'; " +
|
||||||
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; " +
|
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; " +
|
||||||
"style-src 'self' 'unsafe-inline'; " +
|
"style-src 'self' 'unsafe-inline'; " +
|
||||||
@@ -51,7 +63,7 @@ func SecurityHeadersMiddleware() func(http.Handler) http.Handler {
|
|||||||
} else {
|
} else {
|
||||||
nonce, err := GenerateCSPNonce()
|
nonce, err := GenerateCSPNonce()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("middleware security headers: CSP nonce: %v", err)
|
||||||
nonce = ""
|
nonce = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,7 +84,6 @@ func SecurityHeadersMiddleware() func(http.Handler) http.Handler {
|
|||||||
csp = "script-src 'self' 'nonce-" + nonce + "'; " +
|
csp = "script-src 'self' 'nonce-" + nonce + "'; " +
|
||||||
"style-src 'self' 'nonce-" + nonce + "'; " + csp
|
"style-src 'self' 'nonce-" + nonce + "'; " + csp
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
csp = "script-src 'self'; " +
|
csp = "script-src 'self'; " +
|
||||||
"style-src 'self'; " + csp
|
"style-src 'self'; " + csp
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ func TestSecurityHeadersMiddleware(t *testing.T) {
|
|||||||
expectedHeaders := map[string]string{
|
expectedHeaders := map[string]string{
|
||||||
"X-Content-Type-Options": "nosniff",
|
"X-Content-Type-Options": "nosniff",
|
||||||
"X-Frame-Options": "DENY",
|
"X-Frame-Options": "DENY",
|
||||||
"X-XSS-Protection": "1; mode=block",
|
|
||||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||||
"Server": "",
|
"Server": "",
|
||||||
}
|
}
|
||||||
@@ -176,7 +175,6 @@ func TestSecurityHeadersMiddleware_MultipleRequests(t *testing.T) {
|
|||||||
requiredHeaders := []string{
|
requiredHeaders := []string{
|
||||||
"X-Content-Type-Options",
|
"X-Content-Type-Options",
|
||||||
"X-Frame-Options",
|
"X-Frame-Options",
|
||||||
"X-XSS-Protection",
|
|
||||||
"Referrer-Policy",
|
"Referrer-Policy",
|
||||||
"Content-Security-Policy",
|
"Content-Security-Policy",
|
||||||
"Permissions-Policy",
|
"Permissions-Policy",
|
||||||
@@ -289,3 +287,25 @@ func TestCSPNonceInContext(t *testing.T) {
|
|||||||
t.Errorf("CSP header should contain nonce from context. CSP: %s, Nonce: %s", csp, capturedNonce)
|
t.Errorf("CSP header should contain nonce from context. CSP: %s, Nonce: %s", csp, capturedNonce)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSecurityHeadersMiddleware_SwaggerStrictWhenRelaxedDisabled(t *testing.T) {
|
||||||
|
handler := SecurityHeadersMiddlewareWithConfig(SecurityHeadersConfig{RelaxSwaggerCSP: false})(
|
||||||
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/swagger/index.html", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
csp := rec.Header().Get("Content-Security-Policy")
|
||||||
|
if csp == "" {
|
||||||
|
t.Fatal("expected CSP")
|
||||||
|
}
|
||||||
|
if strings.Contains(csp, "'unsafe-eval'") || strings.Contains(csp, "'unsafe-inline'") {
|
||||||
|
t.Fatalf("unexpected relaxed CSP for swagger path: %s", csp)
|
||||||
|
}
|
||||||
|
if !strings.Contains(csp, "nonce-") {
|
||||||
|
t.Fatalf("expected nonce CSP for swagger path, got %s", csp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -52,7 +54,10 @@ func SecurityLoggingMiddleware(logger *SecurityLogger) func(http.Handler) http.H
|
|||||||
|
|
||||||
next.ServeHTTP(rw, r)
|
next.ServeHTTP(rw, r)
|
||||||
|
|
||||||
userID := GetUserIDFromContext(r.Context())
|
userUID := uint(0)
|
||||||
|
if u := GetUserIDFromContext(r.Context()); u != nil {
|
||||||
|
userUID = *u
|
||||||
|
}
|
||||||
ip := getClientIP(r)
|
ip := getClientIP(r)
|
||||||
|
|
||||||
event := SecurityEvent{
|
event := SecurityEvent{
|
||||||
@@ -60,7 +65,7 @@ func SecurityLoggingMiddleware(logger *SecurityLogger) func(http.Handler) http.H
|
|||||||
UserAgent: r.UserAgent(),
|
UserAgent: r.UserAgent(),
|
||||||
Path: r.URL.Path,
|
Path: r.URL.Path,
|
||||||
Method: r.Method,
|
Method: r.Method,
|
||||||
UserID: userID,
|
UserID: userUID,
|
||||||
Timestamp: start,
|
Timestamp: start,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,12 +110,15 @@ func SuspiciousActivityMiddleware(logger *SecurityLogger) func(http.Handler) htt
|
|||||||
suspicious := false
|
suspicious := false
|
||||||
details := ""
|
details := ""
|
||||||
|
|
||||||
if containsSQLInjection(r.URL.RawQuery) || containsSQLInjection(r.URL.Path) {
|
pathProbe := layeredUnescape(r.URL.Path, url.PathUnescape)
|
||||||
|
queryProbe := layeredUnescape(r.URL.RawQuery, url.QueryUnescape)
|
||||||
|
|
||||||
|
if containsSQLInjection(pathProbe) || containsSQLInjection(queryProbe) {
|
||||||
suspicious = true
|
suspicious = true
|
||||||
details = "Potential SQL injection attempt"
|
details = "Potential SQL injection attempt"
|
||||||
}
|
}
|
||||||
|
|
||||||
if containsXSS(r.URL.RawQuery) || containsXSS(r.URL.Path) {
|
if containsXSS(pathProbe) || containsXSS(queryProbe) {
|
||||||
suspicious = true
|
suspicious = true
|
||||||
details = "Potential XSS attempt"
|
details = "Potential XSS attempt"
|
||||||
}
|
}
|
||||||
@@ -158,6 +166,18 @@ func getClientIP(r *http.Request) string {
|
|||||||
return GetSecureClientIP(r)
|
return GetSecureClientIP(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func layeredUnescape(s string, decoder func(string) (string, error)) string {
|
||||||
|
out := s
|
||||||
|
for range 3 {
|
||||||
|
d, err := decoder(out)
|
||||||
|
if err != nil || d == out {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
out = d
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func containsSQLInjection(input string) bool {
|
func containsSQLInjection(input string) bool {
|
||||||
sqlPatterns := []string{
|
sqlPatterns := []string{
|
||||||
"' OR '1'='1",
|
"' OR '1'='1",
|
||||||
@@ -220,18 +240,29 @@ func isSuspiciousUserAgent(userAgent string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestCounts = make(map[string]int)
|
type rapidRequestTracker struct {
|
||||||
var lastReset = time.Now()
|
mu sync.Mutex
|
||||||
|
counts map[string]int
|
||||||
|
lastReset time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var rapidRequests = rapidRequestTracker{
|
||||||
|
counts: make(map[string]int),
|
||||||
|
lastReset: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
func isRapidRequest(ip string) bool {
|
func isRapidRequest(ip string) bool {
|
||||||
|
rapidRequests.mu.Lock()
|
||||||
|
defer rapidRequests.mu.Unlock()
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
if now.Sub(lastReset) > time.Minute {
|
if now.Sub(rapidRequests.lastReset) > time.Minute {
|
||||||
requestCounts = make(map[string]int)
|
rapidRequests.counts = make(map[string]int)
|
||||||
lastReset = now
|
rapidRequests.lastReset = now
|
||||||
}
|
}
|
||||||
|
|
||||||
requestCounts[ip]++
|
rapidRequests.counts[ip]++
|
||||||
|
|
||||||
return requestCounts[ip] > 100
|
return rapidRequests.counts[ip] > 100
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -453,8 +454,10 @@ func TestIsSuspiciousUserAgent(t *testing.T) {
|
|||||||
|
|
||||||
func TestIsRapidRequest(t *testing.T) {
|
func TestIsRapidRequest(t *testing.T) {
|
||||||
|
|
||||||
requestCounts = make(map[string]int)
|
rapidRequests.mu.Lock()
|
||||||
lastReset = time.Now()
|
rapidRequests.counts = make(map[string]int)
|
||||||
|
rapidRequests.lastReset = time.Now()
|
||||||
|
rapidRequests.mu.Unlock()
|
||||||
|
|
||||||
ip := "192.168.1.1"
|
ip := "192.168.1.1"
|
||||||
|
|
||||||
@@ -564,6 +567,28 @@ func TestSuspiciousActivityMiddleware_NoSuspiciousActivity(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSuspiciousActivityMiddleware_EncodedSQLInQuery(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := &SecurityLogger{
|
||||||
|
logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile),
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := SuspiciousActivityMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
q := url.Values{}
|
||||||
|
q.Set("s", "' OR '1'='1")
|
||||||
|
req := httptest.NewRequest("GET", "/search?"+q.Encode(), nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
out := buf.String()
|
||||||
|
if !strings.Contains(out, "SQL injection") {
|
||||||
|
t.Fatalf("expected SQL injection log, got %q", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSuspiciousActivityMiddleware_Debug(t *testing.T) {
|
func TestSuspiciousActivityMiddleware_Debug(t *testing.T) {
|
||||||
|
|
||||||
t.Run("SQL Detection", func(t *testing.T) {
|
t.Run("SQL Detection", func(t *testing.T) {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"mime"
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -32,12 +33,23 @@ type RouterConfig struct {
|
|||||||
RateLimitConfig config.RateLimitConfig
|
RateLimitConfig config.RateLimitConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func swaggerExposed() bool {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(os.Getenv("SWAGGER_ENABLED")), "true") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return strings.ToLower(strings.TrimSpace(os.Getenv("GOYCO_ENV"))) != "production"
|
||||||
|
}
|
||||||
|
|
||||||
func NewRouter(cfg RouterConfig) http.Handler {
|
func NewRouter(cfg RouterConfig) http.Handler {
|
||||||
middleware.SetTrustProxyHeaders(cfg.RateLimitConfig.TrustProxyHeaders)
|
middleware.SetTrustProxyHeaders(cfg.RateLimitConfig.TrustProxyHeaders)
|
||||||
|
|
||||||
|
exposeSwagger := swaggerExposed()
|
||||||
|
|
||||||
router := chi.NewRouter()
|
router := chi.NewRouter()
|
||||||
router.Use(middleware.Logging(cfg.Debug))
|
router.Use(middleware.Logging(cfg.Debug))
|
||||||
router.Use(middleware.SecurityHeadersMiddleware())
|
router.Use(middleware.SecurityHeadersMiddlewareWithConfig(middleware.SecurityHeadersConfig{
|
||||||
|
RelaxSwaggerCSP: exposeSwagger,
|
||||||
|
}))
|
||||||
router.Use(middleware.HSTSMiddleware())
|
router.Use(middleware.HSTSMiddleware())
|
||||||
router.Use(middleware.CORS)
|
router.Use(middleware.CORS)
|
||||||
|
|
||||||
@@ -54,7 +66,7 @@ func NewRouter(cfg RouterConfig) http.Handler {
|
|||||||
cacheConfig.CacheablePaths = append([]string{}, cfg.CacheablePaths...)
|
cacheConfig.CacheablePaths = append([]string{}, cfg.CacheablePaths...)
|
||||||
}
|
}
|
||||||
router.Use(middleware.CacheMiddleware(cache, cacheConfig))
|
router.Use(middleware.CacheMiddleware(cache, cacheConfig))
|
||||||
router.Use(middleware.CacheInvalidationMiddleware(cache))
|
router.Use(middleware.CacheInvalidationMiddleware(cache, cacheConfig.CacheablePaths))
|
||||||
}
|
}
|
||||||
|
|
||||||
var dbMonitor middleware.DBMonitor
|
var dbMonitor middleware.DBMonitor
|
||||||
@@ -94,8 +106,10 @@ func NewRouter(cfg RouterConfig) http.Handler {
|
|||||||
metricsRateLimited.Get("/metrics", cfg.APIHandler.GetMetrics)
|
metricsRateLimited.Get("/metrics", cfg.APIHandler.GetMetrics)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if exposeSwagger {
|
||||||
swaggerRateLimited := router.With(middleware.GeneralRateLimitMiddlewareWithLimit(cfg.RateLimitConfig.GeneralLimit))
|
swaggerRateLimited := router.With(middleware.GeneralRateLimitMiddlewareWithLimit(cfg.RateLimitConfig.GeneralLimit))
|
||||||
swaggerRateLimited.Get("/swagger/*", httpSwagger.Handler())
|
swaggerRateLimited.Get("/swagger/*", httpSwagger.Handler())
|
||||||
|
}
|
||||||
|
|
||||||
router.Get("/robots.txt", func(w http.ResponseWriter, r *http.Request) {
|
router.Get("/robots.txt", func(w http.ResponseWriter, r *http.Request) {
|
||||||
http.ServeFile(w, r, filepath.Join(cfg.StaticDir, "robots.txt"))
|
http.ServeFile(w, r, filepath.Join(cfg.StaticDir, "robots.txt"))
|
||||||
|
|||||||
@@ -348,6 +348,8 @@ func TestRouterWithoutPageHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSwaggerRoute(t *testing.T) {
|
func TestSwaggerRoute(t *testing.T) {
|
||||||
|
t.Setenv("GOYCO_ENV", "development")
|
||||||
|
t.Setenv("SWAGGER_ENABLED", "")
|
||||||
router := createTestRouter(createDefaultRouterConfig())
|
router := createTestRouter(createDefaultRouterConfig())
|
||||||
|
|
||||||
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
|
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
|
||||||
@@ -360,6 +362,34 @@ func TestSwaggerRoute(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSwaggerRouteHiddenInProduction(t *testing.T) {
|
||||||
|
t.Setenv("GOYCO_ENV", "production")
|
||||||
|
t.Setenv("SWAGGER_ENABLED", "")
|
||||||
|
router := createTestRouter(createDefaultRouterConfig())
|
||||||
|
|
||||||
|
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("Expected 404 for swagger in production, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSwaggerRouteEnabledInProductionWhenForced(t *testing.T) {
|
||||||
|
t.Setenv("GOYCO_ENV", "production")
|
||||||
|
t.Setenv("SWAGGER_ENABLED", "true")
|
||||||
|
router := createTestRouter(createDefaultRouterConfig())
|
||||||
|
|
||||||
|
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusOK && recorder.Code != http.StatusMovedPermanently {
|
||||||
|
t.Errorf("Expected status 200 or 301 for swagger when SWAGGER_ENABLED, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStaticFileRoute(t *testing.T) {
|
func TestStaticFileRoute(t *testing.T) {
|
||||||
cfg := createDefaultRouterConfig()
|
cfg := createDefaultRouterConfig()
|
||||||
cfg.StaticDir = "../../internal/static/"
|
cfg.StaticDir = "../../internal/static/"
|
||||||
|
|||||||
Reference in New Issue
Block a user