375 lines
8.4 KiB
Go
375 lines
8.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"container/list"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type CacheEntry struct {
|
|
Data []byte `json:"data"`
|
|
Headers http.Header `json:"headers"`
|
|
Timestamp time.Time `json:"timestamp"`
|
|
TTL time.Duration `json:"ttl"`
|
|
}
|
|
|
|
type Cache interface {
|
|
Get(key string) (*CacheEntry, error)
|
|
Set(key string, entry *CacheEntry) error
|
|
Delete(key string) error
|
|
Clear() error
|
|
}
|
|
|
|
func applyCacheMaxSize(cache Cache, max int) {
|
|
if im, ok := cache.(*InMemoryCache); ok {
|
|
im.SetMaxEntries(max)
|
|
}
|
|
}
|
|
|
|
func registerIndexedCacheKey(cache Cache, cacheKey string, path string, cacheablePrefixes []string) {
|
|
if im, ok := cache.(*InMemoryCache); ok {
|
|
im.RegisterKeyForPath(cacheKey, path, cacheablePrefixes)
|
|
}
|
|
}
|
|
|
|
func invalidateCacheForMutation(cache Cache, mutationPath string, cacheablePrefixes []string) {
|
|
if im, ok := cache.(*InMemoryCache); ok {
|
|
im.InvalidateForMutationPath(mutationPath, cacheablePrefixes)
|
|
return
|
|
}
|
|
_ = cache.Clear()
|
|
}
|
|
|
|
type InMemoryCache struct {
|
|
mu sync.Mutex
|
|
|
|
data map[string]*CacheEntry
|
|
maxSize int
|
|
|
|
ll *list.List
|
|
lruEl map[string]*list.Element
|
|
|
|
prefixKeys map[string]map[string]struct{}
|
|
keyPrefixes map[string]map[string]struct{}
|
|
}
|
|
|
|
func NewInMemoryCache() *InMemoryCache {
|
|
return &InMemoryCache{
|
|
data: make(map[string]*CacheEntry),
|
|
maxSize: 1000,
|
|
ll: list.New(),
|
|
lruEl: make(map[string]*list.Element),
|
|
prefixKeys: make(map[string]map[string]struct{}),
|
|
keyPrefixes: make(map[string]map[string]struct{}),
|
|
}
|
|
}
|
|
|
|
func (cache *InMemoryCache) SetMaxEntries(n int) {
|
|
cache.mu.Lock()
|
|
defer cache.mu.Unlock()
|
|
cache.maxSize = n
|
|
for n > 0 && len(cache.data) > n {
|
|
cache.evictOldestLocked()
|
|
}
|
|
}
|
|
|
|
func (cache *InMemoryCache) RegisterKeyForPath(cacheKey string, path string, cacheablePrefixes []string) {
|
|
cache.mu.Lock()
|
|
defer cache.mu.Unlock()
|
|
for _, p := range matchingCachePrefixes(path, cacheablePrefixes) {
|
|
if cache.prefixKeys[p] == nil {
|
|
cache.prefixKeys[p] = make(map[string]struct{})
|
|
}
|
|
cache.prefixKeys[p][cacheKey] = struct{}{}
|
|
|
|
if cache.keyPrefixes[cacheKey] == nil {
|
|
cache.keyPrefixes[cacheKey] = make(map[string]struct{})
|
|
}
|
|
cache.keyPrefixes[cacheKey][p] = struct{}{}
|
|
}
|
|
}
|
|
|
|
func (cache *InMemoryCache) InvalidateForMutationPath(mutationPath string, cacheablePrefixes []string) {
|
|
cache.mu.Lock()
|
|
defer cache.mu.Unlock()
|
|
|
|
seen := make(map[string]struct{})
|
|
var stale []string
|
|
for _, prefix := range cacheablePrefixes {
|
|
if !strings.HasPrefix(mutationPath, prefix) {
|
|
continue
|
|
}
|
|
for key := range cache.prefixKeys[prefix] {
|
|
if _, dup := seen[key]; dup {
|
|
continue
|
|
}
|
|
seen[key] = struct{}{}
|
|
stale = append(stale, key)
|
|
}
|
|
}
|
|
|
|
for _, key := range stale {
|
|
cache.removeKeyLocked(key)
|
|
}
|
|
}
|
|
|
|
func matchingCachePrefixes(path string, cacheablePrefixes []string) []string {
|
|
var out []string
|
|
for _, p := range cacheablePrefixes {
|
|
if strings.HasPrefix(path, p) {
|
|
out = append(out, p)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (cache *InMemoryCache) Get(key string) (*CacheEntry, error) {
|
|
cache.mu.Lock()
|
|
defer cache.mu.Unlock()
|
|
|
|
entry, exists := cache.data[key]
|
|
if !exists {
|
|
return nil, fmt.Errorf("key not found")
|
|
}
|
|
|
|
if time.Since(entry.Timestamp) > entry.TTL {
|
|
cache.removeKeyLocked(key)
|
|
return nil, fmt.Errorf("entry expired")
|
|
}
|
|
|
|
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
|
|
cache.ll.MoveToFront(el)
|
|
}
|
|
|
|
return entry, nil
|
|
}
|
|
|
|
func (cache *InMemoryCache) Set(key string, entry *CacheEntry) error {
|
|
cache.mu.Lock()
|
|
defer cache.mu.Unlock()
|
|
|
|
if _, exists := cache.data[key]; exists {
|
|
cache.data[key] = entry
|
|
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
|
|
cache.ll.MoveToFront(el)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
cache.data[key] = entry
|
|
if cache.ll != nil {
|
|
el := cache.ll.PushFront(key)
|
|
cache.lruEl[key] = el
|
|
}
|
|
|
|
for cache.maxSize > 0 && len(cache.data) > cache.maxSize {
|
|
cache.evictOldestLocked()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (cache *InMemoryCache) evictOldestLocked() {
|
|
if cache.ll == nil || cache.ll.Len() == 0 {
|
|
return
|
|
}
|
|
el := cache.ll.Back()
|
|
if el == nil {
|
|
return
|
|
}
|
|
key, _ := el.Value.(string)
|
|
cache.removeKeyLocked(key)
|
|
}
|
|
|
|
func (cache *InMemoryCache) removeKeyLocked(key string) {
|
|
delete(cache.data, key)
|
|
|
|
if el, ok := cache.lruEl[key]; ok && cache.ll != nil {
|
|
cache.ll.Remove(el)
|
|
}
|
|
delete(cache.lruEl, key)
|
|
|
|
for p := range cache.keyPrefixes[key] {
|
|
if m, ok := cache.prefixKeys[p]; ok {
|
|
delete(m, key)
|
|
if len(m) == 0 {
|
|
delete(cache.prefixKeys, p)
|
|
}
|
|
}
|
|
}
|
|
delete(cache.keyPrefixes, key)
|
|
}
|
|
|
|
func (cache *InMemoryCache) Delete(key string) error {
|
|
cache.mu.Lock()
|
|
defer cache.mu.Unlock()
|
|
if _, ok := cache.data[key]; !ok {
|
|
return fmt.Errorf("key not found")
|
|
}
|
|
cache.removeKeyLocked(key)
|
|
return nil
|
|
}
|
|
|
|
func (cache *InMemoryCache) Clear() error {
|
|
cache.mu.Lock()
|
|
defer cache.mu.Unlock()
|
|
cache.data = make(map[string]*CacheEntry)
|
|
cache.prefixKeys = make(map[string]map[string]struct{})
|
|
cache.keyPrefixes = make(map[string]map[string]struct{})
|
|
cache.lruEl = make(map[string]*list.Element)
|
|
cache.ll = list.New()
|
|
return nil
|
|
}
|
|
|
|
type CacheConfig struct {
|
|
TTL time.Duration
|
|
MaxSize int
|
|
CacheablePaths []string
|
|
CacheableMethods []string
|
|
}
|
|
|
|
func DefaultCacheConfig() *CacheConfig {
|
|
return &CacheConfig{
|
|
TTL: 5 * time.Minute,
|
|
MaxSize: 1000,
|
|
CacheablePaths: []string{},
|
|
CacheableMethods: []string{"GET"},
|
|
}
|
|
}
|
|
|
|
func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.Handler {
|
|
if config == nil {
|
|
config = DefaultCacheConfig()
|
|
}
|
|
applyCacheMaxSize(cache, config.MaxSize)
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "GET" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
if !isCacheablePath(r.URL.Path, config.CacheablePaths) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
cacheKey := generateCacheKey(r)
|
|
|
|
if entry, err := cache.Get(cacheKey); err == nil {
|
|
for key, values := range entry.Headers {
|
|
if len(values) > 0 {
|
|
w.Header().Set(key, values[0])
|
|
for i := 1; i < len(values); i++ {
|
|
w.Header().Add(key, values[i])
|
|
}
|
|
}
|
|
}
|
|
w.Header().Set("X-Cache", "HIT")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(entry.Data)
|
|
return
|
|
}
|
|
|
|
capturer := &responseCapturer{
|
|
ResponseWriter: w,
|
|
body: &bytes.Buffer{},
|
|
headers: make(http.Header),
|
|
}
|
|
|
|
next.ServeHTTP(capturer, r)
|
|
|
|
if capturer.statusCode == http.StatusOK {
|
|
entry := &CacheEntry{
|
|
Data: capturer.body.Bytes(),
|
|
Headers: capturer.headers,
|
|
Timestamp: time.Now(),
|
|
TTL: config.TTL,
|
|
}
|
|
|
|
path := r.URL.Path
|
|
prefixes := config.CacheablePaths
|
|
|
|
go func() {
|
|
if err := cache.Set(cacheKey, entry); err != nil {
|
|
log.Printf("middleware cache Set: %v", err)
|
|
return
|
|
}
|
|
registerIndexedCacheKey(cache, cacheKey, path, prefixes)
|
|
}()
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type responseCapturer struct {
|
|
http.ResponseWriter
|
|
body *bytes.Buffer
|
|
headers http.Header
|
|
statusCode int
|
|
}
|
|
|
|
func (rc *responseCapturer) WriteHeader(code int) {
|
|
rc.statusCode = code
|
|
for key, values := range rc.headers {
|
|
for _, value := range values {
|
|
rc.ResponseWriter.Header().Add(key, value)
|
|
}
|
|
}
|
|
rc.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (rc *responseCapturer) Write(b []byte) (int, error) {
|
|
rc.body.Write(b)
|
|
return rc.ResponseWriter.Write(b)
|
|
}
|
|
|
|
func (rc *responseCapturer) Header() http.Header {
|
|
return rc.headers
|
|
}
|
|
|
|
func isCacheablePath(path string, cacheablePaths []string) bool {
|
|
for _, cacheablePath := range cacheablePaths {
|
|
if strings.HasPrefix(path, cacheablePath) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func generateCacheKey(r *http.Request) string {
|
|
key := fmt.Sprintf("%s:%s", r.Method, r.URL.Path)
|
|
if r.URL.RawQuery != "" {
|
|
key += "?" + r.URL.RawQuery
|
|
}
|
|
|
|
if userID := GetUserIDFromContext(r.Context()); userID != nil {
|
|
key += fmt.Sprintf(":user:%d", *userID)
|
|
}
|
|
|
|
sum := sha256.Sum256([]byte(key))
|
|
return "cache:" + hex.EncodeToString(sum[:])
|
|
}
|
|
|
|
func CacheInvalidationMiddleware(cache Cache, cacheablePrefixes []string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" {
|
|
mPath := r.URL.Path
|
|
go func() {
|
|
invalidateCacheForMutation(cache, mPath, cacheablePrefixes)
|
|
}()
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|