423 lines
12 KiB
Go
423 lines
12 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
func TestInMemoryDBMonitor(t *testing.T) {
|
|
monitor := NewInMemoryDBMonitor()
|
|
|
|
stats := monitor.GetStats()
|
|
if stats.TotalQueries != 0 {
|
|
t.Errorf("Expected 0 total queries, got %d", stats.TotalQueries)
|
|
}
|
|
|
|
monitor.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil)
|
|
stats = monitor.GetStats()
|
|
if stats.TotalQueries != 1 {
|
|
t.Errorf("Expected 1 total query, got %d", stats.TotalQueries)
|
|
}
|
|
if stats.AverageDuration != 50*time.Millisecond {
|
|
t.Errorf("Expected average duration 50ms, got %v", stats.AverageDuration)
|
|
}
|
|
if stats.MaxDuration != 50*time.Millisecond {
|
|
t.Errorf("Expected max duration 50ms, got %v", stats.MaxDuration)
|
|
}
|
|
|
|
monitor.LogQuery("SELECT * FROM posts", 150*time.Millisecond, nil)
|
|
stats = monitor.GetStats()
|
|
if stats.TotalQueries != 2 {
|
|
t.Errorf("Expected 2 total queries, got %d", stats.TotalQueries)
|
|
}
|
|
if stats.SlowQueries != 1 {
|
|
t.Errorf("Expected 1 slow query, got %d", stats.SlowQueries)
|
|
}
|
|
|
|
monitor.LogQuery("SELECT * FROM invalid", 10*time.Millisecond, sql.ErrNoRows)
|
|
stats = monitor.GetStats()
|
|
if stats.TotalQueries != 3 {
|
|
t.Errorf("Expected 3 total queries, got %d", stats.TotalQueries)
|
|
}
|
|
if stats.ErrorCount != 1 {
|
|
t.Errorf("Expected 1 error, got %d", stats.ErrorCount)
|
|
}
|
|
|
|
expectedAvg := time.Duration((int64(50*time.Millisecond) + int64(150*time.Millisecond)) / 2)
|
|
if stats.AverageDuration != expectedAvg {
|
|
t.Errorf("Expected average duration %v, got %v", expectedAvg, stats.AverageDuration)
|
|
}
|
|
}
|
|
|
|
func TestQueryLogger(t *testing.T) {
|
|
|
|
db, err := sql.Open("sqlite3", ":memory:")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
_, err = db.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test table: %v", err)
|
|
}
|
|
|
|
monitor := NewInMemoryDBMonitor()
|
|
logger := NewQueryLogger(db, monitor)
|
|
|
|
ctx := context.Background()
|
|
rows, err := logger.QueryContext(ctx, "SELECT * FROM users")
|
|
if err != nil {
|
|
t.Fatalf("Expected no error, got %v", err)
|
|
}
|
|
if rows == nil {
|
|
t.Fatal("Expected rows, got nil")
|
|
}
|
|
rows.Close()
|
|
|
|
stats := monitor.GetStats()
|
|
if stats.TotalQueries != 1 {
|
|
t.Errorf("Expected 1 total query, got %d", stats.TotalQueries)
|
|
}
|
|
|
|
row := logger.QueryRowContext(ctx, "SELECT * FROM users WHERE id = ?", 1)
|
|
if row == nil {
|
|
t.Fatal("Expected row, got nil")
|
|
}
|
|
|
|
stats = monitor.GetStats()
|
|
if stats.TotalQueries != 2 {
|
|
t.Errorf("Expected 2 total queries, got %d", stats.TotalQueries)
|
|
}
|
|
|
|
_, err = logger.ExecContext(ctx, "INSERT INTO users (name) VALUES (?)", "test")
|
|
|
|
if err == nil {
|
|
t.Fatal("Expected error for INSERT into non-existent table")
|
|
}
|
|
|
|
stats = monitor.GetStats()
|
|
if stats.TotalQueries != 3 {
|
|
t.Errorf("Expected 3 total queries, got %d", stats.TotalQueries)
|
|
}
|
|
if stats.ErrorCount != 1 {
|
|
t.Errorf("Expected 1 error, got %d", stats.ErrorCount)
|
|
}
|
|
}
|
|
|
|
func TestDatabaseHealthChecker(t *testing.T) {
|
|
|
|
db, err := sql.Open("sqlite3", ":memory:")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
monitor := NewInMemoryDBMonitor()
|
|
checker := NewDatabaseHealthChecker(db, monitor)
|
|
|
|
health := checker.CheckHealth()
|
|
if health["status"] != "healthy" {
|
|
t.Errorf("Expected healthy status, got %v", health["status"])
|
|
}
|
|
if health["ping_time"] == nil {
|
|
t.Error("Expected ping_time to be present")
|
|
}
|
|
|
|
monitor.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil)
|
|
monitor.LogQuery("SELECT * FROM posts", 150*time.Millisecond, nil)
|
|
|
|
health = checker.CheckHealth()
|
|
if health["database_stats"] == nil {
|
|
t.Error("Expected database_stats to be present")
|
|
}
|
|
|
|
stats, ok := health["database_stats"].(map[string]any)
|
|
if !ok {
|
|
t.Fatal("Expected database_stats to be a map")
|
|
}
|
|
|
|
if stats["total_queries"] != int64(2) {
|
|
t.Errorf("Expected 2 total queries, got %v", stats["total_queries"])
|
|
}
|
|
if stats["slow_queries"] != int64(1) {
|
|
t.Errorf("Expected 1 slow query, got %v", stats["slow_queries"])
|
|
}
|
|
}
|
|
|
|
func TestMetricsCollector(t *testing.T) {
|
|
monitor := NewInMemoryDBMonitor()
|
|
collector := NewMetricsCollector(monitor)
|
|
|
|
metrics := collector.GetMetrics()
|
|
if metrics.RequestCount != 0 {
|
|
t.Errorf("Expected 0 requests, got %d", metrics.RequestCount)
|
|
}
|
|
|
|
collector.RecordRequest(100*time.Millisecond, false)
|
|
collector.RecordRequest(200*time.Millisecond, false)
|
|
collector.RecordRequest(50*time.Millisecond, true)
|
|
|
|
metrics = collector.GetMetrics()
|
|
if metrics.RequestCount != 3 {
|
|
t.Errorf("Expected 3 requests, got %d", metrics.RequestCount)
|
|
}
|
|
if metrics.ErrorCount != 1 {
|
|
t.Errorf("Expected 1 error, got %d", metrics.ErrorCount)
|
|
}
|
|
if metrics.MaxResponse != 200*time.Millisecond {
|
|
t.Errorf("Expected max response 200ms, got %v", metrics.MaxResponse)
|
|
}
|
|
|
|
expectedAvg := time.Duration((int64(100*time.Millisecond) + int64(200*time.Millisecond) + int64(50*time.Millisecond)) / 3)
|
|
if metrics.AverageResponse != expectedAvg {
|
|
t.Errorf("Expected average response %v, got %v", expectedAvg, metrics.AverageResponse)
|
|
}
|
|
}
|
|
|
|
func TestMetricsMiddleware(t *testing.T) {
|
|
monitor := NewInMemoryDBMonitor()
|
|
collector := NewMetricsCollector(monitor)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("test response"))
|
|
})
|
|
|
|
middleware := MetricsMiddleware(collector)
|
|
wrappedHandler := middleware(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
w := httptest.NewRecorder()
|
|
wrappedHandler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
metrics := collector.GetMetrics()
|
|
if metrics.RequestCount != 1 {
|
|
t.Errorf("Expected 1 request, got %d", metrics.RequestCount)
|
|
}
|
|
if metrics.ErrorCount != 0 {
|
|
t.Errorf("Expected 0 errors, got %d", metrics.ErrorCount)
|
|
}
|
|
|
|
errorHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
w.Write([]byte("error"))
|
|
})
|
|
|
|
errorMiddleware := MetricsMiddleware(collector)
|
|
errorWrappedHandler := errorMiddleware(errorHandler)
|
|
|
|
req = httptest.NewRequest("GET", "/error", nil)
|
|
w = httptest.NewRecorder()
|
|
errorWrappedHandler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusInternalServerError {
|
|
t.Errorf("Expected status 500, got %d", w.Code)
|
|
}
|
|
|
|
metrics = collector.GetMetrics()
|
|
if metrics.RequestCount != 2 {
|
|
t.Errorf("Expected 2 requests, got %d", metrics.RequestCount)
|
|
}
|
|
if metrics.ErrorCount != 1 {
|
|
t.Errorf("Expected 1 error, got %d", metrics.ErrorCount)
|
|
}
|
|
}
|
|
|
|
func TestDBMonitoringMiddleware(t *testing.T) {
|
|
monitor := NewInMemoryDBMonitor()
|
|
threshold := 50 * time.Millisecond
|
|
|
|
var capturedCtx context.Context
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
capturedCtx = r.Context()
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("test response"))
|
|
})
|
|
|
|
middleware := DBMonitoringMiddleware(monitor, threshold)
|
|
wrappedHandler := middleware(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
w := httptest.NewRecorder()
|
|
wrappedHandler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if capturedCtx == nil {
|
|
t.Fatal("Expected context to be captured")
|
|
}
|
|
if capturedCtx.Value(dbMonitorKey) == nil {
|
|
t.Error("Expected dbMonitorKey to be set in context")
|
|
}
|
|
if capturedCtx.Value(slowQueryThresholdKey) == nil {
|
|
t.Error("Expected slowQueryThresholdKey to be set in context")
|
|
}
|
|
|
|
actualThreshold := capturedCtx.Value(slowQueryThresholdKey).(time.Duration)
|
|
if actualThreshold != threshold {
|
|
t.Errorf("Expected threshold %v, got %v", threshold, actualThreshold)
|
|
}
|
|
}
|
|
|
|
func TestMetricsResponseWriter(t *testing.T) {
|
|
recorder := httptest.NewRecorder()
|
|
writer := &metricsResponseWriter{
|
|
ResponseWriter: recorder,
|
|
statusCode: http.StatusOK,
|
|
}
|
|
|
|
writer.WriteHeader(http.StatusNotFound)
|
|
if writer.statusCode != http.StatusNotFound {
|
|
t.Errorf("Expected status code %d, got %d", http.StatusNotFound, writer.statusCode)
|
|
}
|
|
|
|
if recorder.Code != http.StatusNotFound {
|
|
t.Errorf("Expected underlying writer to receive status %d, got %d", http.StatusNotFound, recorder.Code)
|
|
}
|
|
}
|
|
|
|
func TestSlowQueryThreshold(t *testing.T) {
|
|
monitor := NewInMemoryDBMonitor()
|
|
|
|
monitor.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil)
|
|
monitor.LogQuery("SELECT * FROM posts", 150*time.Millisecond, nil)
|
|
monitor.LogQuery("SELECT * FROM comments", 200*time.Millisecond, nil)
|
|
|
|
stats := monitor.GetStats()
|
|
if stats.SlowQueries != 2 {
|
|
t.Errorf("Expected 2 slow queries with default 100ms threshold, got %d", stats.SlowQueries)
|
|
}
|
|
|
|
monitor2 := NewInMemoryDBMonitor()
|
|
monitor2.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil)
|
|
monitor2.LogQuery("SELECT * FROM posts", 150*time.Millisecond, nil)
|
|
|
|
stats2 := monitor2.GetStats()
|
|
if stats2.SlowQueries != 1 {
|
|
t.Errorf("Expected 1 slow query with default 100ms threshold, got %d", stats2.SlowQueries)
|
|
}
|
|
}
|
|
|
|
func TestConcurrentAccess(t *testing.T) {
|
|
monitor := NewInMemoryDBMonitor()
|
|
collector := NewMetricsCollector(monitor)
|
|
|
|
done := make(chan bool, 10)
|
|
for i := 0; i < 10; i++ {
|
|
go func() {
|
|
monitor.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil)
|
|
collector.RecordRequest(100*time.Millisecond, false)
|
|
done <- true
|
|
}()
|
|
}
|
|
|
|
for i := 0; i < 10; i++ {
|
|
<-done
|
|
}
|
|
|
|
stats := monitor.GetStats()
|
|
if stats.TotalQueries != 10 {
|
|
t.Errorf("Expected 10 total queries, got %d", stats.TotalQueries)
|
|
}
|
|
|
|
metrics := collector.GetMetrics()
|
|
if metrics.RequestCount != 10 {
|
|
t.Errorf("Expected 10 requests, got %d", metrics.RequestCount)
|
|
}
|
|
}
|
|
|
|
func TestContextHelpers(t *testing.T) {
|
|
monitor := NewInMemoryDBMonitor()
|
|
threshold := 200 * time.Millisecond
|
|
|
|
ctx := context.Background()
|
|
ctx = context.WithValue(ctx, dbMonitorKey, monitor)
|
|
ctx = context.WithValue(ctx, slowQueryThresholdKey, threshold)
|
|
|
|
retrievedMonitor, ok := GetDBMonitorFromContext(ctx)
|
|
if !ok {
|
|
t.Error("Expected to retrieve monitor from context")
|
|
}
|
|
if retrievedMonitor != monitor {
|
|
t.Error("Expected retrieved monitor to match original")
|
|
}
|
|
|
|
retrievedThreshold, ok := GetSlowQueryThresholdFromContext(ctx)
|
|
if !ok {
|
|
t.Error("Expected to retrieve threshold from context")
|
|
}
|
|
if retrievedThreshold != threshold {
|
|
t.Errorf("Expected threshold %v, got %v", threshold, retrievedThreshold)
|
|
}
|
|
|
|
emptyCtx := context.Background()
|
|
_, ok = GetDBMonitorFromContext(emptyCtx)
|
|
if ok {
|
|
t.Error("Expected not to retrieve monitor from empty context")
|
|
}
|
|
|
|
_, ok = GetSlowQueryThresholdFromContext(emptyCtx)
|
|
if ok {
|
|
t.Error("Expected not to retrieve threshold from empty context")
|
|
}
|
|
}
|
|
|
|
func TestThreadSafety(t *testing.T) {
|
|
monitor := NewInMemoryDBMonitor()
|
|
collector := NewMetricsCollector(monitor)
|
|
|
|
numGoroutines := 100
|
|
done := make(chan bool, numGoroutines)
|
|
|
|
for i := 0; i < numGoroutines; i++ {
|
|
go func(id int) {
|
|
|
|
if id%2 == 0 {
|
|
monitor.LogQuery("SELECT * FROM users", time.Duration(id)*time.Millisecond, nil)
|
|
collector.RecordRequest(time.Duration(id)*time.Millisecond, false)
|
|
} else {
|
|
monitor.LogQuery("SELECT * FROM users", time.Duration(id)*time.Millisecond, sql.ErrNoRows)
|
|
collector.RecordRequest(time.Duration(id)*time.Millisecond, true)
|
|
}
|
|
done <- true
|
|
}(i)
|
|
}
|
|
|
|
for i := 0; i < numGoroutines; i++ {
|
|
<-done
|
|
}
|
|
|
|
stats := monitor.GetStats()
|
|
if stats.TotalQueries != int64(numGoroutines) {
|
|
t.Errorf("Expected %d total queries, got %d", numGoroutines, stats.TotalQueries)
|
|
}
|
|
|
|
metrics := collector.GetMetrics()
|
|
if metrics.RequestCount != int64(numGoroutines) {
|
|
t.Errorf("Expected %d requests, got %d", numGoroutines, metrics.RequestCount)
|
|
}
|
|
|
|
expectedErrors := int64(numGoroutines / 2)
|
|
if stats.ErrorCount != expectedErrors {
|
|
t.Errorf("Expected %d errors, got %d", expectedErrors, stats.ErrorCount)
|
|
}
|
|
if metrics.ErrorCount != expectedErrors {
|
|
t.Errorf("Expected %d request errors, got %d", expectedErrors, metrics.ErrorCount)
|
|
}
|
|
}
|