Compare commits
23 Commits
0baf7053fc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| f7d43def1c | |||
| d891b33b57 | |||
| 60daeddbe4 | |||
| 537a7e3759 | |||
| 194884293f | |||
| 0fbb6f4a88 | |||
| b3f6f5b15e | |||
| 2ede636bd6 | |||
| 7c525e71cb | |||
| 620798577e | |||
| b41d3bb20c | |||
| abaf46e624 | |||
| 61875201f9 | |||
| d668567dc5 | |||
| 102f1d8400 | |||
| 98985db537 | |||
| be64e7c8d2 | |||
| 1aa256c6a8 | |||
| dccf85e038 | |||
| 4e188eb8d5 | |||
| 2adf72c138 | |||
| add60ad3c2 | |||
| 89131331a6 |
@@ -48,6 +48,9 @@ RATE_LIMIT_TRUST_PROXY=false
|
||||
# Set to: development, staging, or production
|
||||
GOYCO_ENV=development
|
||||
|
||||
# When GOYCO_ENV=production, set to true only if you intentionally want Swagger UI mounted (default: omitted/false hides it).
|
||||
SWAGGER_ENABLED=false
|
||||
|
||||
# CORS Configuration (optional, comma-separated)
|
||||
# Example: CORS_ALLOWED_ORIGINS=https://example.com,https://www.example.com
|
||||
CORS_ALLOWED_ORIGINS=
|
||||
|
||||
@@ -18,6 +18,9 @@ go.work.sum
|
||||
# env file
|
||||
.env
|
||||
|
||||
# local security audit notes (not tracked)
|
||||
SECURITY_AUDIT.md
|
||||
|
||||
# binaries
|
||||
bin/goyco
|
||||
|
||||
|
||||
@@ -171,13 +171,21 @@ server {
|
||||
}
|
||||
```
|
||||
|
||||
### Security headers and middleware ordering
|
||||
|
||||
When `RATE_LIMIT_TRUST_PROXY` is `true`, the application trusts `X-Forwarded-Proto` (among other forwarded headers) for HSTS and client IP derivation. Configure your reverse proxy to set trustworthy values and strip or overwrite any client-supplied forwarding headers before they reach the app.
|
||||
|
||||
Response caching uses `DecompressionMiddleware` before `DefaultRequestSizeLimitMiddleware`. Decompression is additionally capped internally so gzipped bodies cannot expand without limit before other limits apply.
|
||||
|
||||
In production (`GOYCO_ENV=production`), Swagger UI is not registered unless `SWAGGER_ENABLED=true`.
|
||||
|
||||
## API Documentation
|
||||
|
||||
The API is fully documented with Swagger.
|
||||
The API is fully documented with Swagger (enabled for non-production environments, or when `SWAGGER_ENABLED=true`).
|
||||
|
||||
Once running, visit:
|
||||
|
||||
- **Swagger UI**: `https://goyco.example.com/swagger/index.html`
|
||||
- **Swagger UI**: `https://goyco.example.com/swagger/index.html` (not served in production unless explicitly enabled via `SWAGGER_ENABLED`)
|
||||
|
||||
You can also use `curl` to get the API info, health check and even metrics:
|
||||
|
||||
|
||||
@@ -547,7 +547,6 @@ func TestE2E_SecurityHeadersEnhanced(t *testing.T) {
|
||||
expectedHeaders := map[string]string{
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||
}
|
||||
|
||||
@@ -707,7 +706,6 @@ func TestE2E_SecurityHeaderCombinations(t *testing.T) {
|
||||
requiredHeaders := []string{
|
||||
"X-Content-Type-Options",
|
||||
"X-Frame-Options",
|
||||
"X-XSS-Protection",
|
||||
"Referrer-Policy",
|
||||
"Content-Security-Policy",
|
||||
}
|
||||
|
||||
@@ -200,17 +200,21 @@ func ParseUintParam(w http.ResponseWriter, r *http.Request, paramName, entityNam
|
||||
}
|
||||
|
||||
func RequireAuth(w http.ResponseWriter, r *http.Request) (uint, bool) {
|
||||
userID := middleware.GetUserIDFromContext(r.Context())
|
||||
if userID == 0 {
|
||||
userPtr := middleware.GetUserIDFromContext(r.Context())
|
||||
if userPtr == nil {
|
||||
SendErrorResponse(w, "Authentication required", http.StatusUnauthorized)
|
||||
return 0, false
|
||||
}
|
||||
return userID, true
|
||||
return *userPtr, true
|
||||
}
|
||||
|
||||
func NewVoteContext(r *http.Request) services.VoteContext {
|
||||
var uid uint
|
||||
if userPtr := middleware.GetUserIDFromContext(r.Context()); userPtr != nil {
|
||||
uid = *userPtr
|
||||
}
|
||||
return services.VoteContext{
|
||||
UserID: middleware.GetUserIDFromContext(r.Context()),
|
||||
UserID: uid,
|
||||
IPAddress: GetClientIP(r),
|
||||
UserAgent: r.UserAgent(),
|
||||
}
|
||||
|
||||
@@ -569,7 +569,8 @@ func TestParseUintParam(t *testing.T) {
|
||||
func TestRequireAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uint
|
||||
setUserKey bool
|
||||
userIDValue uint
|
||||
expectedID uint
|
||||
expectedOK bool
|
||||
expectedStatus int
|
||||
@@ -577,25 +578,32 @@ func TestRequireAuth(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "authenticated user",
|
||||
userID: 123,
|
||||
setUserKey: true,
|
||||
userIDValue: 123,
|
||||
expectedID: 123,
|
||||
expectedOK: true,
|
||||
expectedStatus: 0,
|
||||
},
|
||||
{
|
||||
name: "unauthenticated user (no userID)",
|
||||
userID: 0,
|
||||
name: "unauthenticated user (missing context)",
|
||||
setUserKey: false,
|
||||
expectedID: 0,
|
||||
expectedOK: false,
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Authentication required",
|
||||
},
|
||||
{
|
||||
name: "authenticated user with id zero",
|
||||
setUserKey: true,
|
||||
userIDValue: 0,
|
||||
expectedID: 0,
|
||||
expectedOK: true,
|
||||
},
|
||||
{
|
||||
name: "authenticated user with large ID",
|
||||
userID: 4294967295,
|
||||
setUserKey: true,
|
||||
userIDValue: 4294967295,
|
||||
expectedID: 4294967295,
|
||||
expectedOK: true,
|
||||
expectedStatus: 0,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -604,7 +612,10 @@ func TestRequireAuth(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
ctx := context.WithValue(r.Context(), middleware.UserIDKey, tt.userID)
|
||||
ctx := context.Background()
|
||||
if tt.setUserKey {
|
||||
ctx = context.WithValue(ctx, middleware.UserIDKey, tt.userIDValue)
|
||||
}
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
userID, ok := RequireAuth(w, r)
|
||||
|
||||
@@ -22,7 +22,6 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
|
||||
|
||||
assertHeader(t, request, "X-Content-Type-Options")
|
||||
assertHeader(t, request, "X-Frame-Options")
|
||||
assertHeader(t, request, "X-XSS-Protection")
|
||||
})
|
||||
|
||||
t.Run("CORS_Headers_Present", func(t *testing.T) {
|
||||
|
||||
@@ -73,9 +73,10 @@ func NewAuth(verifier TokenVerifier) func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
func GetUserIDFromContext(ctx context.Context) uint {
|
||||
func GetUserIDFromContext(ctx context.Context) *uint {
|
||||
if userID, ok := ctx.Value(UserIDKey).(uint); ok {
|
||||
return userID
|
||||
u := userID
|
||||
return &u
|
||||
}
|
||||
return 0
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -28,8 +28,8 @@ func TestNewAuthWithoutAuthorization(t *testing.T) {
|
||||
middleware := NewAuth(verifier)
|
||||
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
if id := GetUserIDFromContext(r.Context()); id != 0 {
|
||||
t.Fatalf("unexpected user id %d", id)
|
||||
if id := GetUserIDFromContext(r.Context()); id != nil {
|
||||
t.Fatalf("unexpected user id %v", id)
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -54,8 +54,13 @@ func TestNewAuthValidToken(t *testing.T) {
|
||||
handlerCalled := false
|
||||
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerCalled = true
|
||||
if id := GetUserIDFromContext(r.Context()); id != 99 {
|
||||
t.Fatalf("expected user id 99, got %d", id)
|
||||
id := GetUserIDFromContext(r.Context())
|
||||
if id == nil || *id != 99 {
|
||||
v := uint(0)
|
||||
if id != nil {
|
||||
v = *id
|
||||
}
|
||||
t.Fatalf("expected user id 99, got %d", v)
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -131,11 +136,12 @@ func TestNewAuthVerifierError(t *testing.T) {
|
||||
func TestGetUserIDFromContext(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), UserIDKey, uint(55))
|
||||
|
||||
if id := GetUserIDFromContext(ctx); id != 55 {
|
||||
t.Fatalf("expected id 55, got %d", id)
|
||||
id := GetUserIDFromContext(ctx)
|
||||
if id == nil || *id != 55 {
|
||||
t.Fatalf("expected id 55, got %v", id)
|
||||
}
|
||||
|
||||
if id := GetUserIDFromContext(context.Background()); id != 0 {
|
||||
t.Fatalf("expected zero when id missing, got %d", id)
|
||||
if ptr := GetUserIDFromContext(context.Background()); ptr != nil {
|
||||
t.Fatalf("expected nil when id missing, got %v", ptr)
|
||||
}
|
||||
}
|
||||
|
||||
+178
-17
@@ -2,8 +2,11 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"container/list"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -24,47 +27,192 @@ type Cache interface {
|
||||
Clear() error
|
||||
}
|
||||
|
||||
func applyCacheMaxSize(cache Cache, max int) {
|
||||
if im, ok := cache.(*InMemoryCache); ok {
|
||||
im.SetMaxEntries(max)
|
||||
}
|
||||
}
|
||||
|
||||
func registerIndexedCacheKey(cache Cache, cacheKey string, path string, cacheablePrefixes []string) {
|
||||
if im, ok := cache.(*InMemoryCache); ok {
|
||||
im.RegisterKeyForPath(cacheKey, path, cacheablePrefixes)
|
||||
}
|
||||
}
|
||||
|
||||
func invalidateCacheForMutation(cache Cache, mutationPath string, cacheablePrefixes []string) {
|
||||
if im, ok := cache.(*InMemoryCache); ok {
|
||||
im.InvalidateForMutationPath(mutationPath, cacheablePrefixes)
|
||||
return
|
||||
}
|
||||
_ = cache.Clear()
|
||||
}
|
||||
|
||||
type InMemoryCache struct {
|
||||
mu sync.RWMutex
|
||||
mu sync.Mutex
|
||||
|
||||
data map[string]*CacheEntry
|
||||
maxSize int
|
||||
|
||||
ll *list.List
|
||||
lruEl map[string]*list.Element
|
||||
|
||||
prefixKeys map[string]map[string]struct{}
|
||||
keyPrefixes map[string]map[string]struct{}
|
||||
}
|
||||
|
||||
func NewInMemoryCache() *InMemoryCache {
|
||||
return &InMemoryCache{
|
||||
data: make(map[string]*CacheEntry),
|
||||
maxSize: 1000,
|
||||
ll: list.New(),
|
||||
lruEl: make(map[string]*list.Element),
|
||||
prefixKeys: make(map[string]map[string]struct{}),
|
||||
keyPrefixes: make(map[string]map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) Get(key string) (*CacheEntry, error) {
|
||||
cache.mu.RLock()
|
||||
entry, exists := cache.data[key]
|
||||
cache.mu.RUnlock()
|
||||
func (cache *InMemoryCache) SetMaxEntries(n int) {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
cache.maxSize = n
|
||||
for n > 0 && len(cache.data) > n {
|
||||
cache.evictOldestLocked()
|
||||
}
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) RegisterKeyForPath(cacheKey string, path string, cacheablePrefixes []string) {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
for _, p := range matchingCachePrefixes(path, cacheablePrefixes) {
|
||||
if cache.prefixKeys[p] == nil {
|
||||
cache.prefixKeys[p] = make(map[string]struct{})
|
||||
}
|
||||
cache.prefixKeys[p][cacheKey] = struct{}{}
|
||||
|
||||
if cache.keyPrefixes[cacheKey] == nil {
|
||||
cache.keyPrefixes[cacheKey] = make(map[string]struct{})
|
||||
}
|
||||
cache.keyPrefixes[cacheKey][p] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) InvalidateForMutationPath(mutationPath string, cacheablePrefixes []string) {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
var stale []string
|
||||
for _, prefix := range cacheablePrefixes {
|
||||
if !strings.HasPrefix(mutationPath, prefix) {
|
||||
continue
|
||||
}
|
||||
for key := range cache.prefixKeys[prefix] {
|
||||
if _, dup := seen[key]; dup {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
stale = append(stale, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range stale {
|
||||
cache.removeKeyLocked(key)
|
||||
}
|
||||
}
|
||||
|
||||
func matchingCachePrefixes(path string, cacheablePrefixes []string) []string {
|
||||
var out []string
|
||||
for _, p := range cacheablePrefixes {
|
||||
if strings.HasPrefix(path, p) {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) Get(key string) (*CacheEntry, error) {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
|
||||
entry, exists := cache.data[key]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("key not found")
|
||||
}
|
||||
|
||||
if time.Since(entry.Timestamp) > entry.TTL {
|
||||
cache.mu.Lock()
|
||||
delete(cache.data, key)
|
||||
cache.mu.Unlock()
|
||||
cache.removeKeyLocked(key)
|
||||
return nil, fmt.Errorf("entry expired")
|
||||
}
|
||||
|
||||
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
|
||||
cache.ll.MoveToFront(el)
|
||||
}
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) Set(key string, entry *CacheEntry) error {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
|
||||
if _, exists := cache.data[key]; exists {
|
||||
cache.data[key] = entry
|
||||
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
|
||||
cache.ll.MoveToFront(el)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
cache.data[key] = entry
|
||||
if cache.ll != nil {
|
||||
el := cache.ll.PushFront(key)
|
||||
cache.lruEl[key] = el
|
||||
}
|
||||
|
||||
for cache.maxSize > 0 && len(cache.data) > cache.maxSize {
|
||||
cache.evictOldestLocked()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) evictOldestLocked() {
|
||||
if cache.ll == nil || cache.ll.Len() == 0 {
|
||||
return
|
||||
}
|
||||
el := cache.ll.Back()
|
||||
if el == nil {
|
||||
return
|
||||
}
|
||||
key, _ := el.Value.(string)
|
||||
cache.removeKeyLocked(key)
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) removeKeyLocked(key string) {
|
||||
delete(cache.data, key)
|
||||
|
||||
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
|
||||
cache.ll.Remove(el)
|
||||
}
|
||||
delete(cache.lruEl, key)
|
||||
|
||||
for p := range cache.keyPrefixes[key] {
|
||||
if m, ok := cache.prefixKeys[p]; ok {
|
||||
delete(m, key)
|
||||
if len(m) == 0 {
|
||||
delete(cache.prefixKeys, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(cache.keyPrefixes, key)
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) Delete(key string) error {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
delete(cache.data, key)
|
||||
if _, ok := cache.data[key]; !ok {
|
||||
return fmt.Errorf("key not found")
|
||||
}
|
||||
cache.removeKeyLocked(key)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -72,6 +220,10 @@ func (cache *InMemoryCache) Clear() error {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
cache.data = make(map[string]*CacheEntry)
|
||||
cache.prefixKeys = make(map[string]map[string]struct{})
|
||||
cache.keyPrefixes = make(map[string]map[string]struct{})
|
||||
cache.lruEl = make(map[string]*list.Element)
|
||||
cache.ll = list.New()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -95,6 +247,7 @@ func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.H
|
||||
if config == nil {
|
||||
config = DefaultCacheConfig()
|
||||
}
|
||||
applyCacheMaxSize(cache, config.MaxSize)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -141,8 +294,15 @@ func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.H
|
||||
TTL: config.TTL,
|
||||
}
|
||||
|
||||
path := r.URL.Path
|
||||
prefixes := config.CacheablePaths
|
||||
|
||||
go func() {
|
||||
cache.Set(cacheKey, entry)
|
||||
if err := cache.Set(cacheKey, entry); err != nil {
|
||||
log.Printf("middleware cache Set: %v", err)
|
||||
return
|
||||
}
|
||||
registerIndexedCacheKey(cache, cacheKey, path, prefixes)
|
||||
}()
|
||||
}
|
||||
})
|
||||
@@ -190,20 +350,21 @@ func generateCacheKey(r *http.Request) string {
|
||||
key += "?" + r.URL.RawQuery
|
||||
}
|
||||
|
||||
if userID := GetUserIDFromContext(r.Context()); userID != 0 {
|
||||
key += fmt.Sprintf(":user:%d", userID)
|
||||
if userID := GetUserIDFromContext(r.Context()); userID != nil {
|
||||
key += fmt.Sprintf(":user:%d", *userID)
|
||||
}
|
||||
|
||||
hash := md5.Sum([]byte(key))
|
||||
return fmt.Sprintf("cache:%x", hash)
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return "cache:" + hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func CacheInvalidationMiddleware(cache Cache) func(http.Handler) http.Handler {
|
||||
func CacheInvalidationMiddleware(cache Cache, cacheablePrefixes []string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" {
|
||||
mPath := r.URL.Path
|
||||
go func() {
|
||||
cache.Clear()
|
||||
invalidateCacheForMutation(cache, mPath, cacheablePrefixes)
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -97,6 +97,29 @@ func TestInMemoryCache(t *testing.T) {
|
||||
t.Error("Expected error for expired entry")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LRU evicts oldest at max size", func(t *testing.T) {
|
||||
c := NewInMemoryCache()
|
||||
c.SetMaxEntries(2)
|
||||
entry := func(b byte) *CacheEntry {
|
||||
return &CacheEntry{Data: []byte{b}, Headers: make(http.Header), Timestamp: time.Now(), TTL: time.Hour}
|
||||
}
|
||||
_ = c.Set("k1", entry('a'))
|
||||
_ = c.Set("k2", entry('b'))
|
||||
if _, err := c.Get("k1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = c.Set("k3", entry('c'))
|
||||
if _, err := c.Get("k1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := c.Get("k3"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := c.Get("k2"); err == nil {
|
||||
t.Fatal("expected k2 evicted")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCacheMiddleware(t *testing.T) {
|
||||
@@ -255,10 +278,10 @@ func TestCacheKeyGeneration(t *testing.T) {
|
||||
query string
|
||||
expected string
|
||||
}{
|
||||
{"GET", "/test", "", "cache:e2b43a77e8b6707afcc1571382ca7c73"},
|
||||
{"GET", "/test", "param=value", "cache:067b4b550d6cee93dfb106d6912ef91b"},
|
||||
{"POST", "/test", "", "cache:fb3126bb69b4d21769b5fa4d78318b0e"},
|
||||
{"PUT", "/users/123", "", "cache:40b0b7a2306bfd4998d6219c1ef29783"},
|
||||
{"GET", "/test", "", "cache:dbbdf14ce9e8333532d3760e4e1254e9a4f9b4bd7e98446754bfc23420d5e7c9"},
|
||||
{"GET", "/test", "param=value", "cache:da0e5eaf04e82e40b49ebb8f0a1c85954a207119d7e2423a9c24a94ddb189f71"},
|
||||
{"POST", "/test", "", "cache:719d94211ce99e5e0d039a4a7dfa57409eadf2573544454005c1fd4f3fce988f"},
|
||||
{"PUT", "/users/123", "", "cache:168e0c53c01e3f92badb40db057805a786749b1fd9be4d1562f34ba6cfac77fe"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -587,7 +610,6 @@ func TestCacheMiddlewarePreservesSecurityHeaders(t *testing.T) {
|
||||
securityHeaders := []string{
|
||||
"X-Content-Type-Options",
|
||||
"X-Frame-Options",
|
||||
"X-XSS-Protection",
|
||||
"Referrer-Policy",
|
||||
"Content-Security-Policy",
|
||||
"Permissions-Policy",
|
||||
@@ -698,31 +720,24 @@ func TestCacheMiddlewarePreservesHSTSHeader(t *testing.T) {
|
||||
|
||||
func TestCacheInvalidationMiddleware(t *testing.T) {
|
||||
cache := NewInMemoryCache()
|
||||
prefixes := []string{"/api/posts", "/api/other"}
|
||||
|
||||
entries := []struct {
|
||||
key string
|
||||
entry *CacheEntry
|
||||
}{
|
||||
{"cache:abc123", &CacheEntry{Data: []byte("data1"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}},
|
||||
{"cache:def456", &CacheEntry{Data: []byte("data2"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}},
|
||||
{"cache:ghi789", &CacheEntry{Data: []byte("data3"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}},
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
if err := cache.Set(e.key, e.entry); err != nil {
|
||||
setIndexed := func(key string, entry *CacheEntry, path string) {
|
||||
if err := cache.Set(key, entry); err != nil {
|
||||
t.Fatalf("Failed to set cache entry: %v", err)
|
||||
}
|
||||
cache.RegisterKeyForPath(key, path, prefixes)
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
if _, err := cache.Get(e.key); err != nil {
|
||||
t.Fatalf("Expected entry %s to exist, got error: %v", e.key, err)
|
||||
}
|
||||
}
|
||||
postsEntry := &CacheEntry{Data: []byte("posts"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
|
||||
otherEntry := &CacheEntry{Data: []byte("other"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
|
||||
|
||||
middleware := CacheInvalidationMiddleware(cache)
|
||||
setIndexed("postsKey", postsEntry, "/api/posts/top")
|
||||
setIndexed("otherKey", otherEntry, "/api/other/x")
|
||||
|
||||
t.Run("POST clears cache", func(t *testing.T) {
|
||||
middleware := CacheInvalidationMiddleware(cache, prefixes)
|
||||
|
||||
t.Run("POST under posts prefix invalidates posts keys only", func(t *testing.T) {
|
||||
request := httptest.NewRequest("POST", "/api/posts", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@@ -732,80 +747,80 @@ func TestCacheInvalidationMiddleware(t *testing.T) {
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
for _, e := range entries {
|
||||
if _, err := cache.Get(e.key); err == nil {
|
||||
t.Errorf("Expected entry %s to be cleared, but it still exists", e.key)
|
||||
if _, err := cache.Get("postsKey"); err == nil {
|
||||
t.Error("expected postsKey cleared")
|
||||
}
|
||||
if _, err := cache.Get("otherKey"); err != nil {
|
||||
t.Errorf("expected otherKey to remain: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
for _, e := range entries {
|
||||
if err := cache.Set(e.key, e.entry); err != nil {
|
||||
t.Fatalf("Failed to repopulate cache: %v", err)
|
||||
}
|
||||
}
|
||||
setIndexed("postsKey", postsEntry, "/api/posts/top")
|
||||
wildEntry := &CacheEntry{Data: []byte("wild"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
|
||||
_ = cache.Set("untracked", wildEntry)
|
||||
|
||||
t.Run("PUT clears cache", func(t *testing.T) {
|
||||
t.Run("mutation does not wipe untracked keys", func(t *testing.T) {
|
||||
request := httptest.NewRequest("PUT", "/api/posts/1", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(recorder, request)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
for _, e := range entries {
|
||||
if _, err := cache.Get(e.key); err == nil {
|
||||
t.Errorf("Expected entry %s to be cleared, but it still exists", e.key)
|
||||
}
|
||||
if _, err := cache.Get("untracked"); err != nil {
|
||||
t.Fatal("untracked key should remain")
|
||||
}
|
||||
})
|
||||
|
||||
for _, e := range entries {
|
||||
if err := cache.Set(e.key, e.entry); err != nil {
|
||||
t.Fatalf("Failed to repopulate cache: %v", err)
|
||||
t.Run("GET does not invalidate", func(t *testing.T) {
|
||||
cache2 := NewInMemoryCache()
|
||||
setIndexed := func(key string, path string) {
|
||||
e := &CacheEntry{Data: []byte("d"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
|
||||
if err := cache2.Set(key, e); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cache2.RegisterKeyForPath(key, path, prefixes)
|
||||
}
|
||||
setIndexed("gk", "/api/posts/1")
|
||||
|
||||
t.Run("DELETE clears cache", func(t *testing.T) {
|
||||
request := httptest.NewRequest("DELETE", "/api/posts/1", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mw := CacheInvalidationMiddleware(cache2, prefixes)
|
||||
req := httptest.NewRequest("GET", "/api/posts", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(recorder, request)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
for _, e := range entries {
|
||||
if _, err := cache.Get(e.key); err == nil {
|
||||
t.Errorf("Expected entry %s to be cleared, but it still exists", e.key)
|
||||
}
|
||||
})).ServeHTTP(rec, req)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if _, err := cache2.Get("gk"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GET does not clear cache", func(t *testing.T) {
|
||||
|
||||
for _, e := range entries {
|
||||
if err := cache.Set(e.key, e.entry); err != nil {
|
||||
t.Fatalf("Failed to repopulate cache: %v", err)
|
||||
t.Run("DELETE under other prefix", func(t *testing.T) {
|
||||
cache3 := NewInMemoryCache()
|
||||
ep := &CacheEntry{Data: []byte("p"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
|
||||
eo := &CacheEntry{Data: []byte("o"), Headers: make(http.Header), Timestamp: time.Now(), TTL: 5 * time.Minute}
|
||||
if err := cache3.Set("pk", ep); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cache3.RegisterKeyForPath("pk", "/api/posts/1", prefixes)
|
||||
if err := cache3.Set("ok", eo); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cache3.RegisterKeyForPath("ok", "/api/other/y", prefixes)
|
||||
|
||||
request := httptest.NewRequest("GET", "/api/posts", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mw := CacheInvalidationMiddleware(cache3, prefixes)
|
||||
delReq := httptest.NewRequest("DELETE", "/api/other/y", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(recorder, request)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
for _, e := range entries {
|
||||
if _, err := cache.Get(e.key); err != nil {
|
||||
t.Errorf("Expected entry %s to still exist, got error: %v", e.key, err)
|
||||
})).ServeHTTP(rec, delReq)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if _, err := cache3.Get("ok"); err == nil {
|
||||
t.Error("expected ok cleared")
|
||||
}
|
||||
if _, err := cache3.Get("pk"); err != nil {
|
||||
t.Fatal("posts key should remain")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,8 +15,39 @@ type CORSConfig struct {
|
||||
AllowCredentials bool
|
||||
}
|
||||
|
||||
func parseAllowedOriginsEnv() []string {
|
||||
raw := os.Getenv("CORS_ALLOWED_ORIGINS")
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func validateCORSConfig(config *CORSConfig) {
|
||||
if config == nil {
|
||||
panic("middleware.CORS: config is nil")
|
||||
}
|
||||
if !config.AllowCredentials {
|
||||
return
|
||||
}
|
||||
for _, o := range config.AllowedOrigins {
|
||||
if o == "*" {
|
||||
panic("middleware.CORS: AllowCredentials with wildcard AllowedOrigins (*) is not permitted")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewCORSConfig() *CORSConfig {
|
||||
env := os.Getenv("GOYCO_ENV")
|
||||
originsEnv := parseAllowedOriginsEnv()
|
||||
|
||||
config := &CORSConfig{
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
@@ -27,24 +58,27 @@ func NewCORSConfig() *CORSConfig {
|
||||
|
||||
switch env {
|
||||
case "production", "staging":
|
||||
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" {
|
||||
config.AllowCredentials = true
|
||||
if len(originsEnv) > 0 {
|
||||
config.AllowedOrigins = originsEnv
|
||||
} else {
|
||||
config.AllowedOrigins = []string{}
|
||||
}
|
||||
config.AllowCredentials = true
|
||||
default:
|
||||
config.AllowCredentials = true
|
||||
if len(originsEnv) > 0 {
|
||||
config.AllowedOrigins = originsEnv
|
||||
} else {
|
||||
config.AllowedOrigins = []string{
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:3000",
|
||||
"http://127.0.0.1:8080",
|
||||
}
|
||||
config.AllowCredentials = true
|
||||
}
|
||||
|
||||
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins != "" {
|
||||
config.AllowedOrigins = strings.Split(origins, ",")
|
||||
}
|
||||
}
|
||||
|
||||
validateCORSConfig(config)
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -76,6 +110,7 @@ func setCORSHeaders(w http.ResponseWriter, origin string, hasWildcard bool, conf
|
||||
}
|
||||
|
||||
func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler {
|
||||
validateCORSConfig(config)
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
@@ -181,7 +181,12 @@ func TestCORSWithConfig_WildcardOrigin(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) {
|
||||
func TestCORSWithConfig_WildcardWithCredentialsPanics(t *testing.T) {
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatal("expected panic for AllowCredentials with wildcard AllowedOrigins")
|
||||
}
|
||||
}()
|
||||
config := &CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
@@ -189,24 +194,7 @@ func TestCORSWithConfig_WildcardWithCredentials(t *testing.T) {
|
||||
MaxAge: 3600,
|
||||
AllowCredentials: true,
|
||||
}
|
||||
|
||||
handler := CORSWithConfig(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be 'http://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Credentials") != "" {
|
||||
t.Errorf("Expected Access-Control-Allow-Credentials to be empty with wildcard, got '%s'", w.Header().Get("Access-Control-Allow-Credentials"))
|
||||
}
|
||||
_ = CORSWithConfig(config)
|
||||
}
|
||||
|
||||
func TestCORSWithConfig_NoOriginHeader(t *testing.T) {
|
||||
@@ -393,6 +381,30 @@ func TestCORSOPTIONSRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSAllowedOriginsTrimmedFromEnv(t *testing.T) {
|
||||
t.Setenv("GOYCO_ENV", "development")
|
||||
t.Setenv("CORS_ALLOWED_ORIGINS", " http://localhost:3000 , https://yourdomain.com ")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := CORS(handler)
|
||||
|
||||
for _, origin := range []string{"http://localhost:3000", "https://yourdomain.com"} {
|
||||
request := httptest.NewRequest("GET", "/api/posts", nil)
|
||||
request.Header.Set("Origin", origin)
|
||||
recorder := httptest.NewRecorder()
|
||||
middleware.ServeHTTP(recorder, request)
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("origin %q: status %d", origin, recorder.Code)
|
||||
}
|
||||
if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != origin {
|
||||
t.Fatalf("origin %q: Allow-Origin %q", origin, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSAllowedOrigins(t *testing.T) {
|
||||
t.Setenv("GOYCO_ENV", "development")
|
||||
t.Setenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://yourdomain.com")
|
||||
|
||||
@@ -28,7 +28,7 @@ func SetCSRFToken(w http.ResponseWriter, r *http.Request, token string) {
|
||||
Name: CSRFTokenCookieName,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
HttpOnly: false,
|
||||
Secure: IsHTTPS(r),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: 3600,
|
||||
|
||||
@@ -242,8 +242,8 @@ func TestSetCSRFToken(t *testing.T) {
|
||||
t.Errorf("Expected cookie value %s, got %s", token, cookie.Value)
|
||||
}
|
||||
|
||||
if !cookie.HttpOnly {
|
||||
t.Error("CSRF token cookie should be HttpOnly")
|
||||
if cookie.HttpOnly {
|
||||
t.Error("CSRF token cookie must not be HttpOnly so JS can mirror it to X-CSRF-Token")
|
||||
}
|
||||
|
||||
if cookie.SameSite != http.SameSiteLaxMode {
|
||||
|
||||
@@ -327,8 +327,8 @@ func GetSecureClientIP(r *http.Request) string {
|
||||
func GetKey(r *http.Request) string {
|
||||
ip := GetSecureClientIP(r)
|
||||
|
||||
if userID := GetUserIDFromContext(r.Context()); userID != 0 {
|
||||
return fmt.Sprintf("user:%d:ip:%s", userID, ip)
|
||||
if userID := GetUserIDFromContext(r.Context()); userID != nil {
|
||||
return fmt.Sprintf("user:%d:ip:%s", *userID, ip)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ip:%s", ip)
|
||||
|
||||
@@ -5,12 +5,21 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const CSPNonceKey contextKey = "csp_nonce"
|
||||
|
||||
type SecurityHeadersConfig struct {
|
||||
RelaxSwaggerCSP bool
|
||||
}
|
||||
|
||||
func DefaultSecurityHeadersConfig() SecurityHeadersConfig {
|
||||
return SecurityHeadersConfig{RelaxSwaggerCSP: true}
|
||||
}
|
||||
|
||||
func GenerateCSPNonce() (string, error) {
|
||||
nonceBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(nonceBytes); err != nil {
|
||||
@@ -27,15 +36,18 @@ func GetCSPNonceFromContext(ctx context.Context) string {
|
||||
}
|
||||
|
||||
func SecurityHeadersMiddleware() func(http.Handler) http.Handler {
|
||||
return SecurityHeadersMiddlewareWithConfig(DefaultSecurityHeadersConfig())
|
||||
}
|
||||
|
||||
func SecurityHeadersMiddlewareWithConfig(cfg SecurityHeadersConfig) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
isSwaggerRoute := strings.HasPrefix(r.URL.Path, "/swagger")
|
||||
if isSwaggerRoute {
|
||||
if isSwaggerRoute && cfg.RelaxSwaggerCSP {
|
||||
csp := "default-src 'self'; " +
|
||||
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; " +
|
||||
"style-src 'self' 'unsafe-inline'; " +
|
||||
@@ -51,7 +63,7 @@ func SecurityHeadersMiddleware() func(http.Handler) http.Handler {
|
||||
} else {
|
||||
nonce, err := GenerateCSPNonce()
|
||||
if err != nil {
|
||||
|
||||
log.Printf("middleware security headers: CSP nonce: %v", err)
|
||||
nonce = ""
|
||||
}
|
||||
|
||||
@@ -72,7 +84,6 @@ func SecurityHeadersMiddleware() func(http.Handler) http.Handler {
|
||||
csp = "script-src 'self' 'nonce-" + nonce + "'; " +
|
||||
"style-src 'self' 'nonce-" + nonce + "'; " + csp
|
||||
} else {
|
||||
|
||||
csp = "script-src 'self'; " +
|
||||
"style-src 'self'; " + csp
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ func TestSecurityHeadersMiddleware(t *testing.T) {
|
||||
expectedHeaders := map[string]string{
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||
"Server": "",
|
||||
}
|
||||
@@ -176,7 +175,6 @@ func TestSecurityHeadersMiddleware_MultipleRequests(t *testing.T) {
|
||||
requiredHeaders := []string{
|
||||
"X-Content-Type-Options",
|
||||
"X-Frame-Options",
|
||||
"X-XSS-Protection",
|
||||
"Referrer-Policy",
|
||||
"Content-Security-Policy",
|
||||
"Permissions-Policy",
|
||||
@@ -289,3 +287,25 @@ func TestCSPNonceInContext(t *testing.T) {
|
||||
t.Errorf("CSP header should contain nonce from context. CSP: %s, Nonce: %s", csp, capturedNonce)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeadersMiddleware_SwaggerStrictWhenRelaxedDisabled(t *testing.T) {
|
||||
handler := SecurityHeadersMiddlewareWithConfig(SecurityHeadersConfig{RelaxSwaggerCSP: false})(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/swagger/index.html", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
csp := rec.Header().Get("Content-Security-Policy")
|
||||
if csp == "" {
|
||||
t.Fatal("expected CSP")
|
||||
}
|
||||
if strings.Contains(csp, "'unsafe-eval'") || strings.Contains(csp, "'unsafe-inline'") {
|
||||
t.Fatalf("unexpected relaxed CSP for swagger path: %s", csp)
|
||||
}
|
||||
if !strings.Contains(csp, "nonce-") {
|
||||
t.Fatalf("expected nonce CSP for swagger path, got %s", csp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -53,7 +54,10 @@ func SecurityLoggingMiddleware(logger *SecurityLogger) func(http.Handler) http.H
|
||||
|
||||
next.ServeHTTP(rw, r)
|
||||
|
||||
userID := GetUserIDFromContext(r.Context())
|
||||
userUID := uint(0)
|
||||
if u := GetUserIDFromContext(r.Context()); u != nil {
|
||||
userUID = *u
|
||||
}
|
||||
ip := getClientIP(r)
|
||||
|
||||
event := SecurityEvent{
|
||||
@@ -61,7 +65,7 @@ func SecurityLoggingMiddleware(logger *SecurityLogger) func(http.Handler) http.H
|
||||
UserAgent: r.UserAgent(),
|
||||
Path: r.URL.Path,
|
||||
Method: r.Method,
|
||||
UserID: userID,
|
||||
UserID: userUID,
|
||||
Timestamp: start,
|
||||
}
|
||||
|
||||
@@ -106,12 +110,15 @@ func SuspiciousActivityMiddleware(logger *SecurityLogger) func(http.Handler) htt
|
||||
suspicious := false
|
||||
details := ""
|
||||
|
||||
if containsSQLInjection(r.URL.RawQuery) || containsSQLInjection(r.URL.Path) {
|
||||
pathProbe := layeredUnescape(r.URL.Path, url.PathUnescape)
|
||||
queryProbe := layeredUnescape(r.URL.RawQuery, url.QueryUnescape)
|
||||
|
||||
if containsSQLInjection(pathProbe) || containsSQLInjection(queryProbe) {
|
||||
suspicious = true
|
||||
details = "Potential SQL injection attempt"
|
||||
}
|
||||
|
||||
if containsXSS(r.URL.RawQuery) || containsXSS(r.URL.Path) {
|
||||
if containsXSS(pathProbe) || containsXSS(queryProbe) {
|
||||
suspicious = true
|
||||
details = "Potential XSS attempt"
|
||||
}
|
||||
@@ -159,6 +166,18 @@ func getClientIP(r *http.Request) string {
|
||||
return GetSecureClientIP(r)
|
||||
}
|
||||
|
||||
func layeredUnescape(s string, decoder func(string) (string, error)) string {
|
||||
out := s
|
||||
for range 3 {
|
||||
d, err := decoder(out)
|
||||
if err != nil || d == out {
|
||||
return out
|
||||
}
|
||||
out = d
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func containsSQLInjection(input string) bool {
|
||||
sqlPatterns := []string{
|
||||
"' OR '1'='1",
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -566,6 +567,28 @@ func TestSuspiciousActivityMiddleware_NoSuspiciousActivity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuspiciousActivityMiddleware_EncodedSQLInQuery(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := &SecurityLogger{
|
||||
logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile),
|
||||
}
|
||||
|
||||
handler := SuspiciousActivityMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
q := url.Values{}
|
||||
q.Set("s", "' OR '1'='1")
|
||||
req := httptest.NewRequest("GET", "/search?"+q.Encode(), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "SQL injection") {
|
||||
t.Fatalf("expected SQL injection log, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuspiciousActivityMiddleware_Debug(t *testing.T) {
|
||||
|
||||
t.Run("SQL Detection", func(t *testing.T) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -32,12 +33,23 @@ type RouterConfig struct {
|
||||
RateLimitConfig config.RateLimitConfig
|
||||
}
|
||||
|
||||
func swaggerExposed() bool {
|
||||
if strings.EqualFold(strings.TrimSpace(os.Getenv("SWAGGER_ENABLED")), "true") {
|
||||
return true
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(os.Getenv("GOYCO_ENV"))) != "production"
|
||||
}
|
||||
|
||||
func NewRouter(cfg RouterConfig) http.Handler {
|
||||
middleware.SetTrustProxyHeaders(cfg.RateLimitConfig.TrustProxyHeaders)
|
||||
|
||||
exposeSwagger := swaggerExposed()
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Use(middleware.Logging(cfg.Debug))
|
||||
router.Use(middleware.SecurityHeadersMiddleware())
|
||||
router.Use(middleware.SecurityHeadersMiddlewareWithConfig(middleware.SecurityHeadersConfig{
|
||||
RelaxSwaggerCSP: exposeSwagger,
|
||||
}))
|
||||
router.Use(middleware.HSTSMiddleware())
|
||||
router.Use(middleware.CORS)
|
||||
|
||||
@@ -54,7 +66,7 @@ func NewRouter(cfg RouterConfig) http.Handler {
|
||||
cacheConfig.CacheablePaths = append([]string{}, cfg.CacheablePaths...)
|
||||
}
|
||||
router.Use(middleware.CacheMiddleware(cache, cacheConfig))
|
||||
router.Use(middleware.CacheInvalidationMiddleware(cache))
|
||||
router.Use(middleware.CacheInvalidationMiddleware(cache, cacheConfig.CacheablePaths))
|
||||
}
|
||||
|
||||
var dbMonitor middleware.DBMonitor
|
||||
@@ -94,8 +106,10 @@ func NewRouter(cfg RouterConfig) http.Handler {
|
||||
metricsRateLimited.Get("/metrics", cfg.APIHandler.GetMetrics)
|
||||
}
|
||||
|
||||
if exposeSwagger {
|
||||
swaggerRateLimited := router.With(middleware.GeneralRateLimitMiddlewareWithLimit(cfg.RateLimitConfig.GeneralLimit))
|
||||
swaggerRateLimited.Get("/swagger/*", httpSwagger.Handler())
|
||||
}
|
||||
|
||||
router.Get("/robots.txt", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.ServeFile(w, r, filepath.Join(cfg.StaticDir, "robots.txt"))
|
||||
|
||||
@@ -348,6 +348,8 @@ func TestRouterWithoutPageHandler(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSwaggerRoute(t *testing.T) {
|
||||
t.Setenv("GOYCO_ENV", "development")
|
||||
t.Setenv("SWAGGER_ENABLED", "")
|
||||
router := createTestRouter(createDefaultRouterConfig())
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
|
||||
@@ -360,6 +362,34 @@ func TestSwaggerRoute(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwaggerRouteHiddenInProduction(t *testing.T) {
|
||||
t.Setenv("GOYCO_ENV", "production")
|
||||
t.Setenv("SWAGGER_ENABLED", "")
|
||||
router := createTestRouter(createDefaultRouterConfig())
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected 404 for swagger in production, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwaggerRouteEnabledInProductionWhenForced(t *testing.T) {
|
||||
t.Setenv("GOYCO_ENV", "production")
|
||||
t.Setenv("SWAGGER_ENABLED", "true")
|
||||
router := createTestRouter(createDefaultRouterConfig())
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK && recorder.Code != http.StatusMovedPermanently {
|
||||
t.Errorf("Expected status 200 or 301 for swagger when SWAGGER_ENABLED, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticFileRoute(t *testing.T) {
|
||||
cfg := createDefaultRouterConfig()
|
||||
cfg.StaticDir = "../../internal/static/"
|
||||
|
||||
Reference in New Issue
Block a user