fix(middleware): SHA-256 keys, LRU cache, and prefix-scoped invalidation

This commit is contained in:
2026-05-06 20:13:56 +02:00
parent 1aa256c6a8
commit be64e7c8d2
+177 -16
View File
@@ -2,8 +2,11 @@ package middleware
import ( import (
"bytes" "bytes"
"crypto/md5" "container/list"
"crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"log"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
@@ -24,47 +27,192 @@ type Cache interface {
Clear() error Clear() error
} }
func applyCacheMaxSize(cache Cache, max int) {
if im, ok := cache.(*InMemoryCache); ok {
im.SetMaxEntries(max)
}
}
func registerIndexedCacheKey(cache Cache, cacheKey string, path string, cacheablePrefixes []string) {
if im, ok := cache.(*InMemoryCache); ok {
im.RegisterKeyForPath(cacheKey, path, cacheablePrefixes)
}
}
func invalidateCacheForMutation(cache Cache, mutationPath string, cacheablePrefixes []string) {
if im, ok := cache.(*InMemoryCache); ok {
im.InvalidateForMutationPath(mutationPath, cacheablePrefixes)
return
}
_ = cache.Clear()
}
type InMemoryCache struct { type InMemoryCache struct {
mu sync.RWMutex mu sync.Mutex
data map[string]*CacheEntry data map[string]*CacheEntry
maxSize int
ll *list.List
lruEl map[string]*list.Element
prefixKeys map[string]map[string]struct{}
keyPrefixes map[string]map[string]struct{}
} }
func NewInMemoryCache() *InMemoryCache { func NewInMemoryCache() *InMemoryCache {
return &InMemoryCache{ return &InMemoryCache{
data: make(map[string]*CacheEntry), data: make(map[string]*CacheEntry),
maxSize: 1000,
ll: list.New(),
lruEl: make(map[string]*list.Element),
prefixKeys: make(map[string]map[string]struct{}),
keyPrefixes: make(map[string]map[string]struct{}),
} }
} }
func (cache *InMemoryCache) SetMaxEntries(n int) {
cache.mu.Lock()
defer cache.mu.Unlock()
cache.maxSize = n
for n > 0 && len(cache.data) > n {
cache.evictOldestLocked()
}
}
func (cache *InMemoryCache) RegisterKeyForPath(cacheKey string, path string, cacheablePrefixes []string) {
cache.mu.Lock()
defer cache.mu.Unlock()
for _, p := range matchingCachePrefixes(path, cacheablePrefixes) {
if cache.prefixKeys[p] == nil {
cache.prefixKeys[p] = make(map[string]struct{})
}
cache.prefixKeys[p][cacheKey] = struct{}{}
if cache.keyPrefixes[cacheKey] == nil {
cache.keyPrefixes[cacheKey] = make(map[string]struct{})
}
cache.keyPrefixes[cacheKey][p] = struct{}{}
}
}
func (cache *InMemoryCache) InvalidateForMutationPath(mutationPath string, cacheablePrefixes []string) {
cache.mu.Lock()
defer cache.mu.Unlock()
seen := make(map[string]struct{})
var stale []string
for _, prefix := range cacheablePrefixes {
if !strings.HasPrefix(mutationPath, prefix) {
continue
}
for key := range cache.prefixKeys[prefix] {
if _, dup := seen[key]; dup {
continue
}
seen[key] = struct{}{}
stale = append(stale, key)
}
}
for _, key := range stale {
cache.removeKeyLocked(key)
}
}
func matchingCachePrefixes(path string, cacheablePrefixes []string) []string {
var out []string
for _, p := range cacheablePrefixes {
if strings.HasPrefix(path, p) {
out = append(out, p)
}
}
return out
}
func (cache *InMemoryCache) Get(key string) (*CacheEntry, error) { func (cache *InMemoryCache) Get(key string) (*CacheEntry, error) {
cache.mu.RLock() cache.mu.Lock()
entry, exists := cache.data[key] defer cache.mu.Unlock()
cache.mu.RUnlock()
entry, exists := cache.data[key]
if !exists { if !exists {
return nil, fmt.Errorf("key not found") return nil, fmt.Errorf("key not found")
} }
if time.Since(entry.Timestamp) > entry.TTL { if time.Since(entry.Timestamp) > entry.TTL {
cache.mu.Lock() cache.removeKeyLocked(key)
delete(cache.data, key)
cache.mu.Unlock()
return nil, fmt.Errorf("entry expired") return nil, fmt.Errorf("entry expired")
} }
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
cache.ll.MoveToFront(el)
}
return entry, nil return entry, nil
} }
func (cache *InMemoryCache) Set(key string, entry *CacheEntry) error { func (cache *InMemoryCache) Set(key string, entry *CacheEntry) error {
cache.mu.Lock() cache.mu.Lock()
defer cache.mu.Unlock() defer cache.mu.Unlock()
if _, exists := cache.data[key]; exists {
cache.data[key] = entry cache.data[key] = entry
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
cache.ll.MoveToFront(el)
}
return nil return nil
} }
cache.data[key] = entry
if cache.ll != nil {
el := cache.ll.PushFront(key)
cache.lruEl[key] = el
}
for cache.maxSize > 0 && len(cache.data) > cache.maxSize {
cache.evictOldestLocked()
}
return nil
}
func (cache *InMemoryCache) evictOldestLocked() {
if cache.ll == nil || cache.ll.Len() == 0 {
return
}
el := cache.ll.Back()
if el == nil {
return
}
key, _ := el.Value.(string)
cache.removeKeyLocked(key)
}
func (cache *InMemoryCache) removeKeyLocked(key string) {
delete(cache.data, key)
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
cache.ll.Remove(el)
}
delete(cache.lruEl, key)
for p := range cache.keyPrefixes[key] {
if m, ok := cache.prefixKeys[p]; ok {
delete(m, key)
if len(m) == 0 {
delete(cache.prefixKeys, p)
}
}
}
delete(cache.keyPrefixes, key)
}
func (cache *InMemoryCache) Delete(key string) error { func (cache *InMemoryCache) Delete(key string) error {
cache.mu.Lock() cache.mu.Lock()
defer cache.mu.Unlock() defer cache.mu.Unlock()
delete(cache.data, key) if _, ok := cache.data[key]; !ok {
return fmt.Errorf("key not found")
}
cache.removeKeyLocked(key)
return nil return nil
} }
@@ -72,6 +220,10 @@ func (cache *InMemoryCache) Clear() error {
cache.mu.Lock() cache.mu.Lock()
defer cache.mu.Unlock() defer cache.mu.Unlock()
cache.data = make(map[string]*CacheEntry) cache.data = make(map[string]*CacheEntry)
cache.prefixKeys = make(map[string]map[string]struct{})
cache.keyPrefixes = make(map[string]map[string]struct{})
cache.lruEl = make(map[string]*list.Element)
cache.ll = list.New()
return nil return nil
} }
@@ -95,6 +247,7 @@ func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.H
if config == nil { if config == nil {
config = DefaultCacheConfig() config = DefaultCacheConfig()
} }
applyCacheMaxSize(cache, config.MaxSize)
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -141,8 +294,15 @@ func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.H
TTL: config.TTL, TTL: config.TTL,
} }
path := r.URL.Path
prefixes := config.CacheablePaths
go func() { go func() {
cache.Set(cacheKey, entry) if err := cache.Set(cacheKey, entry); err != nil {
log.Printf("middleware cache Set: %v", err)
return
}
registerIndexedCacheKey(cache, cacheKey, path, prefixes)
}() }()
} }
}) })
@@ -190,20 +350,21 @@ func generateCacheKey(r *http.Request) string {
key += "?" + r.URL.RawQuery key += "?" + r.URL.RawQuery
} }
if userID := GetUserIDFromContext(r.Context()); userID != 0 { if userID := GetUserIDFromContext(r.Context()); userID != nil {
key += fmt.Sprintf(":user:%d", userID) key += fmt.Sprintf(":user:%d", *userID)
} }
hash := md5.Sum([]byte(key)) sum := sha256.Sum256([]byte(key))
return fmt.Sprintf("cache:%x", hash) return "cache:" + hex.EncodeToString(sum[:])
} }
func CacheInvalidationMiddleware(cache Cache) func(http.Handler) http.Handler { func CacheInvalidationMiddleware(cache Cache, cacheablePrefixes []string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" { if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" {
mPath := r.URL.Path
go func() { go func() {
cache.Clear() invalidateCacheForMutation(cache, mPath, cacheablePrefixes)
}() }()
} }