369 lines
9.5 KiB
Go
369 lines
9.5 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
func TestSecureLogger_MaskSensitiveData(t *testing.T) {
|
|
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
|
config := logger.Config{
|
|
SlowThreshold: time.Second,
|
|
LogLevel: logger.Info,
|
|
IgnoreRecordNotFoundError: true,
|
|
Colorful: false,
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
production bool
|
|
input string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "development_mode_no_masking",
|
|
production: false,
|
|
input: "SELECT * FROM users WHERE email = 'user@example.com'",
|
|
expected: "SELECT * FROM users WHERE email = 'user@example.com'",
|
|
},
|
|
{
|
|
name: "production_mode_mask_email",
|
|
production: true,
|
|
input: "SELECT * FROM users WHERE email = 'user@example.com'",
|
|
expected: "SELECT * FROM users WHERE email = '[EMAIL_MASKED]'",
|
|
},
|
|
{
|
|
name: "production_mode_mask_token",
|
|
production: true,
|
|
input: "SELECT * FROM users WHERE password_reset_token = 'abc123def456ghi789'",
|
|
expected: "SELECT * FROM users WHERE password_reset_token = '[TOKEN_MASKED]'",
|
|
},
|
|
{
|
|
name: "production_mode_mask_uuid",
|
|
production: true,
|
|
input: "SELECT * FROM users WHERE id = '550e8400-e29b-41d4-a716-446655440000'",
|
|
expected: "SELECT * FROM users WHERE id = '[TOKEN_MASKED]'",
|
|
},
|
|
{
|
|
name: "production_mode_no_masking_short_values",
|
|
production: true,
|
|
input: "SELECT * FROM users WHERE id = 123",
|
|
expected: "SELECT * FROM users WHERE id = 123",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
secureLogger := NewSecureLogger(writer, config, tt.production)
|
|
result := secureLogger.maskSensitiveData(tt.input)
|
|
|
|
if tt.production {
|
|
if strings.Contains(result, "user@example.com") {
|
|
t.Errorf("Email should be masked in production mode")
|
|
}
|
|
if strings.Contains(result, "abc123def456ghi789") {
|
|
t.Errorf("Token should be masked in production mode")
|
|
}
|
|
} else {
|
|
if result != tt.input {
|
|
t.Errorf("Expected %q, got %q", tt.input, result)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSecureLogger_IsSensitiveValue(t *testing.T) {
|
|
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
|
config := logger.Config{
|
|
SlowThreshold: time.Second,
|
|
LogLevel: logger.Info,
|
|
IgnoreRecordNotFoundError: true,
|
|
Colorful: false,
|
|
}
|
|
|
|
secureLogger := NewSecureLogger(writer, config, true)
|
|
|
|
tests := []struct {
|
|
name string
|
|
value string
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "email_address",
|
|
value: "user@example.com",
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "long_token",
|
|
value: "abc123def456ghi789jkl012mno345pqr678stu901vwx234yz",
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "uuid",
|
|
value: "550e8400-e29b-41d4-a716-446655440000",
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "short_value",
|
|
value: "123",
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "normal_text",
|
|
value: "golang programming",
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "base64_like",
|
|
value: "SGVsbG8gV29ybGQ=",
|
|
expected: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := secureLogger.isSensitiveValue(tt.value)
|
|
if result != tt.expected {
|
|
t.Errorf("Expected %v for value %q, got %v", tt.expected, tt.value, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSecureLogger_LogLevels(t *testing.T) {
|
|
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
|
config := logger.Config{
|
|
SlowThreshold: time.Second,
|
|
LogLevel: logger.Info,
|
|
IgnoreRecordNotFoundError: true,
|
|
Colorful: false,
|
|
}
|
|
|
|
secureLogger := NewSecureLogger(writer, config, false)
|
|
|
|
ctx := context.Background()
|
|
|
|
secureLogger.Info(ctx, "Test info message")
|
|
secureLogger.Warn(ctx, "Test warn message")
|
|
secureLogger.Error(ctx, "Test error message")
|
|
}
|
|
|
|
func TestCreateSecureLogger(t *testing.T) {
|
|
prodLogger := CreateSecureLogger(true)
|
|
if prodLogger == nil {
|
|
t.Error("Expected non-nil logger for production mode")
|
|
}
|
|
|
|
devLogger := CreateSecureLogger(false)
|
|
if devLogger == nil {
|
|
t.Error("Expected non-nil logger for development mode")
|
|
}
|
|
}
|
|
|
|
func TestSecureLogger_LogMode(t *testing.T) {
|
|
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
|
config := logger.Config{
|
|
SlowThreshold: time.Second,
|
|
LogLevel: logger.Info,
|
|
IgnoreRecordNotFoundError: true,
|
|
Colorful: false,
|
|
}
|
|
|
|
secureLogger := NewSecureLogger(writer, config, false)
|
|
|
|
newLogger := secureLogger.LogMode(logger.Error)
|
|
if newLogger == nil {
|
|
t.Error("Expected non-nil logger from LogMode")
|
|
}
|
|
|
|
if secureLogger.config.LogLevel != logger.Info {
|
|
t.Error("Original logger should be unchanged")
|
|
}
|
|
}
|
|
|
|
func TestSecureLogger_Trace(t *testing.T) {
|
|
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
|
config := logger.Config{
|
|
SlowThreshold: time.Second,
|
|
LogLevel: logger.Info,
|
|
IgnoreRecordNotFoundError: true,
|
|
Colorful: false,
|
|
}
|
|
|
|
secureLogger := NewSecureLogger(writer, config, false)
|
|
ctx := context.Background()
|
|
|
|
t.Run("silent_level", func(t *testing.T) {
|
|
silentLogger := secureLogger.LogMode(logger.Silent)
|
|
silentLogger.Trace(ctx, time.Now(), func() (string, int64) {
|
|
return "SELECT * FROM users", 1
|
|
}, nil)
|
|
})
|
|
|
|
t.Run("error_level_with_error", func(t *testing.T) {
|
|
errorLogger := secureLogger.LogMode(logger.Error)
|
|
errorLogger.Trace(ctx, time.Now(), func() (string, int64) {
|
|
return "SELECT * FROM users", 1
|
|
}, errors.New("test error"))
|
|
})
|
|
|
|
t.Run("warn_level_slow_query", func(t *testing.T) {
|
|
warnLogger := secureLogger.LogMode(logger.Warn)
|
|
|
|
startTime := time.Now().Add(-2 * time.Second)
|
|
warnLogger.Trace(ctx, startTime, func() (string, int64) {
|
|
return "SELECT * FROM users", 1
|
|
}, nil)
|
|
})
|
|
|
|
t.Run("info_level", func(t *testing.T) {
|
|
infoLogger := secureLogger.LogMode(logger.Info)
|
|
infoLogger.Trace(ctx, time.Now(), func() (string, int64) {
|
|
return "SELECT * FROM users", 1
|
|
}, nil)
|
|
})
|
|
}
|
|
|
|
func TestSecureLogger_MaskSQLValues(t *testing.T) {
|
|
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
|
config := logger.Config{
|
|
SlowThreshold: time.Second,
|
|
LogLevel: logger.Info,
|
|
IgnoreRecordNotFoundError: true,
|
|
Colorful: false,
|
|
}
|
|
|
|
secureLogger := NewSecureLogger(writer, config, true)
|
|
|
|
tests := []struct {
|
|
name string
|
|
sql string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "email_in_sql",
|
|
sql: "SELECT * FROM users WHERE email = 'user@example.com'",
|
|
expected: "SELECT * FROM users WHERE email = '[MASKED]'",
|
|
},
|
|
{
|
|
name: "token_in_sql",
|
|
sql: "SELECT * FROM users WHERE token = 'abc123def456ghi789'",
|
|
expected: "SELECT * FROM users WHERE token = '[MASKED]'",
|
|
},
|
|
{
|
|
name: "uuid_in_sql",
|
|
sql: "SELECT * FROM users WHERE id = '550e8400-e29b-41d4-a716-446655440000'",
|
|
expected: "SELECT * FROM users WHERE id = '[MASKED]'",
|
|
},
|
|
{
|
|
name: "normal_value",
|
|
sql: "SELECT * FROM users WHERE id = 123",
|
|
expected: "SELECT * FROM users WHERE id = 123",
|
|
},
|
|
{
|
|
name: "multiple_values",
|
|
sql: "SELECT * FROM users WHERE email = 'user@example.com' AND id = 123",
|
|
expected: "SELECT * FROM users WHERE email = '[MASKED]' AND id = 123",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := secureLogger.maskSQLValues(tt.sql)
|
|
if result != tt.expected {
|
|
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSecureLogger_IsRecordNotFoundError(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "record_not_found",
|
|
err: errors.New("record not found"),
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "not_found",
|
|
err: errors.New("not found"),
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "RECORD NOT FOUND",
|
|
err: errors.New("RECORD NOT FOUND"),
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "NOT FOUND",
|
|
err: errors.New("NOT FOUND"),
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "other_error",
|
|
err: errors.New("connection failed"),
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "nil_error",
|
|
err: nil,
|
|
expected: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := IsRecordNotFoundError(tt.err)
|
|
if result != tt.expected {
|
|
t.Errorf("Expected %v for error '%v', got %v", tt.expected, tt.err, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSecureLogger_ProductionMode(t *testing.T) {
|
|
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
|
config := logger.Config{
|
|
SlowThreshold: time.Second,
|
|
LogLevel: logger.Error,
|
|
IgnoreRecordNotFoundError: true,
|
|
Colorful: false,
|
|
}
|
|
|
|
secureLogger := NewSecureLogger(writer, config, true)
|
|
ctx := context.Background()
|
|
|
|
secureLogger.Info(ctx, "User login: %s", "user@example.com")
|
|
secureLogger.Warn(ctx, "Token validation: %s", "abc123def456ghi789")
|
|
secureLogger.Error(ctx, "Database error: %s", "connection failed")
|
|
}
|
|
|
|
func TestSecureLogger_DevelopmentMode(t *testing.T) {
|
|
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
|
config := logger.Config{
|
|
SlowThreshold: time.Second,
|
|
LogLevel: logger.Info,
|
|
IgnoreRecordNotFoundError: true,
|
|
Colorful: false,
|
|
}
|
|
|
|
secureLogger := NewSecureLogger(writer, config, false)
|
|
ctx := context.Background()
|
|
|
|
secureLogger.Info(ctx, "User login: %s", "user@example.com")
|
|
secureLogger.Warn(ctx, "Token validation: %s", "abc123def456ghi789")
|
|
secureLogger.Error(ctx, "Database error: %s", "connection failed")
|
|
}
|