fix(middleware): SHA-256 keys, LRU cache, and prefix-scoped invalidation
This commit is contained in:
+178
-17
@@ -2,8 +2,11 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"container/list"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -24,47 +27,192 @@ type Cache interface {
|
||||
Clear() error
|
||||
}
|
||||
|
||||
func applyCacheMaxSize(cache Cache, max int) {
|
||||
if im, ok := cache.(*InMemoryCache); ok {
|
||||
im.SetMaxEntries(max)
|
||||
}
|
||||
}
|
||||
|
||||
func registerIndexedCacheKey(cache Cache, cacheKey string, path string, cacheablePrefixes []string) {
|
||||
if im, ok := cache.(*InMemoryCache); ok {
|
||||
im.RegisterKeyForPath(cacheKey, path, cacheablePrefixes)
|
||||
}
|
||||
}
|
||||
|
||||
func invalidateCacheForMutation(cache Cache, mutationPath string, cacheablePrefixes []string) {
|
||||
if im, ok := cache.(*InMemoryCache); ok {
|
||||
im.InvalidateForMutationPath(mutationPath, cacheablePrefixes)
|
||||
return
|
||||
}
|
||||
_ = cache.Clear()
|
||||
}
|
||||
|
||||
type InMemoryCache struct {
|
||||
mu sync.RWMutex
|
||||
mu sync.Mutex
|
||||
|
||||
data map[string]*CacheEntry
|
||||
maxSize int
|
||||
|
||||
ll *list.List
|
||||
lruEl map[string]*list.Element
|
||||
|
||||
prefixKeys map[string]map[string]struct{}
|
||||
keyPrefixes map[string]map[string]struct{}
|
||||
}
|
||||
|
||||
func NewInMemoryCache() *InMemoryCache {
|
||||
return &InMemoryCache{
|
||||
data: make(map[string]*CacheEntry),
|
||||
maxSize: 1000,
|
||||
ll: list.New(),
|
||||
lruEl: make(map[string]*list.Element),
|
||||
prefixKeys: make(map[string]map[string]struct{}),
|
||||
keyPrefixes: make(map[string]map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) Get(key string) (*CacheEntry, error) {
|
||||
cache.mu.RLock()
|
||||
entry, exists := cache.data[key]
|
||||
cache.mu.RUnlock()
|
||||
func (cache *InMemoryCache) SetMaxEntries(n int) {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
cache.maxSize = n
|
||||
for n > 0 && len(cache.data) > n {
|
||||
cache.evictOldestLocked()
|
||||
}
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) RegisterKeyForPath(cacheKey string, path string, cacheablePrefixes []string) {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
for _, p := range matchingCachePrefixes(path, cacheablePrefixes) {
|
||||
if cache.prefixKeys[p] == nil {
|
||||
cache.prefixKeys[p] = make(map[string]struct{})
|
||||
}
|
||||
cache.prefixKeys[p][cacheKey] = struct{}{}
|
||||
|
||||
if cache.keyPrefixes[cacheKey] == nil {
|
||||
cache.keyPrefixes[cacheKey] = make(map[string]struct{})
|
||||
}
|
||||
cache.keyPrefixes[cacheKey][p] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) InvalidateForMutationPath(mutationPath string, cacheablePrefixes []string) {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
var stale []string
|
||||
for _, prefix := range cacheablePrefixes {
|
||||
if !strings.HasPrefix(mutationPath, prefix) {
|
||||
continue
|
||||
}
|
||||
for key := range cache.prefixKeys[prefix] {
|
||||
if _, dup := seen[key]; dup {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
stale = append(stale, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range stale {
|
||||
cache.removeKeyLocked(key)
|
||||
}
|
||||
}
|
||||
|
||||
func matchingCachePrefixes(path string, cacheablePrefixes []string) []string {
|
||||
var out []string
|
||||
for _, p := range cacheablePrefixes {
|
||||
if strings.HasPrefix(path, p) {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) Get(key string) (*CacheEntry, error) {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
|
||||
entry, exists := cache.data[key]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("key not found")
|
||||
}
|
||||
|
||||
if time.Since(entry.Timestamp) > entry.TTL {
|
||||
cache.mu.Lock()
|
||||
delete(cache.data, key)
|
||||
cache.mu.Unlock()
|
||||
cache.removeKeyLocked(key)
|
||||
return nil, fmt.Errorf("entry expired")
|
||||
}
|
||||
|
||||
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
|
||||
cache.ll.MoveToFront(el)
|
||||
}
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) Set(key string, entry *CacheEntry) error {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
|
||||
if _, exists := cache.data[key]; exists {
|
||||
cache.data[key] = entry
|
||||
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
|
||||
cache.ll.MoveToFront(el)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
cache.data[key] = entry
|
||||
if cache.ll != nil {
|
||||
el := cache.ll.PushFront(key)
|
||||
cache.lruEl[key] = el
|
||||
}
|
||||
|
||||
for cache.maxSize > 0 && len(cache.data) > cache.maxSize {
|
||||
cache.evictOldestLocked()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) evictOldestLocked() {
|
||||
if cache.ll == nil || cache.ll.Len() == 0 {
|
||||
return
|
||||
}
|
||||
el := cache.ll.Back()
|
||||
if el == nil {
|
||||
return
|
||||
}
|
||||
key, _ := el.Value.(string)
|
||||
cache.removeKeyLocked(key)
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) removeKeyLocked(key string) {
|
||||
delete(cache.data, key)
|
||||
|
||||
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
|
||||
cache.ll.Remove(el)
|
||||
}
|
||||
delete(cache.lruEl, key)
|
||||
|
||||
for p := range cache.keyPrefixes[key] {
|
||||
if m, ok := cache.prefixKeys[p]; ok {
|
||||
delete(m, key)
|
||||
if len(m) == 0 {
|
||||
delete(cache.prefixKeys, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(cache.keyPrefixes, key)
|
||||
}
|
||||
|
||||
func (cache *InMemoryCache) Delete(key string) error {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
delete(cache.data, key)
|
||||
if _, ok := cache.data[key]; !ok {
|
||||
return fmt.Errorf("key not found")
|
||||
}
|
||||
cache.removeKeyLocked(key)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -72,6 +220,10 @@ func (cache *InMemoryCache) Clear() error {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
cache.data = make(map[string]*CacheEntry)
|
||||
cache.prefixKeys = make(map[string]map[string]struct{})
|
||||
cache.keyPrefixes = make(map[string]map[string]struct{})
|
||||
cache.lruEl = make(map[string]*list.Element)
|
||||
cache.ll = list.New()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -95,6 +247,7 @@ func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.H
|
||||
if config == nil {
|
||||
config = DefaultCacheConfig()
|
||||
}
|
||||
applyCacheMaxSize(cache, config.MaxSize)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -141,8 +294,15 @@ func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.H
|
||||
TTL: config.TTL,
|
||||
}
|
||||
|
||||
path := r.URL.Path
|
||||
prefixes := config.CacheablePaths
|
||||
|
||||
go func() {
|
||||
cache.Set(cacheKey, entry)
|
||||
if err := cache.Set(cacheKey, entry); err != nil {
|
||||
log.Printf("middleware cache Set: %v", err)
|
||||
return
|
||||
}
|
||||
registerIndexedCacheKey(cache, cacheKey, path, prefixes)
|
||||
}()
|
||||
}
|
||||
})
|
||||
@@ -190,20 +350,21 @@ func generateCacheKey(r *http.Request) string {
|
||||
key += "?" + r.URL.RawQuery
|
||||
}
|
||||
|
||||
if userID := GetUserIDFromContext(r.Context()); userID != 0 {
|
||||
key += fmt.Sprintf(":user:%d", userID)
|
||||
if userID := GetUserIDFromContext(r.Context()); userID != nil {
|
||||
key += fmt.Sprintf(":user:%d", *userID)
|
||||
}
|
||||
|
||||
hash := md5.Sum([]byte(key))
|
||||
return fmt.Sprintf("cache:%x", hash)
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return "cache:" + hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func CacheInvalidationMiddleware(cache Cache) func(http.Handler) http.Handler {
|
||||
func CacheInvalidationMiddleware(cache Cache, cacheablePrefixes []string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" {
|
||||
mPath := r.URL.Path
|
||||
go func() {
|
||||
cache.Clear()
|
||||
invalidateCacheForMutation(cache, mPath, cacheablePrefixes)
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user