package middleware
import (
"bytes"
"context"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestNewSecurityLogger(t *testing.T) {
logger := NewSecurityLogger()
if logger == nil {
t.Fatal("NewSecurityLogger should not return nil")
}
if logger.logger == nil {
t.Fatal("SecurityLogger should have a logger instance")
}
}
func TestSecurityLogger_LogSecurityEvent(t *testing.T) {
var buf bytes.Buffer
logger := &SecurityLogger{
logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile),
}
event := SecurityEvent{
Type: "Test Event",
IP: "192.168.1.1",
UserAgent: "Test Agent",
Path: "/test",
Method: "GET",
UserID: 123,
Details: "Test details",
Timestamp: time.Now(),
Severity: "INFO",
}
logger.LogSecurityEvent(event)
logOutput := buf.String()
expectedParts := []string{
"[INFO]",
"192.168.1.1",
"GET",
"/test",
"Test Agent",
"UserID: 123",
"Test Event",
"Test details",
}
for _, part := range expectedParts {
if !strings.Contains(logOutput, part) {
t.Errorf("Expected log output to contain %q, got %q", part, logOutput)
}
}
}
func TestSecurityLoggingMiddleware_ClientError(t *testing.T) {
var buf bytes.Buffer
logger := &SecurityLogger{
logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile),
}
handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
request := httptest.NewRequest("GET", "/test", nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, request)
logOutput := buf.String()
expectedParts := []string{
"[WARN]",
"Client Error",
"Client error response",
}
for _, part := range expectedParts {
if !strings.Contains(logOutput, part) {
t.Errorf("Expected log output to contain %q, got %q", part, logOutput)
}
}
}
func TestSecurityLoggingMiddleware_ServerError(t *testing.T) {
var buf bytes.Buffer
logger := &SecurityLogger{
logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile),
}
handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
request := httptest.NewRequest("GET", "/test", nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, request)
logOutput := buf.String()
expectedParts := []string{
"[ERROR]",
"Server Error",
"Server error response",
}
for _, part := range expectedParts {
if !strings.Contains(logOutput, part) {
t.Errorf("Expected log output to contain %q, got %q", part, logOutput)
}
}
}
func TestSecurityLoggingMiddleware_Authentication(t *testing.T) {
var buf bytes.Buffer
logger := &SecurityLogger{
logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile),
}
handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
request := httptest.NewRequest("POST", "/api/auth/login", nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, request)
logOutput := buf.String()
expectedParts := []string{
"[INFO]",
"Authentication",
"Authentication endpoint accessed",
}
for _, part := range expectedParts {
if !strings.Contains(logOutput, part) {
t.Errorf("Expected log output to contain %q, got %q", part, logOutput)
}
}
}
func TestSecurityLoggingMiddleware_PostCreation(t *testing.T) {
var buf bytes.Buffer
logger := &SecurityLogger{
logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile),
}
handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
request := httptest.NewRequest("POST", "/api/posts/", nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, request)
logOutput := buf.String()
expectedParts := []string{
"[INFO]",
"Post Creation",
"Post creation attempt",
}
for _, part := range expectedParts {
if !strings.Contains(logOutput, part) {
t.Errorf("Expected log output to contain %q, got %q", part, logOutput)
}
}
}
func TestSecurityLoggingMiddleware_PostModification(t *testing.T) {
var buf bytes.Buffer
logger := &SecurityLogger{
logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile),
}
handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
request := httptest.NewRequest("PUT", "/api/posts/1", nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, request)
logOutput := buf.String()
expectedParts := []string{
"[INFO]",
"Post Modification",
"Post modification attempt",
}
for _, part := range expectedParts {
if !strings.Contains(logOutput, part) {
t.Errorf("Expected log output to contain %q, got %q", part, logOutput)
}
}
buf.Reset()
request = httptest.NewRequest("DELETE", "/api/posts/1", nil)
recorder = httptest.NewRecorder()
handler.ServeHTTP(recorder, request)
logOutput = buf.String()
for _, part := range expectedParts {
if !strings.Contains(logOutput, part) {
t.Errorf("Expected log output to contain %q, got %q", part, logOutput)
}
}
}
func TestSecurityLoggingMiddleware_APIAccess(t *testing.T) {
var buf bytes.Buffer
logger := &SecurityLogger{
logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile),
}
handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
request := httptest.NewRequest("GET", "/api/users", nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, request)
logOutput := buf.String()
expectedParts := []string{
"[INFO]",
"API Access",
"API endpoint accessed",
}
for _, part := range expectedParts {
if !strings.Contains(logOutput, part) {
t.Errorf("Expected log output to contain %q, got %q", part, logOutput)
}
}
}
func TestSecurityLoggingMiddleware_WithUserID(t *testing.T) {
var buf bytes.Buffer
logger := &SecurityLogger{
logger: log.New(&buf, "[SECURITY] ", log.LstdFlags|log.Lshortfile),
}
handler := SecurityLoggingMiddleware(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
request := httptest.NewRequest("GET", "/test", nil)
request = request.WithContext(context.WithValue(request.Context(), UserIDKey, uint(456)))
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, request)
logOutput := buf.String()
if !strings.Contains(logOutput, "UserID: 456") {
t.Errorf("Expected log output to contain UserID: 456, got %q", logOutput)
}
}
func TestGetClientIP(t *testing.T) {
originalTrust := TrustProxyHeaders
defer func() {
TrustProxyHeaders = originalTrust
}()
tests := []struct {
name string
headers map[string]string
remoteAddr string
trustProxyHeaders bool
expectedIP string
}{
{
name: "Default: RemoteAddr when TrustProxyHeaders is false",
headers: map[string]string{"X-Forwarded-For": "192.168.1.100"},
remoteAddr: "10.0.0.1:8080",
trustProxyHeaders: false,
expectedIP: "10.0.0.1",
},
{
name: "X-Forwarded-For single IP when TrustProxyHeaders is true",
headers: map[string]string{
"X-Forwarded-For": "192.168.1.100",
},
remoteAddr: "10.0.0.1:8080",
trustProxyHeaders: true,
expectedIP: "192.168.1.100",
},
{
name: "X-Forwarded-For multiple IPs when TrustProxyHeaders is true",
headers: map[string]string{
"X-Forwarded-For": "192.168.1.100, 10.0.0.1, 172.16.0.1",
},
remoteAddr: "10.0.0.1:8080",
trustProxyHeaders: true,
expectedIP: "192.168.1.100",
},
{
name: "X-Real-IP when TrustProxyHeaders is true",
headers: map[string]string{
"X-Real-IP": "192.168.1.200",
},
remoteAddr: "10.0.0.1:8080",
trustProxyHeaders: true,
expectedIP: "192.168.1.200",
},
{
name: "X-Forwarded-For takes precedence over X-Real-IP when TrustProxyHeaders is true",
headers: map[string]string{
"X-Forwarded-For": "192.168.1.100",
"X-Real-IP": "192.168.1.200",
},
remoteAddr: "10.0.0.1:8080",
trustProxyHeaders: true,
expectedIP: "192.168.1.100",
},
{
name: "RemoteAddr only",
headers: map[string]string{},
remoteAddr: "192.168.1.50:8080",
trustProxyHeaders: false,
expectedIP: "192.168.1.50",
},
{
name: "RemoteAddr with IPv6",
headers: map[string]string{},
remoteAddr: "[::1]:8080",
trustProxyHeaders: false,
expectedIP: "::1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
TrustProxyHeaders = tt.trustProxyHeaders
request := httptest.NewRequest("GET", "/test", nil)
request.RemoteAddr = tt.remoteAddr
for header, value := range tt.headers {
request.Header.Set(header, value)
}
ip := getClientIP(request)
if ip != tt.expectedIP {
t.Errorf("Expected IP %q, got %q", tt.expectedIP, ip)
}
})
}
TrustProxyHeaders = originalTrust
}
func TestContainsSQLInjection(t *testing.T) {
tests := []struct {
input string
expected bool
}{
{"' OR '1'='1", true},
{"'; DROP TABLE users; --", true},
{"UNION SELECT * FROM users", true},
{"INSERT INTO users VALUES", true},
{"DELETE FROM users", true},
{"UPDATE SET", true},
{"normal query", false},
{"SELECT * FROM posts", false},
{"' OR '1'='1'", true},
{"union select", true},
{"", false},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := containsSQLInjection(tt.input)
if result != tt.expected {
t.Errorf("containsSQLInjection(%q) = %v, expected %v", tt.input, result, tt.expected)
}
})
}
}
func TestContainsXSS(t *testing.T) {
tests := []struct {
input string
expected bool
}{
{"", true},
{"javascript:alert('xss')", true},
{"onload=alert('xss')", true},
{"onerror=alert('xss')", true},
{"onclick=alert('xss')", true},
{"