394 lines
8.1 KiB
Go
394 lines
8.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
DefaultMaxKeys = 10000
|
|
|
|
DefaultCleanupInterval = 5 * time.Minute
|
|
|
|
DefaultMaxStaleAge = 10 * time.Minute
|
|
)
|
|
|
|
var TrustProxyHeaders = false
|
|
|
|
func SetTrustProxyHeaders(value bool) {
|
|
TrustProxyHeaders = value
|
|
}
|
|
|
|
type limiterKey struct {
|
|
window time.Duration
|
|
limit int
|
|
}
|
|
|
|
var (
|
|
limiterRegistry = make(map[limiterKey]*RateLimiter)
|
|
registryMutex sync.RWMutex
|
|
registryCleanup []*RateLimiter
|
|
cleanupMutex sync.Mutex
|
|
)
|
|
|
|
func getOrCreateLimiter(window time.Duration, limit int) *RateLimiter {
|
|
key := limiterKey{window: window, limit: limit}
|
|
|
|
registryMutex.RLock()
|
|
if limiter, exists := limiterRegistry[key]; exists {
|
|
registryMutex.RUnlock()
|
|
return limiter
|
|
}
|
|
registryMutex.RUnlock()
|
|
|
|
registryMutex.Lock()
|
|
defer registryMutex.Unlock()
|
|
|
|
if limiter, exists := limiterRegistry[key]; exists {
|
|
return limiter
|
|
}
|
|
|
|
limiter := NewRateLimiter(window, limit)
|
|
limiterRegistry[key] = limiter
|
|
|
|
cleanupMutex.Lock()
|
|
registryCleanup = append(registryCleanup, limiter)
|
|
cleanupMutex.Unlock()
|
|
|
|
return limiter
|
|
}
|
|
|
|
func StopAllRateLimiters() {
|
|
cleanupMutex.Lock()
|
|
defer cleanupMutex.Unlock()
|
|
|
|
for _, limiter := range registryCleanup {
|
|
limiter.StopCleanup()
|
|
}
|
|
registryCleanup = nil
|
|
|
|
registryMutex.Lock()
|
|
limiterRegistry = make(map[limiterKey]*RateLimiter)
|
|
registryMutex.Unlock()
|
|
}
|
|
|
|
type clock interface {
|
|
Now() time.Time
|
|
}
|
|
|
|
type realClock struct{}
|
|
|
|
func (c *realClock) Now() time.Time {
|
|
return time.Now()
|
|
}
|
|
|
|
type keyEntry struct {
|
|
requests []time.Time
|
|
lastAccess time.Time
|
|
}
|
|
|
|
type RateLimiter struct {
|
|
entries map[string]*keyEntry
|
|
mutex sync.RWMutex
|
|
window time.Duration
|
|
limit int
|
|
maxKeys int
|
|
cleanupInterval time.Duration
|
|
maxStaleAge time.Duration
|
|
stopCleanup chan struct{}
|
|
cleanupOnce sync.Once
|
|
stopOnce sync.Once
|
|
clock clock
|
|
}
|
|
|
|
func NewRateLimiter(window time.Duration, limit int) *RateLimiter {
|
|
return NewRateLimiterWithConfig(window, limit, DefaultMaxKeys, DefaultCleanupInterval, DefaultMaxStaleAge)
|
|
}
|
|
|
|
func NewRateLimiterWithConfig(window time.Duration, limit int, maxKeys int, cleanupInterval time.Duration, maxStaleAge time.Duration) *RateLimiter {
|
|
rl := &RateLimiter{
|
|
entries: make(map[string]*keyEntry),
|
|
window: window,
|
|
limit: limit,
|
|
maxKeys: maxKeys,
|
|
cleanupInterval: cleanupInterval,
|
|
maxStaleAge: maxStaleAge,
|
|
stopCleanup: make(chan struct{}),
|
|
clock: &realClock{},
|
|
}
|
|
|
|
rl.StartCleanup()
|
|
|
|
return rl
|
|
}
|
|
|
|
func newRateLimiterWithClock(window time.Duration, limit int, c clock) *RateLimiter {
|
|
rl := &RateLimiter{
|
|
entries: make(map[string]*keyEntry),
|
|
window: window,
|
|
limit: limit,
|
|
maxKeys: DefaultMaxKeys,
|
|
cleanupInterval: DefaultCleanupInterval,
|
|
maxStaleAge: DefaultMaxStaleAge,
|
|
stopCleanup: make(chan struct{}),
|
|
clock: c,
|
|
}
|
|
|
|
return rl
|
|
}
|
|
|
|
func (rl *RateLimiter) Allow(key string) bool {
|
|
rl.mutex.Lock()
|
|
defer rl.mutex.Unlock()
|
|
|
|
now := rl.clock.Now()
|
|
cutoff := now.Add(-rl.window)
|
|
|
|
var entry *keyEntry
|
|
var exists bool
|
|
|
|
if entry, exists = rl.entries[key]; exists {
|
|
|
|
isStale := now.Sub(entry.lastAccess) > rl.maxStaleAge
|
|
|
|
var validRequests []time.Time
|
|
for _, reqTime := range entry.requests {
|
|
if reqTime.After(cutoff) {
|
|
validRequests = append(validRequests, reqTime)
|
|
}
|
|
}
|
|
entry.requests = validRequests
|
|
|
|
if len(entry.requests) == 0 && isStale {
|
|
delete(rl.entries, key)
|
|
exists = false
|
|
} else {
|
|
|
|
entry.lastAccess = now
|
|
}
|
|
}
|
|
|
|
if !exists {
|
|
|
|
if len(rl.entries) >= rl.maxKeys {
|
|
|
|
rl.evictLRU()
|
|
}
|
|
|
|
entry = &keyEntry{
|
|
requests: []time.Time{now},
|
|
lastAccess: now,
|
|
}
|
|
rl.entries[key] = entry
|
|
return true
|
|
}
|
|
|
|
requestCount := len(entry.requests)
|
|
if requestCount >= rl.limit {
|
|
return false
|
|
}
|
|
|
|
entry.requests = append(entry.requests, now)
|
|
entry.lastAccess = now
|
|
return true
|
|
}
|
|
|
|
func (rl *RateLimiter) evictLRU() {
|
|
if len(rl.entries) == 0 {
|
|
return
|
|
}
|
|
|
|
var oldestKey string
|
|
var oldestTime time.Time
|
|
first := true
|
|
|
|
for key, entry := range rl.entries {
|
|
if first || entry.lastAccess.Before(oldestTime) {
|
|
oldestKey = key
|
|
oldestTime = entry.lastAccess
|
|
first = false
|
|
}
|
|
}
|
|
|
|
if oldestKey != "" {
|
|
delete(rl.entries, oldestKey)
|
|
}
|
|
}
|
|
|
|
func (rl *RateLimiter) GetRemainingTime(key string) time.Duration {
|
|
rl.mutex.RLock()
|
|
defer rl.mutex.RUnlock()
|
|
|
|
if entry, exists := rl.entries[key]; exists && len(entry.requests) > 0 {
|
|
oldestRequest := entry.requests[0]
|
|
return rl.window - rl.clock.Now().Sub(oldestRequest)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (rl *RateLimiter) Cleanup() {
|
|
rl.mutex.Lock()
|
|
defer rl.mutex.Unlock()
|
|
|
|
now := rl.clock.Now()
|
|
cutoff := now.Add(-rl.window)
|
|
staleCutoff := now.Add(-rl.maxStaleAge)
|
|
|
|
for key, entry := range rl.entries {
|
|
|
|
var validRequests []time.Time
|
|
for _, reqTime := range entry.requests {
|
|
if reqTime.After(cutoff) {
|
|
validRequests = append(validRequests, reqTime)
|
|
}
|
|
}
|
|
entry.requests = validRequests
|
|
|
|
if len(entry.requests) == 0 && entry.lastAccess.Before(staleCutoff) {
|
|
delete(rl.entries, key)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (rl *RateLimiter) StartCleanup() {
|
|
rl.cleanupOnce.Do(func() {
|
|
go func() {
|
|
ticker := time.NewTicker(rl.cleanupInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
rl.Cleanup()
|
|
case <-rl.stopCleanup:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
})
|
|
}
|
|
|
|
func (rl *RateLimiter) StopCleanup() {
|
|
rl.stopOnce.Do(func() {
|
|
close(rl.stopCleanup)
|
|
})
|
|
}
|
|
|
|
func (rl *RateLimiter) GetSize() int {
|
|
rl.mutex.RLock()
|
|
defer rl.mutex.RUnlock()
|
|
return len(rl.entries)
|
|
}
|
|
|
|
func GetSecureClientIP(r *http.Request) string {
|
|
if TrustProxyHeaders {
|
|
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
|
|
ips := strings.Split(xff, ",")
|
|
if len(ips) > 0 {
|
|
ip := strings.TrimSpace(ips[0])
|
|
if net.ParseIP(ip) != nil {
|
|
return ip
|
|
}
|
|
}
|
|
}
|
|
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
ip := strings.TrimSpace(xri)
|
|
if net.ParseIP(ip) != nil {
|
|
return ip
|
|
}
|
|
}
|
|
}
|
|
|
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
|
|
if net.ParseIP(r.RemoteAddr) != nil {
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
if net.ParseIP(ip) != nil {
|
|
return ip
|
|
}
|
|
|
|
return ip
|
|
}
|
|
|
|
func GetKey(r *http.Request) string {
|
|
ip := GetSecureClientIP(r)
|
|
|
|
if userID := GetUserIDFromContext(r.Context()); userID != 0 {
|
|
return fmt.Sprintf("user:%d:ip:%s", userID, ip)
|
|
}
|
|
|
|
return fmt.Sprintf("ip:%s", ip)
|
|
}
|
|
|
|
func RateLimitMiddleware(window time.Duration, limit int) func(http.Handler) http.Handler {
|
|
limiter := getOrCreateLimiter(window, limit)
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
key := GetKey(r)
|
|
|
|
if !limiter.Allow(key) {
|
|
remainingTime := limiter.GetRemainingTime(key)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", remainingTime.Seconds()))
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
|
|
response := map[string]any{
|
|
"error": "Rate limit exceeded",
|
|
"message": fmt.Sprintf("Too many requests. Please try again in %d seconds.", int(remainingTime.Seconds())),
|
|
"retry_after": remainingTime.Seconds(),
|
|
}
|
|
|
|
jsonData, err := json.Marshal(response)
|
|
if err != nil {
|
|
jsonData = []byte(`{"error":"Rate limit exceeded"}`)
|
|
}
|
|
|
|
w.Write(jsonData)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
func AuthRateLimitMiddleware() func(http.Handler) http.Handler {
|
|
return RateLimitMiddleware(1*time.Minute, 5)
|
|
}
|
|
|
|
func AuthRateLimitMiddlewareWithLimit(limit int) func(http.Handler) http.Handler {
|
|
return RateLimitMiddleware(1*time.Minute, limit)
|
|
}
|
|
|
|
func GeneralRateLimitMiddleware() func(http.Handler) http.Handler {
|
|
return RateLimitMiddleware(1*time.Minute, 100)
|
|
}
|
|
|
|
func GeneralRateLimitMiddlewareWithLimit(limit int) func(http.Handler) http.Handler {
|
|
return RateLimitMiddleware(1*time.Minute, limit)
|
|
}
|
|
|
|
func HealthRateLimitMiddleware(limit int) func(http.Handler) http.Handler {
|
|
return RateLimitMiddleware(1*time.Minute, limit)
|
|
}
|
|
|
|
func MetricsRateLimitMiddleware(limit int) func(http.Handler) http.Handler {
|
|
return RateLimitMiddleware(1*time.Minute, limit)
|
|
}
|