278 lines
6.7 KiB
Go
278 lines
6.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
dbMonitorKey contextKey = "db_monitor"
|
|
slowQueryThresholdKey contextKey = "slow_query_threshold"
|
|
)
|
|
|
|
type DBMonitor interface {
|
|
LogQuery(query string, duration time.Duration, err error)
|
|
LogSlowQuery(query string, duration time.Duration, threshold time.Duration)
|
|
GetStats() DBStats
|
|
}
|
|
|
|
type DBStats struct {
|
|
TotalQueries int64 `json:"total_queries"`
|
|
SlowQueries int64 `json:"slow_queries"`
|
|
AverageDuration time.Duration `json:"average_duration"`
|
|
MaxDuration time.Duration `json:"max_duration"`
|
|
ErrorCount int64 `json:"error_count"`
|
|
LastQueryTime time.Time `json:"last_query_time"`
|
|
}
|
|
|
|
type InMemoryDBMonitor struct {
|
|
stats DBStats
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func NewInMemoryDBMonitor() *InMemoryDBMonitor {
|
|
return &InMemoryDBMonitor{
|
|
stats: DBStats{},
|
|
}
|
|
}
|
|
|
|
func (m *InMemoryDBMonitor) LogQuery(query string, duration time.Duration, err error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
m.stats.TotalQueries++
|
|
m.stats.LastQueryTime = time.Now()
|
|
|
|
if err != nil {
|
|
m.stats.ErrorCount++
|
|
return
|
|
}
|
|
|
|
if m.stats.TotalQueries == 1 {
|
|
m.stats.AverageDuration = duration
|
|
} else {
|
|
|
|
totalDuration := int64(m.stats.AverageDuration) * (m.stats.TotalQueries - 1)
|
|
totalDuration += int64(duration)
|
|
m.stats.AverageDuration = time.Duration(totalDuration / m.stats.TotalQueries)
|
|
}
|
|
|
|
if duration > m.stats.MaxDuration {
|
|
m.stats.MaxDuration = duration
|
|
}
|
|
|
|
slowThreshold := 100 * time.Millisecond
|
|
if duration > slowThreshold {
|
|
m.stats.SlowQueries++
|
|
}
|
|
}
|
|
|
|
func (m *InMemoryDBMonitor) LogSlowQuery(query string, duration time.Duration, threshold time.Duration) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.stats.SlowQueries++
|
|
}
|
|
|
|
func (m *InMemoryDBMonitor) GetStats() DBStats {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
return m.stats
|
|
}
|
|
|
|
func DBMonitoringMiddleware(monitor DBMonitor, slowQueryThreshold time.Duration) func(http.Handler) http.Handler {
|
|
if slowQueryThreshold == 0 {
|
|
slowQueryThreshold = 100 * time.Millisecond
|
|
}
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
|
|
ctx := context.WithValue(r.Context(), dbMonitorKey, monitor)
|
|
ctx = context.WithValue(ctx, slowQueryThresholdKey, slowQueryThreshold)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
|
|
duration := time.Since(start)
|
|
if duration > slowQueryThreshold {
|
|
|
|
monitor.LogSlowQuery(r.URL.Path, duration, slowQueryThreshold)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type QueryLogger struct {
|
|
DB *sql.DB
|
|
Monitor DBMonitor
|
|
}
|
|
|
|
func NewQueryLogger(db *sql.DB, monitor DBMonitor) *QueryLogger {
|
|
return &QueryLogger{
|
|
DB: db,
|
|
Monitor: monitor,
|
|
}
|
|
}
|
|
|
|
func (ql *QueryLogger) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
|
start := time.Now()
|
|
rows, err := ql.DB.QueryContext(ctx, query, args...)
|
|
duration := time.Since(start)
|
|
|
|
ql.Monitor.LogQuery(query, duration, err)
|
|
return rows, err
|
|
}
|
|
|
|
func (ql *QueryLogger) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
|
|
start := time.Now()
|
|
row := ql.DB.QueryRowContext(ctx, query, args...)
|
|
duration := time.Since(start)
|
|
|
|
ql.Monitor.LogQuery(query, duration, nil)
|
|
return row
|
|
}
|
|
|
|
func (ql *QueryLogger) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
|
start := time.Now()
|
|
result, err := ql.DB.ExecContext(ctx, query, args...)
|
|
duration := time.Since(start)
|
|
|
|
ql.Monitor.LogQuery(query, duration, err)
|
|
return result, err
|
|
}
|
|
|
|
type DatabaseHealthChecker struct {
|
|
DB *sql.DB
|
|
Monitor DBMonitor
|
|
}
|
|
|
|
func NewDatabaseHealthChecker(db *sql.DB, monitor DBMonitor) *DatabaseHealthChecker {
|
|
return &DatabaseHealthChecker{
|
|
DB: db,
|
|
Monitor: monitor,
|
|
}
|
|
}
|
|
|
|
func (dhc *DatabaseHealthChecker) CheckHealth() map[string]any {
|
|
start := time.Now()
|
|
|
|
err := dhc.DB.Ping()
|
|
duration := time.Since(start)
|
|
|
|
health := map[string]any{
|
|
"status": "healthy",
|
|
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
|
"ping_time": duration.String(),
|
|
}
|
|
|
|
if err != nil {
|
|
health["status"] = "unhealthy"
|
|
health["error"] = err.Error()
|
|
return health
|
|
}
|
|
|
|
stats := dhc.Monitor.GetStats()
|
|
health["database_stats"] = map[string]any{
|
|
"total_queries": stats.TotalQueries,
|
|
"slow_queries": stats.SlowQueries,
|
|
"average_duration": stats.AverageDuration.String(),
|
|
"max_duration": stats.MaxDuration.String(),
|
|
"error_count": stats.ErrorCount,
|
|
"last_query_time": stats.LastQueryTime.Format(time.RFC3339),
|
|
}
|
|
|
|
return health
|
|
}
|
|
|
|
type PerformanceMetrics struct {
|
|
RequestCount int64 `json:"request_count"`
|
|
AverageResponse time.Duration `json:"average_response"`
|
|
MaxResponse time.Duration `json:"max_response"`
|
|
ErrorCount int64 `json:"error_count"`
|
|
DBStats DBStats `json:"database_stats"`
|
|
}
|
|
|
|
type MetricsCollector struct {
|
|
monitor DBMonitor
|
|
metrics PerformanceMetrics
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func NewMetricsCollector(monitor DBMonitor) *MetricsCollector {
|
|
return &MetricsCollector{
|
|
monitor: monitor,
|
|
metrics: PerformanceMetrics{},
|
|
}
|
|
}
|
|
|
|
func (mc *MetricsCollector) RecordRequest(duration time.Duration, hasError bool) {
|
|
mc.mu.Lock()
|
|
defer mc.mu.Unlock()
|
|
|
|
mc.metrics.RequestCount++
|
|
|
|
if hasError {
|
|
mc.metrics.ErrorCount++
|
|
}
|
|
|
|
if mc.metrics.RequestCount == 1 {
|
|
mc.metrics.AverageResponse = duration
|
|
} else {
|
|
|
|
totalDuration := int64(mc.metrics.AverageResponse) * (mc.metrics.RequestCount - 1)
|
|
totalDuration += int64(duration)
|
|
mc.metrics.AverageResponse = time.Duration(totalDuration / mc.metrics.RequestCount)
|
|
}
|
|
|
|
if duration > mc.metrics.MaxResponse {
|
|
mc.metrics.MaxResponse = duration
|
|
}
|
|
}
|
|
|
|
func (mc *MetricsCollector) GetMetrics() PerformanceMetrics {
|
|
mc.mu.RLock()
|
|
defer mc.mu.RUnlock()
|
|
|
|
mc.metrics.DBStats = mc.monitor.GetStats()
|
|
return mc.metrics
|
|
}
|
|
|
|
func MetricsMiddleware(collector *MetricsCollector) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
|
|
rw := &metricsResponseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
|
|
|
next.ServeHTTP(rw, r)
|
|
|
|
duration := time.Since(start)
|
|
hasError := rw.statusCode >= 400
|
|
collector.RecordRequest(duration, hasError)
|
|
})
|
|
}
|
|
}
|
|
|
|
type metricsResponseWriter struct {
|
|
http.ResponseWriter
|
|
statusCode int
|
|
}
|
|
|
|
func (rw *metricsResponseWriter) WriteHeader(code int) {
|
|
rw.statusCode = code
|
|
rw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func GetDBMonitorFromContext(ctx context.Context) (DBMonitor, bool) {
|
|
monitor, ok := ctx.Value(dbMonitorKey).(DBMonitor)
|
|
return monitor, ok
|
|
}
|
|
|
|
func GetSlowQueryThresholdFromContext(ctx context.Context) (time.Duration, bool) {
|
|
threshold, ok := ctx.Value(slowQueryThresholdKey).(time.Duration)
|
|
return threshold, ok
|
|
}
|