To gitea and beyond, let's go(-yco)
This commit is contained in:
393
internal/middleware/ratelimit.go
Normal file
393
internal/middleware/ratelimit.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMaxKeys = 10000
|
||||
|
||||
DefaultCleanupInterval = 5 * time.Minute
|
||||
|
||||
DefaultMaxStaleAge = 10 * time.Minute
|
||||
)
|
||||
|
||||
var TrustProxyHeaders = false
|
||||
|
||||
func SetTrustProxyHeaders(value bool) {
|
||||
TrustProxyHeaders = value
|
||||
}
|
||||
|
||||
type limiterKey struct {
|
||||
window time.Duration
|
||||
limit int
|
||||
}
|
||||
|
||||
var (
|
||||
limiterRegistry = make(map[limiterKey]*RateLimiter)
|
||||
registryMutex sync.RWMutex
|
||||
registryCleanup []*RateLimiter
|
||||
cleanupMutex sync.Mutex
|
||||
)
|
||||
|
||||
func getOrCreateLimiter(window time.Duration, limit int) *RateLimiter {
|
||||
key := limiterKey{window: window, limit: limit}
|
||||
|
||||
registryMutex.RLock()
|
||||
if limiter, exists := limiterRegistry[key]; exists {
|
||||
registryMutex.RUnlock()
|
||||
return limiter
|
||||
}
|
||||
registryMutex.RUnlock()
|
||||
|
||||
registryMutex.Lock()
|
||||
defer registryMutex.Unlock()
|
||||
|
||||
if limiter, exists := limiterRegistry[key]; exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
limiter := NewRateLimiter(window, limit)
|
||||
limiterRegistry[key] = limiter
|
||||
|
||||
cleanupMutex.Lock()
|
||||
registryCleanup = append(registryCleanup, limiter)
|
||||
cleanupMutex.Unlock()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
func StopAllRateLimiters() {
|
||||
cleanupMutex.Lock()
|
||||
defer cleanupMutex.Unlock()
|
||||
|
||||
for _, limiter := range registryCleanup {
|
||||
limiter.StopCleanup()
|
||||
}
|
||||
registryCleanup = nil
|
||||
|
||||
registryMutex.Lock()
|
||||
limiterRegistry = make(map[limiterKey]*RateLimiter)
|
||||
registryMutex.Unlock()
|
||||
}
|
||||
|
||||
type clock interface {
|
||||
Now() time.Time
|
||||
}
|
||||
|
||||
type realClock struct{}
|
||||
|
||||
func (c *realClock) Now() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
type keyEntry struct {
|
||||
requests []time.Time
|
||||
lastAccess time.Time
|
||||
}
|
||||
|
||||
type RateLimiter struct {
|
||||
entries map[string]*keyEntry
|
||||
mutex sync.RWMutex
|
||||
window time.Duration
|
||||
limit int
|
||||
maxKeys int
|
||||
cleanupInterval time.Duration
|
||||
maxStaleAge time.Duration
|
||||
stopCleanup chan struct{}
|
||||
cleanupOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
clock clock
|
||||
}
|
||||
|
||||
func NewRateLimiter(window time.Duration, limit int) *RateLimiter {
|
||||
return NewRateLimiterWithConfig(window, limit, DefaultMaxKeys, DefaultCleanupInterval, DefaultMaxStaleAge)
|
||||
}
|
||||
|
||||
func NewRateLimiterWithConfig(window time.Duration, limit int, maxKeys int, cleanupInterval time.Duration, maxStaleAge time.Duration) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
entries: make(map[string]*keyEntry),
|
||||
window: window,
|
||||
limit: limit,
|
||||
maxKeys: maxKeys,
|
||||
cleanupInterval: cleanupInterval,
|
||||
maxStaleAge: maxStaleAge,
|
||||
stopCleanup: make(chan struct{}),
|
||||
clock: &realClock{},
|
||||
}
|
||||
|
||||
rl.StartCleanup()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
func newRateLimiterWithClock(window time.Duration, limit int, c clock) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
entries: make(map[string]*keyEntry),
|
||||
window: window,
|
||||
limit: limit,
|
||||
maxKeys: DefaultMaxKeys,
|
||||
cleanupInterval: DefaultCleanupInterval,
|
||||
maxStaleAge: DefaultMaxStaleAge,
|
||||
stopCleanup: make(chan struct{}),
|
||||
clock: c,
|
||||
}
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Allow(key string) bool {
|
||||
rl.mutex.Lock()
|
||||
defer rl.mutex.Unlock()
|
||||
|
||||
now := rl.clock.Now()
|
||||
cutoff := now.Add(-rl.window)
|
||||
|
||||
var entry *keyEntry
|
||||
var exists bool
|
||||
|
||||
if entry, exists = rl.entries[key]; exists {
|
||||
|
||||
isStale := now.Sub(entry.lastAccess) > rl.maxStaleAge
|
||||
|
||||
var validRequests []time.Time
|
||||
for _, reqTime := range entry.requests {
|
||||
if reqTime.After(cutoff) {
|
||||
validRequests = append(validRequests, reqTime)
|
||||
}
|
||||
}
|
||||
entry.requests = validRequests
|
||||
|
||||
if len(entry.requests) == 0 && isStale {
|
||||
delete(rl.entries, key)
|
||||
exists = false
|
||||
} else {
|
||||
|
||||
entry.lastAccess = now
|
||||
}
|
||||
}
|
||||
|
||||
if !exists {
|
||||
|
||||
if len(rl.entries) >= rl.maxKeys {
|
||||
|
||||
rl.evictLRU()
|
||||
}
|
||||
|
||||
entry = &keyEntry{
|
||||
requests: []time.Time{now},
|
||||
lastAccess: now,
|
||||
}
|
||||
rl.entries[key] = entry
|
||||
return true
|
||||
}
|
||||
|
||||
requestCount := len(entry.requests)
|
||||
if requestCount >= rl.limit {
|
||||
return false
|
||||
}
|
||||
|
||||
entry.requests = append(entry.requests, now)
|
||||
entry.lastAccess = now
|
||||
return true
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) evictLRU() {
|
||||
if len(rl.entries) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var oldestKey string
|
||||
var oldestTime time.Time
|
||||
first := true
|
||||
|
||||
for key, entry := range rl.entries {
|
||||
if first || entry.lastAccess.Before(oldestTime) {
|
||||
oldestKey = key
|
||||
oldestTime = entry.lastAccess
|
||||
first = false
|
||||
}
|
||||
}
|
||||
|
||||
if oldestKey != "" {
|
||||
delete(rl.entries, oldestKey)
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) GetRemainingTime(key string) time.Duration {
|
||||
rl.mutex.RLock()
|
||||
defer rl.mutex.RUnlock()
|
||||
|
||||
if entry, exists := rl.entries[key]; exists && len(entry.requests) > 0 {
|
||||
oldestRequest := entry.requests[0]
|
||||
return rl.window - rl.clock.Now().Sub(oldestRequest)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Cleanup() {
|
||||
rl.mutex.Lock()
|
||||
defer rl.mutex.Unlock()
|
||||
|
||||
now := rl.clock.Now()
|
||||
cutoff := now.Add(-rl.window)
|
||||
staleCutoff := now.Add(-rl.maxStaleAge)
|
||||
|
||||
for key, entry := range rl.entries {
|
||||
|
||||
var validRequests []time.Time
|
||||
for _, reqTime := range entry.requests {
|
||||
if reqTime.After(cutoff) {
|
||||
validRequests = append(validRequests, reqTime)
|
||||
}
|
||||
}
|
||||
entry.requests = validRequests
|
||||
|
||||
if len(entry.requests) == 0 && entry.lastAccess.Before(staleCutoff) {
|
||||
delete(rl.entries, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) StartCleanup() {
|
||||
rl.cleanupOnce.Do(func() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(rl.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
rl.Cleanup()
|
||||
case <-rl.stopCleanup:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) StopCleanup() {
|
||||
rl.stopOnce.Do(func() {
|
||||
close(rl.stopCleanup)
|
||||
})
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) GetSize() int {
|
||||
rl.mutex.RLock()
|
||||
defer rl.mutex.RUnlock()
|
||||
return len(rl.entries)
|
||||
}
|
||||
|
||||
func GetSecureClientIP(r *http.Request) string {
|
||||
if TrustProxyHeaders {
|
||||
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) > 0 {
|
||||
ip := strings.TrimSpace(ips[0])
|
||||
if net.ParseIP(ip) != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
ip := strings.TrimSpace(xri)
|
||||
if net.ParseIP(ip) != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
|
||||
if net.ParseIP(r.RemoteAddr) != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
if net.ParseIP(ip) != nil {
|
||||
return ip
|
||||
}
|
||||
|
||||
return ip
|
||||
}
|
||||
|
||||
func GetKey(r *http.Request) string {
|
||||
ip := GetSecureClientIP(r)
|
||||
|
||||
if userID := GetUserIDFromContext(r.Context()); userID != 0 {
|
||||
return fmt.Sprintf("user:%d:ip:%s", userID, ip)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ip:%s", ip)
|
||||
}
|
||||
|
||||
func RateLimitMiddleware(window time.Duration, limit int) func(http.Handler) http.Handler {
|
||||
limiter := getOrCreateLimiter(window, limit)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := GetKey(r)
|
||||
|
||||
if !limiter.Allow(key) {
|
||||
remainingTime := limiter.GetRemainingTime(key)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", remainingTime.Seconds()))
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
|
||||
response := map[string]any{
|
||||
"error": "Rate limit exceeded",
|
||||
"message": fmt.Sprintf("Too many requests. Please try again in %d seconds.", int(remainingTime.Seconds())),
|
||||
"retry_after": remainingTime.Seconds(),
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
jsonData = []byte(`{"error":"Rate limit exceeded"}`)
|
||||
}
|
||||
|
||||
w.Write(jsonData)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func AuthRateLimitMiddleware() func(http.Handler) http.Handler {
|
||||
return RateLimitMiddleware(1*time.Minute, 5)
|
||||
}
|
||||
|
||||
func AuthRateLimitMiddlewareWithLimit(limit int) func(http.Handler) http.Handler {
|
||||
return RateLimitMiddleware(1*time.Minute, limit)
|
||||
}
|
||||
|
||||
func GeneralRateLimitMiddleware() func(http.Handler) http.Handler {
|
||||
return RateLimitMiddleware(1*time.Minute, 100)
|
||||
}
|
||||
|
||||
func GeneralRateLimitMiddlewareWithLimit(limit int) func(http.Handler) http.Handler {
|
||||
return RateLimitMiddleware(1*time.Minute, limit)
|
||||
}
|
||||
|
||||
func HealthRateLimitMiddleware(limit int) func(http.Handler) http.Handler {
|
||||
return RateLimitMiddleware(1*time.Minute, limit)
|
||||
}
|
||||
|
||||
func MetricsRateLimitMiddleware(limit int) func(http.Handler) http.Handler {
|
||||
return RateLimitMiddleware(1*time.Minute, limit)
|
||||
}
|
||||
Reference in New Issue
Block a user