package middleware import ( "bytes" "context" "log" "net/http" "net/http/httptest" "strings" "testing" "time" ) func TestNewSecurityLogger(t *testing.T) { logger := NewSecurityLogger() if logger == nil { t.Fatal("NewSecurityLogger should not return nil") } if logger.logger == nil { t.Fatal("SecurityLogger should have a logger instance") } } func TestSecurityLogger_LogSecurityEvent(t *testing.T) { var buf bytes.Buffer logger := &SecurityLogger{ logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile), } event := SecurityEvent{ Type: "Test Event", IP: "192.168.1.1", UserAgent: "Test Agent", Path: "/test", Method: "GET", UserID: 123, Details: "Test details", Timestamp: time.Now(), Severity: "INFO", } logger.LogSecurityEvent(event) logOutput := buf.String() expectedParts := []string{ "[INFO]", "192.168.1.1", "GET", "/test", "Test Agent", "UserID: 123", "Test Event", "Test details", } for _, part := range expectedParts { if !strings.Contains(logOutput, part) { t.Errorf("Expected log output to contain %q, got %q", part, logOutput) } } } func TestSecurityLoggingMiddleware_ClientError(t *testing.T) { var buf bytes.Buffer logger := &SecurityLogger{ logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile), } handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) })) request := httptest.NewRequest("GET", "/test", nil) recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, request) logOutput := buf.String() expectedParts := []string{ "[WARN]", "Client Error", "Client error response", } for _, part := range expectedParts { if !strings.Contains(logOutput, part) { t.Errorf("Expected log output to contain %q, got %q", part, logOutput) } } } func TestSecurityLoggingMiddleware_ServerError(t *testing.T) { var buf bytes.Buffer logger := &SecurityLogger{ logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile), } handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) request := httptest.NewRequest("GET", "/test", nil) recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, request) logOutput := buf.String() expectedParts := []string{ "[ERROR]", "Server Error", "Server error response", } for _, part := range expectedParts { if !strings.Contains(logOutput, part) { t.Errorf("Expected log output to contain %q, got %q", part, logOutput) } } } func TestSecurityLoggingMiddleware_Authentication(t *testing.T) { var buf bytes.Buffer logger := &SecurityLogger{ logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile), } handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) request := httptest.NewRequest("POST", "/api/auth/login", nil) recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, request) logOutput := buf.String() expectedParts := []string{ "[INFO]", "Authentication", "Authentication endpoint accessed", } for _, part := range expectedParts { if !strings.Contains(logOutput, part) { t.Errorf("Expected log output to contain %q, got %q", part, logOutput) } } } func TestSecurityLoggingMiddleware_PostCreation(t *testing.T) { var buf bytes.Buffer logger := &SecurityLogger{ logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile), } handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) request := httptest.NewRequest("POST", "/api/posts/", nil) recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, request) logOutput := buf.String() expectedParts := []string{ "[INFO]", "Post Creation", "Post creation attempt", } for _, part := range expectedParts { if !strings.Contains(logOutput, part) { t.Errorf("Expected log output to contain %q, got %q", part, logOutput) } } } func TestSecurityLoggingMiddleware_PostModification(t *testing.T) { var buf bytes.Buffer logger := &SecurityLogger{ logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile), } handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) request := httptest.NewRequest("PUT", "/api/posts/1", nil) recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, request) logOutput := buf.String() expectedParts := []string{ "[INFO]", "Post Modification", "Post modification attempt", } for _, part := range expectedParts { if !strings.Contains(logOutput, part) { t.Errorf("Expected log output to contain %q, got %q", part, logOutput) } } buf.Reset() request = httptest.NewRequest("DELETE", "/api/posts/1", nil) recorder = httptest.NewRecorder() handler.ServeHTTP(recorder, request) logOutput = buf.String() for _, part := range expectedParts { if !strings.Contains(logOutput, part) { t.Errorf("Expected log output to contain %q, got %q", part, logOutput) } } } func TestSecurityLoggingMiddleware_APIAccess(t *testing.T) { var buf bytes.Buffer logger := &SecurityLogger{ logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile), } handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) request := httptest.NewRequest("GET", "/api/users", nil) recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, request) logOutput := buf.String() expectedParts := []string{ "[INFO]", "API Access", "API endpoint accessed", } for _, part := range expectedParts { if !strings.Contains(logOutput, part) { t.Errorf("Expected log output to contain %q, got %q", part, logOutput) } } } func TestSecurityLoggingMiddleware_WithUserID(t *testing.T) { var buf bytes.Buffer logger := &SecurityLogger{ logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile), } handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) request := httptest.NewRequest("GET", "/test", nil) request = request.WithContext(context.WithValue(request.Context(), UserIDKey, uint(456))) recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, request) logOutput := buf.String() if !strings.Contains(logOutput, "UserID: 456") { t.Errorf("Expected log output to contain UserID: 456, got %q", logOutput) } } func TestGetClientIP(t *testing.T) { originalTrust := TrustProxyHeaders defer func() { TrustProxyHeaders = originalTrust }() tests := []struct { name string headers map[string]string remoteAddr string trustProxyHeaders bool expectedIP string }{ { name: "Default: RemoteAddr when TrustProxyHeaders is false", headers: map[string]string{"X-Forwarded-For": "192.168.1.100"}, remoteAddr: "10.0.0.1:8080", trustProxyHeaders: false, expectedIP: "10.0.0.1", }, { name: "X-Forwarded-For single IP when TrustProxyHeaders is true", headers: map[string]string{ "X-Forwarded-For": "192.168.1.100", }, remoteAddr: "10.0.0.1:8080", trustProxyHeaders: true, expectedIP: "192.168.1.100", }, { name: "X-Forwarded-For multiple IPs when TrustProxyHeaders is true", headers: map[string]string{ "X-Forwarded-For": "192.168.1.100, 10.0.0.1, 172.16.0.1", }, remoteAddr: "10.0.0.1:8080", trustProxyHeaders: true, expectedIP: "192.168.1.100", }, { name: "X-Real-IP when TrustProxyHeaders is true", headers: map[string]string{ "X-Real-IP": "192.168.1.200", }, remoteAddr: "10.0.0.1:8080", trustProxyHeaders: true, expectedIP: "192.168.1.200", }, { name: "X-Forwarded-For takes precedence over X-Real-IP when TrustProxyHeaders is true", headers: map[string]string{ "X-Forwarded-For": "192.168.1.100", "X-Real-IP": "192.168.1.200", }, remoteAddr: "10.0.0.1:8080", trustProxyHeaders: true, expectedIP: "192.168.1.100", }, { name: "RemoteAddr only", headers: map[string]string{}, remoteAddr: "192.168.1.50:8080", trustProxyHeaders: false, expectedIP: "192.168.1.50", }, { name: "RemoteAddr with IPv6", headers: map[string]string{}, remoteAddr: "[::1]:8080", trustProxyHeaders: false, expectedIP: "::1", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { TrustProxyHeaders = tt.trustProxyHeaders request := httptest.NewRequest("GET", "/test", nil) request.RemoteAddr = tt.remoteAddr for header, value := range tt.headers { request.Header.Set(header, value) } ip := getClientIP(request) if ip != tt.expectedIP { t.Errorf("Expected IP %q, got %q", tt.expectedIP, ip) } }) } TrustProxyHeaders = originalTrust } func TestContainsSQLInjection(t *testing.T) { tests := []struct { input string expected bool }{ {"' OR '1'='1", true}, {"'; DROP TABLE users; --", true}, {"UNION SELECT * FROM users", true}, {"INSERT INTO users VALUES", true}, {"DELETE FROM users", true}, {"UPDATE SET", true}, {"normal query", false}, {"SELECT * FROM posts", false}, {"' OR '1'='1'", true}, {"union select", true}, {"", false}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { result := containsSQLInjection(tt.input) if result != tt.expected { t.Errorf("containsSQLInjection(%q) = %v, expected %v", tt.input, result, tt.expected) } }) } } func TestContainsXSS(t *testing.T) { tests := []struct { input string expected bool }{ {"", true}, {"javascript:alert('xss')", true}, {"onload=alert('xss')", true}, {"onerror=alert('xss')", true}, {"onclick=alert('xss')", true}, {"