Files
goyco/internal/middleware/cache.go

206 lines
4.2 KiB
Go

package middleware
import (
"bytes"
"crypto/md5"
"fmt"
"net/http"
"strings"
"sync"
"time"
)
type CacheEntry struct {
Data []byte `json:"data"`
Headers http.Header `json:"headers"`
Timestamp time.Time `json:"timestamp"`
TTL time.Duration `json:"ttl"`
}
type Cache interface {
Get(key string) (*CacheEntry, error)
Set(key string, entry *CacheEntry) error
Delete(key string) error
Clear() error
}
type InMemoryCache struct {
mu sync.RWMutex
data map[string]*CacheEntry
}
func NewInMemoryCache() *InMemoryCache {
return &InMemoryCache{
data: make(map[string]*CacheEntry),
}
}
func (cache *InMemoryCache) Get(key string) (*CacheEntry, error) {
cache.mu.RLock()
entry, exists := cache.data[key]
cache.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("key not found")
}
if time.Since(entry.Timestamp) > entry.TTL {
cache.mu.Lock()
delete(cache.data, key)
cache.mu.Unlock()
return nil, fmt.Errorf("entry expired")
}
return entry, nil
}
func (cache *InMemoryCache) Set(key string, entry *CacheEntry) error {
cache.mu.Lock()
defer cache.mu.Unlock()
cache.data[key] = entry
return nil
}
func (cache *InMemoryCache) Delete(key string) error {
cache.mu.Lock()
defer cache.mu.Unlock()
delete(cache.data, key)
return nil
}
func (cache *InMemoryCache) Clear() error {
cache.mu.Lock()
defer cache.mu.Unlock()
cache.data = make(map[string]*CacheEntry)
return nil
}
type CacheConfig struct {
TTL time.Duration
MaxSize int
CacheablePaths []string
CacheableMethods []string
}
func DefaultCacheConfig() *CacheConfig {
return &CacheConfig{
TTL: 5 * time.Minute,
MaxSize: 1000,
CacheablePaths: []string{},
CacheableMethods: []string{"GET"},
}
}
func CacheMiddleware(cache Cache, config *CacheConfig) func(http.Handler) http.Handler {
if config == nil {
config = DefaultCacheConfig()
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
next.ServeHTTP(w, r)
return
}
if !isCacheablePath(r.URL.Path, config.CacheablePaths) {
next.ServeHTTP(w, r)
return
}
cacheKey := generateCacheKey(r)
if entry, err := cache.Get(cacheKey); err == nil {
for key, values := range entry.Headers {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.Header().Set("X-Cache", "HIT")
w.WriteHeader(http.StatusOK)
w.Write(entry.Data)
return
}
capturer := &responseCapturer{
ResponseWriter: w,
body: &bytes.Buffer{},
headers: make(http.Header),
}
next.ServeHTTP(capturer, r)
if capturer.statusCode == http.StatusOK {
entry := &CacheEntry{
Data: capturer.body.Bytes(),
Headers: capturer.headers,
Timestamp: time.Now(),
TTL: config.TTL,
}
go func() {
cache.Set(cacheKey, entry)
}()
}
})
}
}
type responseCapturer struct {
http.ResponseWriter
body *bytes.Buffer
headers http.Header
statusCode int
}
func (rc *responseCapturer) WriteHeader(code int) {
rc.statusCode = code
rc.ResponseWriter.WriteHeader(code)
}
func (rc *responseCapturer) Write(b []byte) (int, error) {
rc.body.Write(b)
return rc.ResponseWriter.Write(b)
}
func (rc *responseCapturer) Header() http.Header {
return rc.headers
}
func isCacheablePath(path string, cacheablePaths []string) bool {
for _, cacheablePath := range cacheablePaths {
if strings.HasPrefix(path, cacheablePath) {
return true
}
}
return false
}
func generateCacheKey(r *http.Request) string {
key := fmt.Sprintf("%s:%s", r.Method, r.URL.Path)
if r.URL.RawQuery != "" {
key += "?" + r.URL.RawQuery
}
if userID := GetUserIDFromContext(r.Context()); userID != 0 {
key += fmt.Sprintf(":user:%d", userID)
}
hash := md5.Sum([]byte(key))
return fmt.Sprintf("cache:%x", hash)
}
func CacheInvalidationMiddleware(cache Cache) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE" {
go func() {
cache.Clear()
}()
}
next.ServeHTTP(w, r)
})
}
}