Files
goyco/internal/middleware/db_monitoring.go

278 lines
6.7 KiB
Go

package middleware
import (
"context"
"database/sql"
"net/http"
"sync"
"time"
)
const (
dbMonitorKey contextKey = "db_monitor"
slowQueryThresholdKey contextKey = "slow_query_threshold"
)
type DBMonitor interface {
LogQuery(query string, duration time.Duration, err error)
LogSlowQuery(query string, duration time.Duration, threshold time.Duration)
GetStats() DBStats
}
type DBStats struct {
TotalQueries int64 `json:"total_queries"`
SlowQueries int64 `json:"slow_queries"`
AverageDuration time.Duration `json:"average_duration"`
MaxDuration time.Duration `json:"max_duration"`
ErrorCount int64 `json:"error_count"`
LastQueryTime time.Time `json:"last_query_time"`
}
type InMemoryDBMonitor struct {
stats DBStats
mu sync.RWMutex
}
func NewInMemoryDBMonitor() *InMemoryDBMonitor {
return &InMemoryDBMonitor{
stats: DBStats{},
}
}
func (m *InMemoryDBMonitor) LogQuery(query string, duration time.Duration, err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.stats.TotalQueries++
m.stats.LastQueryTime = time.Now()
if err != nil {
m.stats.ErrorCount++
return
}
if m.stats.TotalQueries == 1 {
m.stats.AverageDuration = duration
} else {
totalDuration := int64(m.stats.AverageDuration) * (m.stats.TotalQueries - 1)
totalDuration += int64(duration)
m.stats.AverageDuration = time.Duration(totalDuration / m.stats.TotalQueries)
}
if duration > m.stats.MaxDuration {
m.stats.MaxDuration = duration
}
slowThreshold := 100 * time.Millisecond
if duration > slowThreshold {
m.stats.SlowQueries++
}
}
func (m *InMemoryDBMonitor) LogSlowQuery(query string, duration time.Duration, threshold time.Duration) {
m.mu.Lock()
defer m.mu.Unlock()
m.stats.SlowQueries++
}
func (m *InMemoryDBMonitor) GetStats() DBStats {
m.mu.RLock()
defer m.mu.RUnlock()
return m.stats
}
func DBMonitoringMiddleware(monitor DBMonitor, slowQueryThreshold time.Duration) func(http.Handler) http.Handler {
if slowQueryThreshold == 0 {
slowQueryThreshold = 100 * time.Millisecond
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
ctx := context.WithValue(r.Context(), dbMonitorKey, monitor)
ctx = context.WithValue(ctx, slowQueryThresholdKey, slowQueryThreshold)
next.ServeHTTP(w, r.WithContext(ctx))
duration := time.Since(start)
if duration > slowQueryThreshold {
monitor.LogSlowQuery(r.URL.Path, duration, slowQueryThreshold)
}
})
}
}
type QueryLogger struct {
DB *sql.DB
Monitor DBMonitor
}
func NewQueryLogger(db *sql.DB, monitor DBMonitor) *QueryLogger {
return &QueryLogger{
DB: db,
Monitor: monitor,
}
}
func (ql *QueryLogger) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
start := time.Now()
rows, err := ql.DB.QueryContext(ctx, query, args...)
duration := time.Since(start)
ql.Monitor.LogQuery(query, duration, err)
return rows, err
}
func (ql *QueryLogger) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
start := time.Now()
row := ql.DB.QueryRowContext(ctx, query, args...)
duration := time.Since(start)
ql.Monitor.LogQuery(query, duration, nil)
return row
}
func (ql *QueryLogger) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
start := time.Now()
result, err := ql.DB.ExecContext(ctx, query, args...)
duration := time.Since(start)
ql.Monitor.LogQuery(query, duration, err)
return result, err
}
type DatabaseHealthChecker struct {
DB *sql.DB
Monitor DBMonitor
}
func NewDatabaseHealthChecker(db *sql.DB, monitor DBMonitor) *DatabaseHealthChecker {
return &DatabaseHealthChecker{
DB: db,
Monitor: monitor,
}
}
func (dhc *DatabaseHealthChecker) CheckHealth() map[string]any {
start := time.Now()
err := dhc.DB.Ping()
duration := time.Since(start)
health := map[string]any{
"status": "healthy",
"timestamp": time.Now().UTC().Format(time.RFC3339),
"ping_time": duration.String(),
}
if err != nil {
health["status"] = "unhealthy"
health["error"] = err.Error()
return health
}
stats := dhc.Monitor.GetStats()
health["database_stats"] = map[string]any{
"total_queries": stats.TotalQueries,
"slow_queries": stats.SlowQueries,
"average_duration": stats.AverageDuration.String(),
"max_duration": stats.MaxDuration.String(),
"error_count": stats.ErrorCount,
"last_query_time": stats.LastQueryTime.Format(time.RFC3339),
}
return health
}
type PerformanceMetrics struct {
RequestCount int64 `json:"request_count"`
AverageResponse time.Duration `json:"average_response"`
MaxResponse time.Duration `json:"max_response"`
ErrorCount int64 `json:"error_count"`
DBStats DBStats `json:"database_stats"`
}
type MetricsCollector struct {
monitor DBMonitor
metrics PerformanceMetrics
mu sync.RWMutex
}
func NewMetricsCollector(monitor DBMonitor) *MetricsCollector {
return &MetricsCollector{
monitor: monitor,
metrics: PerformanceMetrics{},
}
}
func (mc *MetricsCollector) RecordRequest(duration time.Duration, hasError bool) {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.metrics.RequestCount++
if hasError {
mc.metrics.ErrorCount++
}
if mc.metrics.RequestCount == 1 {
mc.metrics.AverageResponse = duration
} else {
totalDuration := int64(mc.metrics.AverageResponse) * (mc.metrics.RequestCount - 1)
totalDuration += int64(duration)
mc.metrics.AverageResponse = time.Duration(totalDuration / mc.metrics.RequestCount)
}
if duration > mc.metrics.MaxResponse {
mc.metrics.MaxResponse = duration
}
}
func (mc *MetricsCollector) GetMetrics() PerformanceMetrics {
mc.mu.RLock()
defer mc.mu.RUnlock()
mc.metrics.DBStats = mc.monitor.GetStats()
return mc.metrics
}
func MetricsMiddleware(collector *MetricsCollector) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
rw := &metricsResponseWriter{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(rw, r)
duration := time.Since(start)
hasError := rw.statusCode >= 400
collector.RecordRequest(duration, hasError)
})
}
}
type metricsResponseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *metricsResponseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
func GetDBMonitorFromContext(ctx context.Context) (DBMonitor, bool) {
monitor, ok := ctx.Value(dbMonitorKey).(DBMonitor)
return monitor, ok
}
func GetSlowQueryThresholdFromContext(ctx context.Context) (time.Duration, bool) {
threshold, ok := ctx.Value(slowQueryThresholdKey).(time.Duration)
return threshold, ok
}