To gitea and beyond, let's go(-yco)
This commit is contained in:
422
internal/middleware/db_monitoring_test.go
Normal file
422
internal/middleware/db_monitoring_test.go
Normal file
@@ -0,0 +1,422 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestInMemoryDBMonitor(t *testing.T) {
|
||||
monitor := NewInMemoryDBMonitor()
|
||||
|
||||
stats := monitor.GetStats()
|
||||
if stats.TotalQueries != 0 {
|
||||
t.Errorf("Expected 0 total queries, got %d", stats.TotalQueries)
|
||||
}
|
||||
|
||||
monitor.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil)
|
||||
stats = monitor.GetStats()
|
||||
if stats.TotalQueries != 1 {
|
||||
t.Errorf("Expected 1 total query, got %d", stats.TotalQueries)
|
||||
}
|
||||
if stats.AverageDuration != 50*time.Millisecond {
|
||||
t.Errorf("Expected average duration 50ms, got %v", stats.AverageDuration)
|
||||
}
|
||||
if stats.MaxDuration != 50*time.Millisecond {
|
||||
t.Errorf("Expected max duration 50ms, got %v", stats.MaxDuration)
|
||||
}
|
||||
|
||||
monitor.LogQuery("SELECT * FROM posts", 150*time.Millisecond, nil)
|
||||
stats = monitor.GetStats()
|
||||
if stats.TotalQueries != 2 {
|
||||
t.Errorf("Expected 2 total queries, got %d", stats.TotalQueries)
|
||||
}
|
||||
if stats.SlowQueries != 1 {
|
||||
t.Errorf("Expected 1 slow query, got %d", stats.SlowQueries)
|
||||
}
|
||||
|
||||
monitor.LogQuery("SELECT * FROM invalid", 10*time.Millisecond, sql.ErrNoRows)
|
||||
stats = monitor.GetStats()
|
||||
if stats.TotalQueries != 3 {
|
||||
t.Errorf("Expected 3 total queries, got %d", stats.TotalQueries)
|
||||
}
|
||||
if stats.ErrorCount != 1 {
|
||||
t.Errorf("Expected 1 error, got %d", stats.ErrorCount)
|
||||
}
|
||||
|
||||
expectedAvg := time.Duration((int64(50*time.Millisecond) + int64(150*time.Millisecond)) / 2)
|
||||
if stats.AverageDuration != expectedAvg {
|
||||
t.Errorf("Expected average duration %v, got %v", expectedAvg, stats.AverageDuration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryLogger(t *testing.T) {
|
||||
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test table: %v", err)
|
||||
}
|
||||
|
||||
monitor := NewInMemoryDBMonitor()
|
||||
logger := NewQueryLogger(db, monitor)
|
||||
|
||||
ctx := context.Background()
|
||||
rows, err := logger.QueryContext(ctx, "SELECT * FROM users")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if rows == nil {
|
||||
t.Fatal("Expected rows, got nil")
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
stats := monitor.GetStats()
|
||||
if stats.TotalQueries != 1 {
|
||||
t.Errorf("Expected 1 total query, got %d", stats.TotalQueries)
|
||||
}
|
||||
|
||||
row := logger.QueryRowContext(ctx, "SELECT * FROM users WHERE id = ?", 1)
|
||||
if row == nil {
|
||||
t.Fatal("Expected row, got nil")
|
||||
}
|
||||
|
||||
stats = monitor.GetStats()
|
||||
if stats.TotalQueries != 2 {
|
||||
t.Errorf("Expected 2 total queries, got %d", stats.TotalQueries)
|
||||
}
|
||||
|
||||
_, err = logger.ExecContext(ctx, "INSERT INTO users (name) VALUES (?)", "test")
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for INSERT into non-existent table")
|
||||
}
|
||||
|
||||
stats = monitor.GetStats()
|
||||
if stats.TotalQueries != 3 {
|
||||
t.Errorf("Expected 3 total queries, got %d", stats.TotalQueries)
|
||||
}
|
||||
if stats.ErrorCount != 1 {
|
||||
t.Errorf("Expected 1 error, got %d", stats.ErrorCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseHealthChecker(t *testing.T) {
|
||||
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
monitor := NewInMemoryDBMonitor()
|
||||
checker := NewDatabaseHealthChecker(db, monitor)
|
||||
|
||||
health := checker.CheckHealth()
|
||||
if health["status"] != "healthy" {
|
||||
t.Errorf("Expected healthy status, got %v", health["status"])
|
||||
}
|
||||
if health["ping_time"] == nil {
|
||||
t.Error("Expected ping_time to be present")
|
||||
}
|
||||
|
||||
monitor.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil)
|
||||
monitor.LogQuery("SELECT * FROM posts", 150*time.Millisecond, nil)
|
||||
|
||||
health = checker.CheckHealth()
|
||||
if health["database_stats"] == nil {
|
||||
t.Error("Expected database_stats to be present")
|
||||
}
|
||||
|
||||
stats, ok := health["database_stats"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Expected database_stats to be a map")
|
||||
}
|
||||
|
||||
if stats["total_queries"] != int64(2) {
|
||||
t.Errorf("Expected 2 total queries, got %v", stats["total_queries"])
|
||||
}
|
||||
if stats["slow_queries"] != int64(1) {
|
||||
t.Errorf("Expected 1 slow query, got %v", stats["slow_queries"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsCollector(t *testing.T) {
|
||||
monitor := NewInMemoryDBMonitor()
|
||||
collector := NewMetricsCollector(monitor)
|
||||
|
||||
metrics := collector.GetMetrics()
|
||||
if metrics.RequestCount != 0 {
|
||||
t.Errorf("Expected 0 requests, got %d", metrics.RequestCount)
|
||||
}
|
||||
|
||||
collector.RecordRequest(100*time.Millisecond, false)
|
||||
collector.RecordRequest(200*time.Millisecond, false)
|
||||
collector.RecordRequest(50*time.Millisecond, true)
|
||||
|
||||
metrics = collector.GetMetrics()
|
||||
if metrics.RequestCount != 3 {
|
||||
t.Errorf("Expected 3 requests, got %d", metrics.RequestCount)
|
||||
}
|
||||
if metrics.ErrorCount != 1 {
|
||||
t.Errorf("Expected 1 error, got %d", metrics.ErrorCount)
|
||||
}
|
||||
if metrics.MaxResponse != 200*time.Millisecond {
|
||||
t.Errorf("Expected max response 200ms, got %v", metrics.MaxResponse)
|
||||
}
|
||||
|
||||
expectedAvg := time.Duration((int64(100*time.Millisecond) + int64(200*time.Millisecond) + int64(50*time.Millisecond)) / 3)
|
||||
if metrics.AverageResponse != expectedAvg {
|
||||
t.Errorf("Expected average response %v, got %v", expectedAvg, metrics.AverageResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsMiddleware(t *testing.T) {
|
||||
monitor := NewInMemoryDBMonitor()
|
||||
collector := NewMetricsCollector(monitor)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("test response"))
|
||||
})
|
||||
|
||||
middleware := MetricsMiddleware(collector)
|
||||
wrappedHandler := middleware(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
metrics := collector.GetMetrics()
|
||||
if metrics.RequestCount != 1 {
|
||||
t.Errorf("Expected 1 request, got %d", metrics.RequestCount)
|
||||
}
|
||||
if metrics.ErrorCount != 0 {
|
||||
t.Errorf("Expected 0 errors, got %d", metrics.ErrorCount)
|
||||
}
|
||||
|
||||
errorHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("error"))
|
||||
})
|
||||
|
||||
errorMiddleware := MetricsMiddleware(collector)
|
||||
errorWrappedHandler := errorMiddleware(errorHandler)
|
||||
|
||||
req = httptest.NewRequest("GET", "/error", nil)
|
||||
w = httptest.NewRecorder()
|
||||
errorWrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status 500, got %d", w.Code)
|
||||
}
|
||||
|
||||
metrics = collector.GetMetrics()
|
||||
if metrics.RequestCount != 2 {
|
||||
t.Errorf("Expected 2 requests, got %d", metrics.RequestCount)
|
||||
}
|
||||
if metrics.ErrorCount != 1 {
|
||||
t.Errorf("Expected 1 error, got %d", metrics.ErrorCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBMonitoringMiddleware(t *testing.T) {
|
||||
monitor := NewInMemoryDBMonitor()
|
||||
threshold := 50 * time.Millisecond
|
||||
|
||||
var capturedCtx context.Context
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCtx = r.Context()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("test response"))
|
||||
})
|
||||
|
||||
middleware := DBMonitoringMiddleware(monitor, threshold)
|
||||
wrappedHandler := middleware(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if capturedCtx == nil {
|
||||
t.Fatal("Expected context to be captured")
|
||||
}
|
||||
if capturedCtx.Value(dbMonitorKey) == nil {
|
||||
t.Error("Expected dbMonitorKey to be set in context")
|
||||
}
|
||||
if capturedCtx.Value(slowQueryThresholdKey) == nil {
|
||||
t.Error("Expected slowQueryThresholdKey to be set in context")
|
||||
}
|
||||
|
||||
actualThreshold := capturedCtx.Value(slowQueryThresholdKey).(time.Duration)
|
||||
if actualThreshold != threshold {
|
||||
t.Errorf("Expected threshold %v, got %v", threshold, actualThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsResponseWriter(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
writer := &metricsResponseWriter{
|
||||
ResponseWriter: recorder,
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
|
||||
writer.WriteHeader(http.StatusNotFound)
|
||||
if writer.statusCode != http.StatusNotFound {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusNotFound, writer.statusCode)
|
||||
}
|
||||
|
||||
if recorder.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected underlying writer to receive status %d, got %d", http.StatusNotFound, recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlowQueryThreshold(t *testing.T) {
|
||||
monitor := NewInMemoryDBMonitor()
|
||||
|
||||
monitor.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil)
|
||||
monitor.LogQuery("SELECT * FROM posts", 150*time.Millisecond, nil)
|
||||
monitor.LogQuery("SELECT * FROM comments", 200*time.Millisecond, nil)
|
||||
|
||||
stats := monitor.GetStats()
|
||||
if stats.SlowQueries != 2 {
|
||||
t.Errorf("Expected 2 slow queries with default 100ms threshold, got %d", stats.SlowQueries)
|
||||
}
|
||||
|
||||
monitor2 := NewInMemoryDBMonitor()
|
||||
monitor2.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil)
|
||||
monitor2.LogQuery("SELECT * FROM posts", 150*time.Millisecond, nil)
|
||||
|
||||
stats2 := monitor2.GetStats()
|
||||
if stats2.SlowQueries != 1 {
|
||||
t.Errorf("Expected 1 slow query with default 100ms threshold, got %d", stats2.SlowQueries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
monitor := NewInMemoryDBMonitor()
|
||||
collector := NewMetricsCollector(monitor)
|
||||
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
monitor.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil)
|
||||
collector.RecordRequest(100*time.Millisecond, false)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
stats := monitor.GetStats()
|
||||
if stats.TotalQueries != 10 {
|
||||
t.Errorf("Expected 10 total queries, got %d", stats.TotalQueries)
|
||||
}
|
||||
|
||||
metrics := collector.GetMetrics()
|
||||
if metrics.RequestCount != 10 {
|
||||
t.Errorf("Expected 10 requests, got %d", metrics.RequestCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextHelpers(t *testing.T) {
|
||||
monitor := NewInMemoryDBMonitor()
|
||||
threshold := 200 * time.Millisecond
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, dbMonitorKey, monitor)
|
||||
ctx = context.WithValue(ctx, slowQueryThresholdKey, threshold)
|
||||
|
||||
retrievedMonitor, ok := GetDBMonitorFromContext(ctx)
|
||||
if !ok {
|
||||
t.Error("Expected to retrieve monitor from context")
|
||||
}
|
||||
if retrievedMonitor != monitor {
|
||||
t.Error("Expected retrieved monitor to match original")
|
||||
}
|
||||
|
||||
retrievedThreshold, ok := GetSlowQueryThresholdFromContext(ctx)
|
||||
if !ok {
|
||||
t.Error("Expected to retrieve threshold from context")
|
||||
}
|
||||
if retrievedThreshold != threshold {
|
||||
t.Errorf("Expected threshold %v, got %v", threshold, retrievedThreshold)
|
||||
}
|
||||
|
||||
emptyCtx := context.Background()
|
||||
_, ok = GetDBMonitorFromContext(emptyCtx)
|
||||
if ok {
|
||||
t.Error("Expected not to retrieve monitor from empty context")
|
||||
}
|
||||
|
||||
_, ok = GetSlowQueryThresholdFromContext(emptyCtx)
|
||||
if ok {
|
||||
t.Error("Expected not to retrieve threshold from empty context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestThreadSafety(t *testing.T) {
|
||||
monitor := NewInMemoryDBMonitor()
|
||||
collector := NewMetricsCollector(monitor)
|
||||
|
||||
numGoroutines := 100
|
||||
done := make(chan bool, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
|
||||
if id%2 == 0 {
|
||||
monitor.LogQuery("SELECT * FROM users", time.Duration(id)*time.Millisecond, nil)
|
||||
collector.RecordRequest(time.Duration(id)*time.Millisecond, false)
|
||||
} else {
|
||||
monitor.LogQuery("SELECT * FROM users", time.Duration(id)*time.Millisecond, sql.ErrNoRows)
|
||||
collector.RecordRequest(time.Duration(id)*time.Millisecond, true)
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
stats := monitor.GetStats()
|
||||
if stats.TotalQueries != int64(numGoroutines) {
|
||||
t.Errorf("Expected %d total queries, got %d", numGoroutines, stats.TotalQueries)
|
||||
}
|
||||
|
||||
metrics := collector.GetMetrics()
|
||||
if metrics.RequestCount != int64(numGoroutines) {
|
||||
t.Errorf("Expected %d requests, got %d", numGoroutines, metrics.RequestCount)
|
||||
}
|
||||
|
||||
expectedErrors := int64(numGoroutines / 2)
|
||||
if stats.ErrorCount != expectedErrors {
|
||||
t.Errorf("Expected %d errors, got %d", expectedErrors, stats.ErrorCount)
|
||||
}
|
||||
if metrics.ErrorCount != expectedErrors {
|
||||
t.Errorf("Expected %d request errors, got %d", expectedErrors, metrics.ErrorCount)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user