From be64e7c8d2c0783bd430475b6f8ce8867579d339 Mon Sep 17 00:00:00 2001 From: Kharec Date: Wed, 6 May 2026 20:13:56 +0200 Subject: [PATCH] fix(middleware): SHA-256 keys, LRU cache, and prefix-scoped invalidation --- internal/middleware/cache.go | 199 +++++++++++++++++++++++++++++++---- 1 file changed, 180 insertions(+), 19 deletions(-) diff --git a/internal/middleware/cache.go b/internal/middleware/cache.go index 8d87461..2f93732 100644 --- a/internal/middleware/cache.go +++ b/internal/middleware/cache.go @@ -2,8 +2,11 @@ package middleware import ( "bytes" - "crypto/md5" + "container/list" + "crypto/sha256" + "encoding/hex" "fmt" + "log" "net/http" "strings" "sync" @@ -24,47 +27,192 @@ type Cache interface { Clear() error } +func applyCacheMaxSize(cache Cache, max int) { + if im, ok := cache.(*InMemoryCache); ok { + im.SetMaxEntries(max) + } +} + +func registerIndexedCacheKey(cache Cache, cacheKey string, path string, cacheablePrefixes []string) { + if im, ok := cache.(*InMemoryCache); ok { + im.RegisterKeyForPath(cacheKey, path, cacheablePrefixes) + } +} + +func invalidateCacheForMutation(cache Cache, mutationPath string, cacheablePrefixes []string) { + if im, ok := cache.(*InMemoryCache); ok { + im.InvalidateForMutationPath(mutationPath, cacheablePrefixes) + return + } + _ = cache.Clear() +} + type InMemoryCache struct { - mu sync.RWMutex - data map[string]*CacheEntry + mu sync.Mutex + + data map[string]*CacheEntry + maxSize int + + ll *list.List + lruEl map[string]*list.Element + + prefixKeys map[string]map[string]struct{} + keyPrefixes map[string]map[string]struct{} } func NewInMemoryCache() *InMemoryCache { return &InMemoryCache{ - data: make(map[string]*CacheEntry), + data: make(map[string]*CacheEntry), + maxSize: 1000, + ll: list.New(), + lruEl: make(map[string]*list.Element), + prefixKeys: make(map[string]map[string]struct{}), + keyPrefixes: make(map[string]map[string]struct{}), } } -func (cache *InMemoryCache) Get(key string) (*CacheEntry, error) { - cache.mu.RLock() - entry, exists := cache.data[key] - cache.mu.RUnlock() +func (cache *InMemoryCache) SetMaxEntries(n int) { + cache.mu.Lock() + defer cache.mu.Unlock() + cache.maxSize = n + for n > 0 && len(cache.data) > n { + cache.evictOldestLocked() + } +} +func (cache *InMemoryCache) RegisterKeyForPath(cacheKey string, path string, cacheablePrefixes []string) { + cache.mu.Lock() + defer cache.mu.Unlock() + for _, p := range matchingCachePrefixes(path, cacheablePrefixes) { + if cache.prefixKeys[p] == nil { + cache.prefixKeys[p] = make(map[string]struct{}) + } + cache.prefixKeys[p][cacheKey] = struct{}{} + + if cache.keyPrefixes[cacheKey] == nil { + cache.keyPrefixes[cacheKey] = make(map[string]struct{}) + } + cache.keyPrefixes[cacheKey][p] = struct{}{} + } +} + +func (cache *InMemoryCache) InvalidateForMutationPath(mutationPath string, cacheablePrefixes []string) { + cache.mu.Lock() + defer cache.mu.Unlock() + + seen := make(map[string]struct{}) + var stale []string + for _, prefix := range cacheablePrefixes { + if !strings.HasPrefix(mutationPath, prefix) { + continue + } + for key := range cache.prefixKeys[prefix] { + if _, dup := seen[key]; dup { + continue + } + seen[key] = struct{}{} + stale = append(stale, key) + } + } + + for _, key := range stale { + cache.removeKeyLocked(key) + } +} + +func matchingCachePrefixes(path string, cacheablePrefixes []string) []string { + var out []string + for _, p := range cacheablePrefixes { + if strings.HasPrefix(path, p) { + out = append(out, p) + } + } + return out +} + +func (cache *InMemoryCache) Get(key string) (*CacheEntry, error) { + cache.mu.Lock() + defer cache.mu.Unlock() + + entry, exists := cache.data[key] if !exists { return nil, fmt.Errorf("key not found") } if time.Since(entry.Timestamp) > entry.TTL { - cache.mu.Lock() - delete(cache.data, key) - cache.mu.Unlock() + cache.removeKeyLocked(key) return nil, fmt.Errorf("entry expired") } + if el, ok := cache.lruEl[key]; ok && cache.ll != nil { + cache.ll.MoveToFront(el) + } + return entry, nil } func (cache *InMemoryCache) Set(key string, entry *CacheEntry) error { cache.mu.Lock() defer cache.mu.Unlock() + + if _, exists := cache.data[key]; exists { + cache.data[key] = entry + if el, ok := cache.lruEl[key]; ok && cache.ll != nil { + cache.ll.MoveToFront(el) + } + return nil + } + cache.data[key] = entry + if cache.ll != nil { + el := cache.ll.PushFront(key) + cache.lruEl[key] = el + } + + for cache.maxSize > 0 && len(cache.data) > cache.maxSize { + cache.evictOldestLocked() + } return nil } +func (cache *InMemoryCache) evictOldestLocked() { + if cache.ll == nil || cache.ll.Len() == 0 { + return + } + el := cache.ll.Back() + if el == nil { + return + } + key, _ := el.Value.(string) + cache.removeKeyLocked(key) +} + +func (cache *InMemoryCache) removeKeyLocked(key string) { + delete(cache.data, key) + + if el, ok := cache.lruEl[key]; ok && cache.ll != nil { + cache.ll.Remove(el) + } + delete(cache.lruEl, key) + + for p := range cache.keyPrefixes[key] { + if m, ok := cache.prefixKeys[p]; ok { + delete(m, key) + if len(m) == 0 { + delete(cache.prefixKeys, p) + } + } + } + delete(cache.keyPrefixes, key) +} + func (cache *InMemoryCache) Delete(key string) error { cache.mu.Lock() defer cache.mu.Unlock() - delete(cache.data, key) + if _, ok := cache.data[key]; !ok { + return fmt.Errorf("key not found") + } + cache.removeKeyLocked(key) return nil } @@ -72,6 +220,10 @@ func (cache *InMemoryCache) Clear() error { cache.mu.Lock() defer cache.mu.Unlock() cache.data = make(map[string]*CacheEntry) + cache.prefixKeys = make(map[string]map[string]struct{}) + cache.keyPrefixes = make(map[string]map[string]struct{}) + cache.lruEl = make(map[string]*list.Element) + cache.ll = list.New() return nil } @@ -95,6 +247,7 @@ func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.H if config == nil { config = DefaultCacheConfig() } + applyCacheMaxSize(cache, config.MaxSize) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -141,8 +294,15 @@ func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.H TTL: config.TTL, } + path := r.URL.Path + prefixes := config.CacheablePaths + go func() { - cache.Set(cacheKey, entry) + if err := cache.Set(cacheKey, entry); err != nil { + log.Printf("middleware cache Set: %v", err) + return + } + registerIndexedCacheKey(cache, cacheKey, path, prefixes) }() } }) @@ -190,20 +350,21 @@ func generateCacheKey(r *http.Request) string { key += "?" + r.URL.RawQuery } - if userID := GetUserIDFromContext(r.Context()); userID != 0 { - key += fmt.Sprintf(":user:%d", userID) + if userID := GetUserIDFromContext(r.Context()); userID != nil { + key += fmt.Sprintf(":user:%d", *userID) } - hash := md5.Sum([]byte(key)) - return fmt.Sprintf("cache:%x", hash) + sum := sha256.Sum256([]byte(key)) + return "cache:" + hex.EncodeToString(sum[:]) } -func CacheInvalidationMiddleware(cache Cache) func(http.Handler) http.Handler { +func CacheInvalidationMiddleware(cache Cache, cacheablePrefixes []string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" { + mPath := r.URL.Path go func() { - cache.Clear() + invalidateCacheForMutation(cache, mPath, cacheablePrefixes) }() }