To gitea and beyond, let's go(-yco)
This commit is contained in:
77
internal/database/connection.go
Normal file
77
internal/database/connection.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/middleware"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func connectDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
dsn := cfg.GetConnectionString()
|
||||
gormLogger := CreateSecureLogger(!cfg.App.Debug)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func Connect(cfg *config.Config) (*gorm.DB, error) {
|
||||
return connectDB(cfg)
|
||||
}
|
||||
|
||||
func ConnectWithMonitoring(cfg *config.Config, monitor middleware.DBMonitor) (*gorm.DB, error) {
|
||||
db, err := connectDB(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if monitor != nil {
|
||||
monitoringPlugin := NewGormDBMonitor(monitor)
|
||||
if err := db.Use(monitoringPlugin); err != nil {
|
||||
return nil, fmt.Errorf("failed to add monitoring plugin: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func Migrate(db *gorm.DB) error {
|
||||
if db == nil {
|
||||
return fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
err := db.AutoMigrate(
|
||||
&User{},
|
||||
&Post{},
|
||||
&Vote{},
|
||||
&AccountDeletionRequest{},
|
||||
&RefreshToken{},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to migrate database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Close(db *gorm.DB) error {
|
||||
if db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get underlying sql.DB: %w", err)
|
||||
}
|
||||
|
||||
return sqlDB.Close()
|
||||
}
|
||||
169
internal/database/connection_pool.go
Normal file
169
internal/database/connection_pool.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/config"
|
||||
)
|
||||
|
||||
type ConnectionPoolConfig struct {
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
ConnMaxIdleTime time.Duration
|
||||
ConnTimeout time.Duration
|
||||
HealthCheckInterval time.Duration
|
||||
}
|
||||
|
||||
func DefaultConnectionPoolConfig() ConnectionPoolConfig {
|
||||
return ConnectionPoolConfig{
|
||||
MaxOpenConns: 25,
|
||||
MaxIdleConns: 10,
|
||||
ConnMaxLifetime: 5 * time.Minute,
|
||||
ConnMaxIdleTime: 1 * time.Minute,
|
||||
ConnTimeout: 30 * time.Second,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func ProductionConnectionPoolConfig() ConnectionPoolConfig {
|
||||
return ConnectionPoolConfig{
|
||||
MaxOpenConns: 100,
|
||||
MaxIdleConns: 25,
|
||||
ConnMaxLifetime: 10 * time.Minute,
|
||||
ConnMaxIdleTime: 2 * time.Minute,
|
||||
ConnTimeout: 10 * time.Second,
|
||||
HealthCheckInterval: 15 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func HighTrafficConnectionPoolConfig() ConnectionPoolConfig {
|
||||
return ConnectionPoolConfig{
|
||||
MaxOpenConns: 200,
|
||||
MaxIdleConns: 50,
|
||||
ConnMaxLifetime: 15 * time.Minute,
|
||||
ConnMaxIdleTime: 5 * time.Minute,
|
||||
ConnTimeout: 5 * time.Second,
|
||||
HealthCheckInterval: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
type ConnectionPoolManager struct {
|
||||
db *gorm.DB
|
||||
sqlDB *sql.DB
|
||||
config ConnectionPoolConfig
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewConnectionPoolManager(cfg *config.Config, poolConfig ConnectionPoolConfig) (*ConnectionPoolManager, error) {
|
||||
dsn := cfg.GetConnectionString()
|
||||
|
||||
secureLogger := CreateSecureLogger(!cfg.App.Debug)
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: secureLogger,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err)
|
||||
}
|
||||
|
||||
sqlDB.SetMaxOpenConns(poolConfig.MaxOpenConns)
|
||||
sqlDB.SetMaxIdleConns(poolConfig.MaxIdleConns)
|
||||
sqlDB.SetConnMaxLifetime(poolConfig.ConnMaxLifetime)
|
||||
sqlDB.SetConnMaxIdleTime(poolConfig.ConnMaxIdleTime)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), poolConfig.ConnTimeout)
|
||||
if err := sqlDB.PingContext(ctx); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
cancel()
|
||||
|
||||
managerCtx, managerCancel := context.WithCancel(context.Background())
|
||||
|
||||
manager := &ConnectionPoolManager{
|
||||
db: db,
|
||||
sqlDB: sqlDB,
|
||||
config: poolConfig,
|
||||
ctx: managerCtx,
|
||||
cancel: managerCancel,
|
||||
}
|
||||
|
||||
go manager.startHealthCheck()
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func (m *ConnectionPoolManager) GetDB() *gorm.DB {
|
||||
return m.db
|
||||
}
|
||||
|
||||
func (m *ConnectionPoolManager) GetSQLDB() *sql.DB {
|
||||
return m.sqlDB
|
||||
}
|
||||
|
||||
func (m *ConnectionPoolManager) GetPoolStats() sql.DBStats {
|
||||
return m.sqlDB.Stats()
|
||||
}
|
||||
|
||||
func (m *ConnectionPoolManager) startHealthCheck() {
|
||||
ticker := time.NewTicker(m.config.HealthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.performHealthCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ConnectionPoolManager) performHealthCheck() {
|
||||
ctx, cancel := context.WithTimeout(m.ctx, m.config.ConnTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := m.sqlDB.PingContext(ctx); err != nil {
|
||||
log.Printf("Database health check failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ConnectionPoolManager) Close() error {
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
|
||||
if m.sqlDB != nil {
|
||||
return m.sqlDB.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ConnectWithPool(cfg *config.Config) (*ConnectionPoolManager, error) {
|
||||
var poolConfig ConnectionPoolConfig
|
||||
|
||||
if cfg.App.Debug {
|
||||
poolConfig = DefaultConnectionPoolConfig()
|
||||
} else {
|
||||
poolConfig = ProductionConnectionPoolConfig()
|
||||
}
|
||||
|
||||
if cfg.App.BaseURL != "" && !cfg.App.Debug {
|
||||
poolConfig = HighTrafficConnectionPoolConfig()
|
||||
}
|
||||
|
||||
return NewConnectionPoolManager(cfg, poolConfig)
|
||||
}
|
||||
253
internal/database/connection_pool_test.go
Normal file
253
internal/database/connection_pool_test.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"goyco/internal/config"
|
||||
)
|
||||
|
||||
func TestConnectionPoolConfig(t *testing.T) {
|
||||
t.Run("default_config", func(t *testing.T) {
|
||||
config := DefaultConnectionPoolConfig()
|
||||
|
||||
if config.MaxOpenConns <= 0 {
|
||||
t.Error("MaxOpenConns should be positive")
|
||||
}
|
||||
if config.MaxIdleConns <= 0 {
|
||||
t.Error("MaxIdleConns should be positive")
|
||||
}
|
||||
if config.ConnMaxLifetime <= 0 {
|
||||
t.Error("ConnMaxLifetime should be positive")
|
||||
}
|
||||
if config.ConnMaxIdleTime <= 0 {
|
||||
t.Error("ConnMaxIdleTime should be positive")
|
||||
}
|
||||
if config.ConnTimeout <= 0 {
|
||||
t.Error("ConnTimeout should be positive")
|
||||
}
|
||||
if config.HealthCheckInterval <= 0 {
|
||||
t.Error("HealthCheckInterval should be positive")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("production_config", func(t *testing.T) {
|
||||
config := ProductionConnectionPoolConfig()
|
||||
|
||||
if config.MaxOpenConns < 50 {
|
||||
t.Error("Production MaxOpenConns should be higher")
|
||||
}
|
||||
if config.MaxIdleConns < 10 {
|
||||
t.Error("Production MaxIdleConns should be higher")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("high_traffic_config", func(t *testing.T) {
|
||||
config := HighTrafficConnectionPoolConfig()
|
||||
|
||||
if config.MaxOpenConns < 100 {
|
||||
t.Error("High traffic MaxOpenConns should be very high")
|
||||
}
|
||||
if config.MaxIdleConns < 25 {
|
||||
t.Error("High traffic MaxIdleConns should be high")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectionPoolManager_Stats(t *testing.T) {
|
||||
|
||||
t.Run("config_validation", func(t *testing.T) {
|
||||
config := DefaultConnectionPoolConfig()
|
||||
|
||||
if config.MaxOpenConns < config.MaxIdleConns {
|
||||
t.Error("MaxOpenConns should be >= MaxIdleConns")
|
||||
}
|
||||
|
||||
if config.ConnMaxLifetime < config.ConnMaxIdleTime {
|
||||
t.Error("ConnMaxLifetime should be >= ConnMaxIdleTime")
|
||||
}
|
||||
|
||||
if config.ConnTimeout > 60*time.Second {
|
||||
t.Error("ConnTimeout should be reasonable")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectionPoolConfig_Values(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config ConnectionPoolConfig
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
config: DefaultConnectionPoolConfig(),
|
||||
},
|
||||
{
|
||||
name: "production",
|
||||
config: ProductionConnectionPoolConfig(),
|
||||
},
|
||||
{
|
||||
name: "high_traffic",
|
||||
config: HighTrafficConnectionPoolConfig(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := tt.config
|
||||
|
||||
if config.MaxOpenConns <= 0 {
|
||||
t.Errorf("MaxOpenConns should be positive, got %d", config.MaxOpenConns)
|
||||
}
|
||||
if config.MaxIdleConns <= 0 {
|
||||
t.Errorf("MaxIdleConns should be positive, got %d", config.MaxIdleConns)
|
||||
}
|
||||
if config.ConnMaxLifetime <= 0 {
|
||||
t.Errorf("ConnMaxLifetime should be positive, got %v", config.ConnMaxLifetime)
|
||||
}
|
||||
if config.ConnMaxIdleTime <= 0 {
|
||||
t.Errorf("ConnMaxIdleTime should be positive, got %v", config.ConnMaxIdleTime)
|
||||
}
|
||||
if config.ConnTimeout <= 0 {
|
||||
t.Errorf("ConnTimeout should be positive, got %v", config.ConnTimeout)
|
||||
}
|
||||
if config.HealthCheckInterval <= 0 {
|
||||
t.Errorf("HealthCheckInterval should be positive, got %v", config.HealthCheckInterval)
|
||||
}
|
||||
|
||||
if config.MaxOpenConns < config.MaxIdleConns {
|
||||
t.Errorf("MaxOpenConns (%d) should be >= MaxIdleConns (%d)", config.MaxOpenConns, config.MaxIdleConns)
|
||||
}
|
||||
|
||||
if config.ConnMaxLifetime < config.ConnMaxIdleTime {
|
||||
t.Errorf("ConnMaxLifetime (%v) should be >= ConnMaxIdleTime (%v)", config.ConnMaxLifetime, config.ConnMaxIdleTime)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionPoolManager(t *testing.T) {
|
||||
t.Run("invalid_database_config", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "invalid-host",
|
||||
Port: "9999",
|
||||
User: "invalid",
|
||||
Password: "invalid",
|
||||
Name: "invalid",
|
||||
SSLMode: "disable",
|
||||
},
|
||||
App: config.AppConfig{
|
||||
Debug: true,
|
||||
},
|
||||
}
|
||||
|
||||
poolConfig := DefaultConnectionPoolConfig()
|
||||
manager, err := NewConnectionPoolManager(cfg, poolConfig)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error with invalid database config")
|
||||
}
|
||||
if manager != nil {
|
||||
t.Error("Expected nil manager with invalid database config")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to connect to database") {
|
||||
t.Errorf("Expected connection error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectionPoolManager_Methods(t *testing.T) {
|
||||
t.Run("get_db_methods", func(t *testing.T) {
|
||||
|
||||
manager := &ConnectionPoolManager{
|
||||
db: nil,
|
||||
sqlDB: nil,
|
||||
}
|
||||
|
||||
if manager.GetDB() != nil {
|
||||
t.Error("Expected nil DB from uninitialized manager")
|
||||
}
|
||||
|
||||
if manager.GetSQLDB() != nil {
|
||||
t.Error("Expected nil SQLDB from uninitialized manager")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectWithPool(t *testing.T) {
|
||||
t.Run("debug_mode_config", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "test",
|
||||
Password: "test",
|
||||
Name: "test_db",
|
||||
SSLMode: "disable",
|
||||
},
|
||||
App: config.AppConfig{
|
||||
Debug: true,
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := ConnectWithPool(cfg)
|
||||
if err == nil {
|
||||
t.Error("Expected error with invalid database config")
|
||||
}
|
||||
if manager != nil {
|
||||
t.Error("Expected nil manager with invalid database config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("production_mode_config", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "test",
|
||||
Password: "test",
|
||||
Name: "test_db",
|
||||
SSLMode: "disable",
|
||||
},
|
||||
App: config.AppConfig{
|
||||
Debug: false,
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := ConnectWithPool(cfg)
|
||||
if err == nil {
|
||||
t.Error("Expected error with invalid database config")
|
||||
}
|
||||
if manager != nil {
|
||||
t.Error("Expected nil manager with invalid database config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("high_traffic_config", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "test",
|
||||
Password: "test",
|
||||
Name: "test_db",
|
||||
SSLMode: "disable",
|
||||
},
|
||||
App: config.AppConfig{
|
||||
Debug: false,
|
||||
BaseURL: "https://example.com",
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := ConnectWithPool(cfg)
|
||||
if err == nil {
|
||||
t.Error("Expected error with invalid database config")
|
||||
}
|
||||
if manager != nil {
|
||||
t.Error("Expected nil manager with invalid database config")
|
||||
}
|
||||
})
|
||||
}
|
||||
156
internal/database/connection_test.go
Normal file
156
internal/database/connection_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"goyco/internal/config"
|
||||
"goyco/internal/middleware"
|
||||
)
|
||||
|
||||
func TestConnectReturnsErrorWhenUnableToReachDatabase(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "127.0.0.1",
|
||||
Port: "1",
|
||||
User: "postgres",
|
||||
Password: "password",
|
||||
Name: "goyco_test",
|
||||
SSLMode: "disable",
|
||||
},
|
||||
}
|
||||
_, err := Connect(cfg)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err == nil {
|
||||
t.Fatalf("expected connection error but got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to connect to database") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("connection test timed out after 5 seconds")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateFailsWhenDBNil(t *testing.T) {
|
||||
err := Migrate(nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error when DB is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateCreatesTables(t *testing.T) {
|
||||
dbName := "file:memdb_" + t.Name() + "?mode=memory&cache=private"
|
||||
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite in-memory database: %v", err)
|
||||
}
|
||||
|
||||
if err := Migrate(db); err != nil {
|
||||
t.Fatalf("expected migrations to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
migrator := db.Migrator()
|
||||
|
||||
models := []any{&User{}, &Post{}, &Vote{}}
|
||||
for _, model := range models {
|
||||
if !migrator.HasTable(model) {
|
||||
t.Fatalf("expected table for %T to exist after migration", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseReturnsNilWhenDBNil(t *testing.T) {
|
||||
if err := Close(nil); err != nil {
|
||||
t.Fatalf("expected nil error when DB is nil, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseClosesUnderlyingConnection(t *testing.T) {
|
||||
dbName := "file:memdb_" + t.Name() + "?mode=memory&cache=private"
|
||||
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite in-memory database: %v", err)
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get sql.DB: %v", err)
|
||||
}
|
||||
|
||||
if err := Close(db); err != nil {
|
||||
t.Fatalf("expected close to succeed, got %v", err)
|
||||
}
|
||||
|
||||
if err := sqlDB.Ping(); err == nil {
|
||||
t.Fatalf("expected ping on closed connection to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectWithMonitoring(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "test",
|
||||
Password: "test",
|
||||
Name: "test_db",
|
||||
SSLMode: "disable",
|
||||
},
|
||||
App: config.AppConfig{
|
||||
Debug: true,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ConnectWithMonitoring(cfg, nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected connection error with invalid database config")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to connect to database") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectWithMonitoringWithValidMonitor(t *testing.T) {
|
||||
mockMonitor := middleware.NewInMemoryDBMonitor()
|
||||
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: "5432",
|
||||
User: "test",
|
||||
Password: "test",
|
||||
Name: "test_db",
|
||||
SSLMode: "disable",
|
||||
},
|
||||
App: config.AppConfig{
|
||||
Debug: true,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ConnectWithMonitoring(cfg, mockMonitor)
|
||||
if err == nil {
|
||||
t.Fatalf("expected connection error with invalid database config")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to connect to database") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
88
internal/database/models.go
Normal file
88
internal/database/models.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Post struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Title string `gorm:"not null"`
|
||||
URL string `gorm:"uniqueIndex"`
|
||||
Content string
|
||||
AuthorID *uint
|
||||
AuthorName string
|
||||
Author User `gorm:"foreignKey:AuthorID;constraint:OnDelete:CASCADE"`
|
||||
UpVotes int `gorm:"default:0"`
|
||||
DownVotes int `gorm:"default:0"`
|
||||
Score int `gorm:"default:0"`
|
||||
Votes []Vote `gorm:"foreignKey:PostID;constraint:OnDelete:CASCADE"`
|
||||
CurrentVote VoteType `gorm:"-"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Username string `gorm:"uniqueIndex;not null"`
|
||||
Email string `gorm:"uniqueIndex;not null"`
|
||||
Password string `gorm:"not null"`
|
||||
EmailVerified bool `gorm:"default:false;not null"`
|
||||
EmailVerifiedAt *time.Time
|
||||
EmailVerificationToken string `gorm:"index"`
|
||||
EmailVerificationSentAt *time.Time
|
||||
PasswordResetToken string `gorm:"index"`
|
||||
PasswordResetSentAt *time.Time
|
||||
PasswordResetExpiresAt *time.Time
|
||||
Locked bool `gorm:"default:false"`
|
||||
SessionVersion uint `gorm:"default:1;not null"`
|
||||
RefreshTokens []RefreshToken `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"`
|
||||
Posts []Post `gorm:"foreignKey:AuthorID"`
|
||||
Votes []Vote `gorm:"foreignKey:UserID"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
type RefreshToken struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
UserID uint `gorm:"not null;index"`
|
||||
User User `gorm:"constraint:OnDelete:CASCADE"`
|
||||
TokenHash string `gorm:"uniqueIndex;not null"`
|
||||
ExpiresAt time.Time `gorm:"not null;index"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
type AccountDeletionRequest struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
UserID uint `gorm:"uniqueIndex"`
|
||||
User User `gorm:"constraint:OnDelete:CASCADE"`
|
||||
TokenHash string `gorm:"uniqueIndex;not null"`
|
||||
ExpiresAt time.Time `gorm:"not null"`
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type Vote struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
UserID *uint `gorm:"uniqueIndex:idx_user_post_vote,where:deleted_at IS NULL AND user_id IS NOT NULL"`
|
||||
User *User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"`
|
||||
PostID uint `gorm:"not null;uniqueIndex:idx_user_post_vote,where:deleted_at IS NULL AND user_id IS NOT NULL;uniqueIndex:idx_hash_post_vote,where:deleted_at IS NULL AND vote_hash IS NOT NULL"`
|
||||
Post Post `gorm:"foreignKey:PostID;constraint:OnDelete:CASCADE"`
|
||||
Type VoteType `gorm:"not null"`
|
||||
VoteHash *string `gorm:"uniqueIndex:idx_hash_post_vote,where:deleted_at IS NULL AND vote_hash IS NOT NULL"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
type VoteType string
|
||||
|
||||
const (
|
||||
VoteUp VoteType = "up"
|
||||
VoteDown VoteType = "down"
|
||||
VoteNone VoteType = "none"
|
||||
)
|
||||
603
internal/database/models_test.go
Normal file
603
internal/database/models_test.go
Normal file
@@ -0,0 +1,603 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
func newTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
dbName := "file:memdb_" + t.Name() + "?mode=memory&cache=shared&_journal_mode=WAL&_synchronous=NORMAL"
|
||||
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to test database: %v", err)
|
||||
}
|
||||
err = db.AutoMigrate(
|
||||
&User{},
|
||||
&Post{},
|
||||
&Vote{},
|
||||
&AccountDeletionRequest{},
|
||||
&RefreshToken{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to migrate database: %v", err)
|
||||
}
|
||||
|
||||
if execErr := db.Exec("PRAGMA busy_timeout = 5000").Error; execErr != nil {
|
||||
t.Fatalf("Failed to configure busy timeout: %v", execErr)
|
||||
}
|
||||
if execErr := db.Exec("PRAGMA foreign_keys = ON").Error; execErr != nil {
|
||||
t.Fatalf("Failed to enable foreign keys: %v", execErr)
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to access SQL DB: %v", err)
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
sqlDB.SetMaxIdleConns(1)
|
||||
sqlDB.SetConnMaxLifetime(5 * time.Minute)
|
||||
return db
|
||||
}
|
||||
|
||||
func createTestUser(t *testing.T, db *gorm.DB) *User {
|
||||
t.Helper()
|
||||
|
||||
uniqueID := time.Now().UnixNano()
|
||||
user := &User{
|
||||
Username: fmt.Sprintf("testuser%d", uniqueID),
|
||||
Email: fmt.Sprintf("test%d@example.com", uniqueID),
|
||||
Password: "hashedpassword123",
|
||||
EmailVerified: true,
|
||||
}
|
||||
if err := db.Create(user).Error; err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
func createTestPost(t *testing.T, db *gorm.DB, authorID uint) *Post {
|
||||
t.Helper()
|
||||
post := &Post{
|
||||
Title: "Test Post " + t.Name(),
|
||||
URL: "https://example.com/test" + t.Name(),
|
||||
Content: "Test content",
|
||||
AuthorID: &authorID,
|
||||
}
|
||||
if err := db.Create(post).Error; err != nil {
|
||||
t.Fatalf("Failed to create test post: %v", err)
|
||||
}
|
||||
return post
|
||||
}
|
||||
|
||||
func TestUser_Model(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
defer func() {
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("create_user", func(t *testing.T) {
|
||||
user := &User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashedpassword",
|
||||
EmailVerified: true,
|
||||
}
|
||||
|
||||
if err := db.Create(user).Error; err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
if user.ID == 0 {
|
||||
t.Error("Expected user ID to be set")
|
||||
}
|
||||
|
||||
if user.CreatedAt.IsZero() {
|
||||
t.Error("Expected CreatedAt to be set")
|
||||
}
|
||||
|
||||
if user.UpdatedAt.IsZero() {
|
||||
t.Error("Expected UpdatedAt to be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("user_constraints", func(t *testing.T) {
|
||||
|
||||
user1 := &User{
|
||||
Username: "duplicate",
|
||||
Email: "user1@example.com",
|
||||
Password: "hashedpassword",
|
||||
EmailVerified: true,
|
||||
}
|
||||
|
||||
user2 := &User{
|
||||
Username: "duplicate",
|
||||
Email: "user2@example.com",
|
||||
Password: "hashedpassword",
|
||||
EmailVerified: true,
|
||||
}
|
||||
|
||||
if err := db.Create(user1).Error; err != nil {
|
||||
t.Fatalf("Failed to create first user: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Create(user2).Error; err == nil {
|
||||
t.Error("Expected error when creating user with duplicate username")
|
||||
}
|
||||
|
||||
user3 := &User{
|
||||
Username: "unique",
|
||||
Email: "user1@example.com",
|
||||
Password: "hashedpassword",
|
||||
EmailVerified: true,
|
||||
}
|
||||
|
||||
if err := db.Create(user3).Error; err == nil {
|
||||
t.Error("Expected error when creating user with duplicate email")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("user_relationships", func(t *testing.T) {
|
||||
user := &User{
|
||||
Username: "author",
|
||||
Email: "author@example.com",
|
||||
Password: "hashedpassword",
|
||||
EmailVerified: true,
|
||||
}
|
||||
|
||||
if err := db.Create(user).Error; err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
post1 := &Post{
|
||||
Title: "Post 1",
|
||||
URL: "https://example.com/1",
|
||||
Content: "Content 1",
|
||||
AuthorID: &user.ID,
|
||||
}
|
||||
|
||||
post2 := &Post{
|
||||
Title: "Post 2",
|
||||
URL: "https://example.com/2",
|
||||
Content: "Content 2",
|
||||
AuthorID: &user.ID,
|
||||
}
|
||||
|
||||
if err := db.Create(post1).Error; err != nil {
|
||||
t.Fatalf("Failed to create post 1: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Create(post2).Error; err != nil {
|
||||
t.Fatalf("Failed to create post 2: %v", err)
|
||||
}
|
||||
|
||||
var foundUser User
|
||||
if err := db.Preload("Posts").First(&foundUser, user.ID).Error; err != nil {
|
||||
t.Fatalf("Failed to load user with posts: %v", err)
|
||||
}
|
||||
|
||||
if len(foundUser.Posts) != 2 {
|
||||
t.Errorf("Expected 2 posts, got %d", len(foundUser.Posts))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPost_Model(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
defer func() {
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("create_post", func(t *testing.T) {
|
||||
user := createTestUser(t, db)
|
||||
|
||||
post := &Post{
|
||||
Title: "Test Post",
|
||||
URL: "https://example.com/test",
|
||||
Content: "Test content",
|
||||
AuthorID: &user.ID,
|
||||
}
|
||||
|
||||
if err := db.Create(post).Error; err != nil {
|
||||
t.Fatalf("Failed to create post: %v", err)
|
||||
}
|
||||
|
||||
if post.ID == 0 {
|
||||
t.Error("Expected post ID to be set")
|
||||
}
|
||||
|
||||
if post.CreatedAt.IsZero() {
|
||||
t.Error("Expected CreatedAt to be set")
|
||||
}
|
||||
|
||||
if post.UpdatedAt.IsZero() {
|
||||
t.Error("Expected UpdatedAt to be set")
|
||||
}
|
||||
|
||||
if post.UpVotes != 0 {
|
||||
t.Error("Expected UpVotes to be 0 by default")
|
||||
}
|
||||
|
||||
if post.DownVotes != 0 {
|
||||
t.Error("Expected DownVotes to be 0 by default")
|
||||
}
|
||||
|
||||
if post.Score != 0 {
|
||||
t.Error("Expected Score to be 0 by default")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("post_constraints", func(t *testing.T) {
|
||||
user := createTestUser(t, db)
|
||||
|
||||
post1 := &Post{
|
||||
Title: "Post 1",
|
||||
URL: "https://example.com/unique",
|
||||
Content: "Content 1",
|
||||
AuthorID: &user.ID,
|
||||
}
|
||||
|
||||
post2 := &Post{
|
||||
Title: "Post 2",
|
||||
URL: "https://example.com/unique",
|
||||
Content: "Content 2",
|
||||
AuthorID: &user.ID,
|
||||
}
|
||||
|
||||
if err := db.Create(post1).Error; err != nil {
|
||||
t.Fatalf("Failed to create first post: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Create(post2).Error; err == nil {
|
||||
t.Error("Expected error when creating post with duplicate URL")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("post_relationships", func(t *testing.T) {
|
||||
user1 := createTestUser(t, db)
|
||||
user2 := createTestUser(t, db)
|
||||
|
||||
post := createTestPost(t, db, user1.ID)
|
||||
|
||||
vote1 := &Vote{
|
||||
UserID: &user1.ID,
|
||||
PostID: post.ID,
|
||||
Type: VoteUp,
|
||||
}
|
||||
|
||||
vote2 := &Vote{
|
||||
UserID: &user2.ID,
|
||||
PostID: post.ID,
|
||||
Type: VoteDown,
|
||||
}
|
||||
|
||||
if err := db.Create(vote1).Error; err != nil {
|
||||
t.Fatalf("Failed to create vote 1: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Create(vote2).Error; err != nil {
|
||||
t.Fatalf("Failed to create vote 2: %v", err)
|
||||
}
|
||||
|
||||
var foundPost Post
|
||||
if err := db.Preload("Votes").First(&foundPost, post.ID).Error; err != nil {
|
||||
t.Fatalf("Failed to load post with votes: %v", err)
|
||||
}
|
||||
|
||||
if len(foundPost.Votes) != 2 {
|
||||
t.Errorf("Expected 2 votes, got %d", len(foundPost.Votes))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestVote_Model(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
defer func() {
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("create_vote", func(t *testing.T) {
|
||||
user := createTestUser(t, db)
|
||||
post := createTestPost(t, db, user.ID)
|
||||
|
||||
vote := &Vote{
|
||||
UserID: &user.ID,
|
||||
PostID: post.ID,
|
||||
Type: VoteUp,
|
||||
}
|
||||
|
||||
if err := db.Create(vote).Error; err != nil {
|
||||
t.Fatalf("Failed to create vote: %v", err)
|
||||
}
|
||||
|
||||
if vote.ID == 0 {
|
||||
t.Error("Expected vote ID to be set")
|
||||
}
|
||||
|
||||
if vote.CreatedAt.IsZero() {
|
||||
t.Error("Expected CreatedAt to be set")
|
||||
}
|
||||
|
||||
if vote.UpdatedAt.IsZero() {
|
||||
t.Error("Expected UpdatedAt to be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("vote_constraints", func(t *testing.T) {
|
||||
user := createTestUser(t, db)
|
||||
post := createTestPost(t, db, user.ID)
|
||||
|
||||
vote1 := &Vote{
|
||||
UserID: &user.ID,
|
||||
PostID: post.ID,
|
||||
Type: VoteUp,
|
||||
}
|
||||
|
||||
vote2 := &Vote{
|
||||
UserID: &user.ID,
|
||||
PostID: post.ID,
|
||||
Type: VoteDown,
|
||||
}
|
||||
|
||||
if err := db.Create(vote1).Error; err != nil {
|
||||
t.Fatalf("Failed to create first vote: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Create(vote2).Error; err == nil {
|
||||
t.Error("Expected error when creating vote with duplicate user-post combination")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("vote_types", func(t *testing.T) {
|
||||
user := createTestUser(t, db)
|
||||
|
||||
voteTypes := []VoteType{VoteUp, VoteDown, VoteNone}
|
||||
|
||||
for i, voteType := range voteTypes {
|
||||
|
||||
post := &Post{
|
||||
Title: "Test Post " + string(rune(i)),
|
||||
URL: "https://example.com/test" + string(rune(i)),
|
||||
Content: "Test content",
|
||||
AuthorID: &user.ID,
|
||||
}
|
||||
|
||||
if err := db.Create(post).Error; err != nil {
|
||||
t.Fatalf("Failed to create post %d: %v", i, err)
|
||||
}
|
||||
|
||||
vote := &Vote{
|
||||
UserID: &user.ID,
|
||||
PostID: post.ID,
|
||||
Type: voteType,
|
||||
}
|
||||
|
||||
if err := db.Create(vote).Error; err != nil {
|
||||
t.Fatalf("Failed to create vote with type %s: %v", voteType, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("vote_relationships", func(t *testing.T) {
|
||||
user := createTestUser(t, db)
|
||||
post := createTestPost(t, db, user.ID)
|
||||
|
||||
vote := &Vote{
|
||||
UserID: &user.ID,
|
||||
PostID: post.ID,
|
||||
Type: VoteUp,
|
||||
}
|
||||
|
||||
if err := db.Create(vote).Error; err != nil {
|
||||
t.Fatalf("Failed to create vote: %v", err)
|
||||
}
|
||||
|
||||
var foundVote Vote
|
||||
if err := db.Preload("User").Preload("Post").First(&foundVote, vote.ID).Error; err != nil {
|
||||
t.Fatalf("Failed to load vote with relationships: %v", err)
|
||||
}
|
||||
|
||||
if foundVote.User.ID != user.ID {
|
||||
t.Error("Expected vote to be associated with correct user")
|
||||
}
|
||||
|
||||
if foundVote.Post.ID != post.ID {
|
||||
t.Error("Expected vote to be associated with correct post")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshToken_Model(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
defer func() {
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("create_refresh_token", func(t *testing.T) {
|
||||
user := createTestUser(t, db)
|
||||
|
||||
token := &RefreshToken{
|
||||
UserID: user.ID,
|
||||
TokenHash: "hashedtoken123",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
if err := db.Create(token).Error; err != nil {
|
||||
t.Fatalf("Failed to create refresh token: %v", err)
|
||||
}
|
||||
|
||||
if token.ID == 0 {
|
||||
t.Error("Expected token ID to be set")
|
||||
}
|
||||
|
||||
if token.CreatedAt.IsZero() {
|
||||
t.Error("Expected CreatedAt to be set")
|
||||
}
|
||||
|
||||
if token.UpdatedAt.IsZero() {
|
||||
t.Error("Expected UpdatedAt to be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("refresh_token_constraints", func(t *testing.T) {
|
||||
user := createTestUser(t, db)
|
||||
|
||||
token1 := &RefreshToken{
|
||||
UserID: user.ID,
|
||||
TokenHash: "uniquehash",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
token2 := &RefreshToken{
|
||||
UserID: user.ID,
|
||||
TokenHash: "uniquehash",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
if err := db.Create(token1).Error; err != nil {
|
||||
t.Fatalf("Failed to create first token: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Create(token2).Error; err == nil {
|
||||
t.Error("Expected error when creating token with duplicate hash")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccountDeletionRequest_Model(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
defer func() {
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("create_account_deletion_request", func(t *testing.T) {
|
||||
user := createTestUser(t, db)
|
||||
|
||||
request := &AccountDeletionRequest{
|
||||
UserID: user.ID,
|
||||
TokenHash: "deletiontoken123",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
if err := db.Create(request).Error; err != nil {
|
||||
t.Fatalf("Failed to create account deletion request: %v", err)
|
||||
}
|
||||
|
||||
if request.ID == 0 {
|
||||
t.Error("Expected request ID to be set")
|
||||
}
|
||||
|
||||
if request.CreatedAt.IsZero() {
|
||||
t.Error("Expected CreatedAt to be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("account_deletion_request_constraints", func(t *testing.T) {
|
||||
user := createTestUser(t, db)
|
||||
|
||||
request1 := &AccountDeletionRequest{
|
||||
UserID: user.ID,
|
||||
TokenHash: "token1",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
request2 := &AccountDeletionRequest{
|
||||
UserID: user.ID,
|
||||
TokenHash: "token2",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
if err := db.Create(request1).Error; err != nil {
|
||||
t.Fatalf("Failed to create first request: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Create(request2).Error; err == nil {
|
||||
t.Error("Expected error when creating request with duplicate user")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestVoteType_Constants(t *testing.T) {
|
||||
t.Run("vote_type_constants", func(t *testing.T) {
|
||||
if VoteUp != "up" {
|
||||
t.Errorf("Expected VoteUp to be 'up', got '%s'", VoteUp)
|
||||
}
|
||||
|
||||
if VoteDown != "down" {
|
||||
t.Errorf("Expected VoteDown to be 'down', got '%s'", VoteDown)
|
||||
}
|
||||
|
||||
if VoteNone != "none" {
|
||||
t.Errorf("Expected VoteNone to be 'none', got '%s'", VoteNone)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModel_SoftDelete(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
defer func() {
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("user_soft_delete", func(t *testing.T) {
|
||||
user := createTestUser(t, db)
|
||||
|
||||
if err := db.Delete(user).Error; err != nil {
|
||||
t.Fatalf("Failed to soft delete user: %v", err)
|
||||
}
|
||||
|
||||
var foundUser User
|
||||
if err := db.First(&foundUser, user.ID).Error; err == nil {
|
||||
t.Error("Expected user to be soft deleted")
|
||||
}
|
||||
|
||||
if err := db.Unscoped().First(&foundUser, user.ID).Error; err != nil {
|
||||
t.Fatalf("Expected to find soft deleted user with Unscoped: %v", err)
|
||||
}
|
||||
|
||||
if foundUser.DeletedAt.Time.IsZero() {
|
||||
t.Error("Expected DeletedAt to be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("post_soft_delete", func(t *testing.T) {
|
||||
user := createTestUser(t, db)
|
||||
post := createTestPost(t, db, user.ID)
|
||||
|
||||
if err := db.Delete(post).Error; err != nil {
|
||||
t.Fatalf("Failed to soft delete post: %v", err)
|
||||
}
|
||||
|
||||
var foundPost Post
|
||||
if err := db.First(&foundPost, post.ID).Error; err == nil {
|
||||
t.Error("Expected post to be soft deleted")
|
||||
}
|
||||
|
||||
if err := db.Unscoped().First(&foundPost, post.ID).Error; err != nil {
|
||||
t.Fatalf("Expected to find soft deleted post with Unscoped: %v", err)
|
||||
}
|
||||
|
||||
if foundPost.DeletedAt.Time.IsZero() {
|
||||
t.Error("Expected DeletedAt to be set")
|
||||
}
|
||||
})
|
||||
}
|
||||
190
internal/database/monitoring_plugin.go
Normal file
190
internal/database/monitoring_plugin.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"goyco/internal/middleware"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const gormOperationStartKey contextKey = "gorm_operation_start"
|
||||
|
||||
type GormDBMonitor struct {
|
||||
monitor middleware.DBMonitor
|
||||
}
|
||||
|
||||
func NewGormDBMonitor(monitor middleware.DBMonitor) *GormDBMonitor {
|
||||
return &GormDBMonitor{
|
||||
monitor: monitor,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) Name() string {
|
||||
return "db_monitor"
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) Initialize(db *gorm.DB) error {
|
||||
|
||||
db.Callback().Create().Before("gorm:create").Register("db_monitor:before_create", g.beforeCreate)
|
||||
db.Callback().Create().After("gorm:create").Register("db_monitor:after_create", g.afterCreate)
|
||||
|
||||
db.Callback().Query().Before("gorm:query").Register("db_monitor:before_query", g.beforeQuery)
|
||||
db.Callback().Query().After("gorm:query").Register("db_monitor:after_query", g.afterQuery)
|
||||
|
||||
db.Callback().Update().Before("gorm:update").Register("db_monitor:before_update", g.beforeUpdate)
|
||||
db.Callback().Update().After("gorm:update").Register("db_monitor:after_update", g.afterUpdate)
|
||||
|
||||
db.Callback().Delete().Before("gorm:delete").Register("db_monitor:before_delete", g.beforeDelete)
|
||||
db.Callback().Delete().After("gorm:delete").Register("db_monitor:after_delete", g.afterDelete)
|
||||
|
||||
db.Callback().Row().Before("gorm:row").Register("db_monitor:before_row", g.beforeRow)
|
||||
db.Callback().Row().After("gorm:row").Register("db_monitor:after_row", g.afterRow)
|
||||
|
||||
db.Callback().Raw().Before("gorm:raw").Register("db_monitor:before_raw", g.beforeRaw)
|
||||
db.Callback().Raw().After("gorm:raw").Register("db_monitor:after_raw", g.afterRaw)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) beforeCreate(db *gorm.DB) {
|
||||
if g.monitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(db.Statement.Context, gormOperationStartKey, time.Now())
|
||||
db.Statement.Context = ctx
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) afterCreate(db *gorm.DB) {
|
||||
if g.monitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
g.logOperation(db, "CREATE")
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) beforeQuery(db *gorm.DB) {
|
||||
if g.monitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(db.Statement.Context, gormOperationStartKey, time.Now())
|
||||
db.Statement.Context = ctx
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) afterQuery(db *gorm.DB) {
|
||||
if g.monitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
g.logOperation(db, "SELECT")
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) beforeUpdate(db *gorm.DB) {
|
||||
if g.monitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(db.Statement.Context, gormOperationStartKey, time.Now())
|
||||
db.Statement.Context = ctx
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) afterUpdate(db *gorm.DB) {
|
||||
if g.monitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
g.logOperation(db, "UPDATE")
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) beforeDelete(db *gorm.DB) {
|
||||
if g.monitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(db.Statement.Context, gormOperationStartKey, time.Now())
|
||||
db.Statement.Context = ctx
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) afterDelete(db *gorm.DB) {
|
||||
if g.monitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
g.logOperation(db, "DELETE")
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) beforeRow(db *gorm.DB) {
|
||||
if g.monitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(db.Statement.Context, gormOperationStartKey, time.Now())
|
||||
db.Statement.Context = ctx
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) afterRow(db *gorm.DB) {
|
||||
if g.monitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
g.logOperation(db, "ROW")
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) beforeRaw(db *gorm.DB) {
|
||||
if g.monitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(db.Statement.Context, gormOperationStartKey, time.Now())
|
||||
db.Statement.Context = ctx
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) afterRaw(db *gorm.DB) {
|
||||
if g.monitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
g.logOperation(db, "RAW")
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) logOperation(db *gorm.DB, operation string) {
|
||||
if g.monitor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
startTime, ok := db.Statement.Context.Value(gormOperationStartKey).(time.Time)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
|
||||
query := g.buildQueryString(db, operation)
|
||||
|
||||
g.monitor.LogQuery(query, duration, db.Error)
|
||||
}
|
||||
|
||||
func (g *GormDBMonitor) buildQueryString(db *gorm.DB, operation string) string {
|
||||
if db.Statement.SQL.String() != "" {
|
||||
return db.Statement.SQL.String()
|
||||
}
|
||||
|
||||
query := operation
|
||||
|
||||
if db.Statement.Table != "" {
|
||||
query += " FROM " + db.Statement.Table
|
||||
}
|
||||
|
||||
if db.Statement.Model != nil {
|
||||
|
||||
if stmt := db.Statement; stmt.Schema != nil {
|
||||
query = operation + " " + stmt.Schema.Table
|
||||
}
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
325
internal/database/monitoring_plugin_test.go
Normal file
325
internal/database/monitoring_plugin_test.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"goyco/internal/middleware"
|
||||
)
|
||||
|
||||
func TestNewGormDBMonitor(t *testing.T) {
|
||||
monitor := middleware.NewInMemoryDBMonitor()
|
||||
gormMonitor := NewGormDBMonitor(monitor)
|
||||
|
||||
if gormMonitor == nil {
|
||||
t.Fatal("Expected non-nil GormDBMonitor")
|
||||
}
|
||||
|
||||
if gormMonitor.monitor != monitor {
|
||||
t.Error("Expected monitor to be set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGormDBMonitor_Name(t *testing.T) {
|
||||
monitor := middleware.NewInMemoryDBMonitor()
|
||||
gormMonitor := NewGormDBMonitor(monitor)
|
||||
|
||||
if gormMonitor.Name() != "db_monitor" {
|
||||
t.Errorf("Expected name 'db_monitor', got '%s'", gormMonitor.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGormDBMonitor_Initialize(t *testing.T) {
|
||||
monitor := middleware.NewInMemoryDBMonitor()
|
||||
gormMonitor := NewGormDBMonitor(monitor)
|
||||
|
||||
db := newTestDB(t)
|
||||
defer func() {
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
err := gormMonitor.Initialize(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected Initialize to succeed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGormDBMonitor_InitializeWithNilMonitor(t *testing.T) {
|
||||
gormMonitor := NewGormDBMonitor(nil)
|
||||
|
||||
db := newTestDB(t)
|
||||
defer func() {
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
err := gormMonitor.Initialize(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected Initialize to succeed with nil monitor, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGormDBMonitor_Callbacks(t *testing.T) {
|
||||
monitor := middleware.NewInMemoryDBMonitor()
|
||||
gormMonitor := NewGormDBMonitor(monitor)
|
||||
|
||||
db := newTestDB(t)
|
||||
defer func() {
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
err := gormMonitor.Initialize(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize plugin: %v", err)
|
||||
}
|
||||
|
||||
user := &User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashedpassword",
|
||||
EmailVerified: true,
|
||||
}
|
||||
|
||||
if err := db.Create(user).Error; err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
var foundUser User
|
||||
if err := db.First(&foundUser, user.ID).Error; err != nil {
|
||||
t.Fatalf("Failed to find user: %v", err)
|
||||
}
|
||||
|
||||
foundUser.Username = "updateduser"
|
||||
if err := db.Save(&foundUser).Error; err != nil {
|
||||
t.Fatalf("Failed to update user: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Delete(&foundUser).Error; err != nil {
|
||||
t.Fatalf("Failed to delete user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGormDBMonitor_CallbacksWithNilMonitor(t *testing.T) {
|
||||
gormMonitor := NewGormDBMonitor(nil)
|
||||
|
||||
db := newTestDB(t)
|
||||
defer func() {
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
err := gormMonitor.Initialize(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize plugin: %v", err)
|
||||
}
|
||||
|
||||
user := &User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashedpassword",
|
||||
EmailVerified: true,
|
||||
}
|
||||
|
||||
if err := db.Create(user).Error; err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGormDBMonitor_BuildQueryString(t *testing.T) {
|
||||
monitor := middleware.NewInMemoryDBMonitor()
|
||||
gormMonitor := NewGormDBMonitor(monitor)
|
||||
|
||||
db := newTestDB(t)
|
||||
defer func() {
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
err := gormMonitor.Initialize(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize plugin: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
operation string
|
||||
table string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "create_operation",
|
||||
operation: "CREATE",
|
||||
table: "users",
|
||||
expected: "CREATE FROM users",
|
||||
},
|
||||
{
|
||||
name: "select_operation",
|
||||
operation: "SELECT",
|
||||
table: "posts",
|
||||
expected: "SELECT FROM posts",
|
||||
},
|
||||
{
|
||||
name: "update_operation",
|
||||
operation: "UPDATE",
|
||||
table: "votes",
|
||||
expected: "UPDATE FROM votes",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
stmt := &gorm.Statement{
|
||||
Table: tt.table,
|
||||
}
|
||||
|
||||
mockDB := &gorm.DB{
|
||||
Statement: stmt,
|
||||
}
|
||||
|
||||
result := gormMonitor.buildQueryString(mockDB, tt.operation)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGormDBMonitor_LogOperation(t *testing.T) {
|
||||
monitor := middleware.NewInMemoryDBMonitor()
|
||||
gormMonitor := NewGormDBMonitor(monitor)
|
||||
|
||||
startTime := time.Now()
|
||||
ctx := context.WithValue(context.Background(), gormOperationStartKey, startTime)
|
||||
|
||||
stmt := &gorm.Statement{
|
||||
Context: ctx,
|
||||
Table: "users",
|
||||
}
|
||||
|
||||
mockDB := &gorm.DB{
|
||||
Statement: stmt,
|
||||
}
|
||||
|
||||
gormMonitor.logOperation(mockDB, "CREATE")
|
||||
|
||||
gormMonitor.monitor = nil
|
||||
gormMonitor.logOperation(mockDB, "CREATE")
|
||||
}
|
||||
|
||||
func TestGormDBMonitor_LogOperationWithoutStartTime(t *testing.T) {
|
||||
monitor := middleware.NewInMemoryDBMonitor()
|
||||
gormMonitor := NewGormDBMonitor(monitor)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
stmt := &gorm.Statement{
|
||||
Context: ctx,
|
||||
Table: "users",
|
||||
}
|
||||
|
||||
mockDB := &gorm.DB{
|
||||
Statement: stmt,
|
||||
}
|
||||
|
||||
gormMonitor.logOperation(mockDB, "CREATE")
|
||||
}
|
||||
|
||||
func TestGormDBMonitor_AllCallbackMethods(t *testing.T) {
|
||||
monitor := middleware.NewInMemoryDBMonitor()
|
||||
gormMonitor := NewGormDBMonitor(monitor)
|
||||
|
||||
gormMonitor.monitor = nil
|
||||
|
||||
ctx := context.Background()
|
||||
stmt := &gorm.Statement{
|
||||
Context: ctx,
|
||||
Table: "users",
|
||||
}
|
||||
|
||||
mockDB := &gorm.DB{
|
||||
Statement: stmt,
|
||||
}
|
||||
|
||||
gormMonitor.beforeCreate(mockDB)
|
||||
gormMonitor.beforeQuery(mockDB)
|
||||
gormMonitor.beforeUpdate(mockDB)
|
||||
gormMonitor.beforeDelete(mockDB)
|
||||
gormMonitor.beforeRow(mockDB)
|
||||
gormMonitor.beforeRaw(mockDB)
|
||||
|
||||
gormMonitor.afterCreate(mockDB)
|
||||
gormMonitor.afterQuery(mockDB)
|
||||
gormMonitor.afterUpdate(mockDB)
|
||||
gormMonitor.afterDelete(mockDB)
|
||||
gormMonitor.afterRow(mockDB)
|
||||
gormMonitor.afterRaw(mockDB)
|
||||
}
|
||||
|
||||
func TestGormDBMonitor_WithRealDatabase(t *testing.T) {
|
||||
monitor := middleware.NewInMemoryDBMonitor()
|
||||
gormMonitor := NewGormDBMonitor(monitor)
|
||||
|
||||
dbName := "file:memdb_" + t.Name() + "?mode=memory&cache=private"
|
||||
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := db.AutoMigrate(&User{}); err != nil {
|
||||
t.Fatalf("Failed to migrate database: %v", err)
|
||||
}
|
||||
|
||||
err = gormMonitor.Initialize(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize plugin: %v", err)
|
||||
}
|
||||
|
||||
user := &User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashedpassword",
|
||||
EmailVerified: true,
|
||||
}
|
||||
|
||||
if err := db.Create(user).Error; err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
var foundUser User
|
||||
if err := db.First(&foundUser, user.ID).Error; err != nil {
|
||||
t.Fatalf("Failed to find user: %v", err)
|
||||
}
|
||||
|
||||
foundUser.Username = "updateduser"
|
||||
if err := db.Save(&foundUser).Error; err != nil {
|
||||
t.Fatalf("Failed to update user: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Delete(&foundUser).Error; err != nil {
|
||||
t.Fatalf("Failed to delete user: %v", err)
|
||||
}
|
||||
|
||||
stats := monitor.GetStats()
|
||||
if stats.TotalQueries == 0 {
|
||||
t.Error("Expected monitor to have recorded some queries")
|
||||
}
|
||||
}
|
||||
175
internal/database/secure_logger.go
Normal file
175
internal/database/secure_logger.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type SecureLogger struct {
|
||||
writer logger.Writer
|
||||
config logger.Config
|
||||
sensitiveFields []string
|
||||
sensitivePattern *regexp.Regexp
|
||||
productionMode bool
|
||||
}
|
||||
|
||||
func NewSecureLogger(writer logger.Writer, config logger.Config, productionMode bool) *SecureLogger {
|
||||
sensitiveFields := []string{
|
||||
"password", "token", "secret", "key", "hash", "salt",
|
||||
"email_verification_token", "password_reset_token",
|
||||
"token_hash", "jwt_secret", "api_key", "access_token",
|
||||
"refresh_token", "session_id", "cookie", "auth",
|
||||
}
|
||||
|
||||
sensitivePattern := regexp.MustCompile(`(?i)(password|token|secret|key|hash|salt|email_verification_token|password_reset_token|token_hash|jwt_secret|api_key|access_token|refresh_token|session_id|cookie|auth)`)
|
||||
|
||||
return &SecureLogger{
|
||||
writer: writer,
|
||||
config: config,
|
||||
sensitiveFields: sensitiveFields,
|
||||
sensitivePattern: sensitivePattern,
|
||||
productionMode: productionMode,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *SecureLogger) LogMode(level logger.LogLevel) logger.Interface {
|
||||
newLogger := *l
|
||||
newLogger.config.LogLevel = level
|
||||
return &newLogger
|
||||
}
|
||||
|
||||
func (l *SecureLogger) Info(ctx context.Context, msg string, data ...any) {
|
||||
if l.config.LogLevel >= logger.Info {
|
||||
l.log(ctx, "info", msg, data...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *SecureLogger) Warn(ctx context.Context, msg string, data ...any) {
|
||||
if l.config.LogLevel >= logger.Warn {
|
||||
l.log(ctx, "warn", msg, data...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *SecureLogger) Error(ctx context.Context, msg string, data ...any) {
|
||||
if l.config.LogLevel >= logger.Error {
|
||||
l.log(ctx, "error", msg, data...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *SecureLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
if l.config.LogLevel <= logger.Silent {
|
||||
return
|
||||
}
|
||||
|
||||
elapsed := time.Since(begin)
|
||||
switch {
|
||||
case err != nil && l.config.LogLevel >= logger.Error && (!l.config.IgnoreRecordNotFoundError || !IsRecordNotFoundError(err)):
|
||||
sql, rows := fc()
|
||||
l.log(ctx, "error", fmt.Sprintf("[%.3fms] [rows:%v] %s", float64(elapsed.Nanoseconds())/1e6, rows, sql))
|
||||
case elapsed > l.config.SlowThreshold && l.config.SlowThreshold != 0 && l.config.LogLevel >= logger.Warn:
|
||||
sql, rows := fc()
|
||||
l.log(ctx, "warn", fmt.Sprintf("[%.3fms] [rows:%v] %s", float64(elapsed.Nanoseconds())/1e6, rows, sql))
|
||||
case l.config.LogLevel == logger.Info:
|
||||
sql, rows := fc()
|
||||
l.log(ctx, "info", fmt.Sprintf("[%.3fms] [rows:%v] %s", float64(elapsed.Nanoseconds())/1e6, rows, sql))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *SecureLogger) log(_ context.Context, level, msg string, data ...any) {
|
||||
if l.productionMode {
|
||||
msg = l.maskSensitiveData(msg)
|
||||
|
||||
maskedData := make([]any, len(data))
|
||||
for i, d := range data {
|
||||
maskedData[i] = l.maskSensitiveData(fmt.Sprintf("%v", d))
|
||||
}
|
||||
data = maskedData
|
||||
}
|
||||
|
||||
formattedMsg := fmt.Sprintf(msg, data...)
|
||||
|
||||
l.writer.Printf("[%s] %s", strings.ToUpper(level), formattedMsg)
|
||||
}
|
||||
|
||||
func (l *SecureLogger) maskSensitiveData(data string) string {
|
||||
if l.productionMode {
|
||||
data = regexp.MustCompile(`\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`).ReplaceAllString(data, "[EMAIL_MASKED]")
|
||||
|
||||
data = regexp.MustCompile(`\b[A-Za-z0-9]{20,}\b`).ReplaceAllStringFunc(data, func(match string) string {
|
||||
if l.sensitivePattern.MatchString(match) {
|
||||
return "[TOKEN_MASKED]"
|
||||
}
|
||||
return match
|
||||
})
|
||||
|
||||
data = l.maskSQLValues(data)
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func (l *SecureLogger) maskSQLValues(sql string) string {
|
||||
paramPattern := regexp.MustCompile(`'([^']*)'`)
|
||||
|
||||
return paramPattern.ReplaceAllStringFunc(sql, func(match string) string {
|
||||
value := strings.Trim(match, "'")
|
||||
|
||||
if l.isSensitiveValue(value) {
|
||||
return "'[MASKED]'"
|
||||
}
|
||||
|
||||
return match
|
||||
})
|
||||
}
|
||||
|
||||
func (l *SecureLogger) isSensitiveValue(value string) bool {
|
||||
if regexp.MustCompile(`\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`).MatchString(value) {
|
||||
return true
|
||||
}
|
||||
|
||||
if len(value) > 20 && regexp.MustCompile(`^[A-Za-z0-9+/]{20,}={0,2}$`).MatchString(value) {
|
||||
return true
|
||||
}
|
||||
|
||||
if regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`).MatchString(value) {
|
||||
return true
|
||||
}
|
||||
|
||||
if regexp.MustCompile(`^[A-Za-z0-9+/]+={0,2}$`).MatchString(value) && len(value) > 10 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func IsRecordNotFoundError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(err.Error()), "record not found") ||
|
||||
strings.Contains(strings.ToLower(err.Error()), "not found")
|
||||
}
|
||||
|
||||
func CreateSecureLogger(productionMode bool) logger.Interface {
|
||||
config := logger.Config{
|
||||
SlowThreshold: time.Second,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: false,
|
||||
}
|
||||
|
||||
if productionMode {
|
||||
config.LogLevel = logger.Error
|
||||
config.SlowThreshold = 2 * time.Second
|
||||
}
|
||||
|
||||
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
||||
return NewSecureLogger(writer, config, productionMode)
|
||||
}
|
||||
368
internal/database/secure_logger_test.go
Normal file
368
internal/database/secure_logger_test.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
func TestSecureLogger_MaskSensitiveData(t *testing.T) {
|
||||
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
||||
config := logger.Config{
|
||||
SlowThreshold: time.Second,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: false,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
production bool
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "development_mode_no_masking",
|
||||
production: false,
|
||||
input: "SELECT * FROM users WHERE email = 'user@example.com'",
|
||||
expected: "SELECT * FROM users WHERE email = 'user@example.com'",
|
||||
},
|
||||
{
|
||||
name: "production_mode_mask_email",
|
||||
production: true,
|
||||
input: "SELECT * FROM users WHERE email = 'user@example.com'",
|
||||
expected: "SELECT * FROM users WHERE email = '[EMAIL_MASKED]'",
|
||||
},
|
||||
{
|
||||
name: "production_mode_mask_token",
|
||||
production: true,
|
||||
input: "SELECT * FROM users WHERE password_reset_token = 'abc123def456ghi789'",
|
||||
expected: "SELECT * FROM users WHERE password_reset_token = '[TOKEN_MASKED]'",
|
||||
},
|
||||
{
|
||||
name: "production_mode_mask_uuid",
|
||||
production: true,
|
||||
input: "SELECT * FROM users WHERE id = '550e8400-e29b-41d4-a716-446655440000'",
|
||||
expected: "SELECT * FROM users WHERE id = '[TOKEN_MASKED]'",
|
||||
},
|
||||
{
|
||||
name: "production_mode_no_masking_short_values",
|
||||
production: true,
|
||||
input: "SELECT * FROM users WHERE id = 123",
|
||||
expected: "SELECT * FROM users WHERE id = 123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
secureLogger := NewSecureLogger(writer, config, tt.production)
|
||||
result := secureLogger.maskSensitiveData(tt.input)
|
||||
|
||||
if tt.production {
|
||||
if strings.Contains(result, "user@example.com") {
|
||||
t.Errorf("Email should be masked in production mode")
|
||||
}
|
||||
if strings.Contains(result, "abc123def456ghi789") {
|
||||
t.Errorf("Token should be masked in production mode")
|
||||
}
|
||||
} else {
|
||||
if result != tt.input {
|
||||
t.Errorf("Expected %q, got %q", tt.input, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureLogger_IsSensitiveValue(t *testing.T) {
|
||||
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
||||
config := logger.Config{
|
||||
SlowThreshold: time.Second,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: false,
|
||||
}
|
||||
|
||||
secureLogger := NewSecureLogger(writer, config, true)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "email_address",
|
||||
value: "user@example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "long_token",
|
||||
value: "abc123def456ghi789jkl012mno345pqr678stu901vwx234yz",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "uuid",
|
||||
value: "550e8400-e29b-41d4-a716-446655440000",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "short_value",
|
||||
value: "123",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "normal_text",
|
||||
value: "golang programming",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "base64_like",
|
||||
value: "SGVsbG8gV29ybGQ=",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := secureLogger.isSensitiveValue(tt.value)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v for value %q, got %v", tt.expected, tt.value, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureLogger_LogLevels(t *testing.T) {
|
||||
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
||||
config := logger.Config{
|
||||
SlowThreshold: time.Second,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: false,
|
||||
}
|
||||
|
||||
secureLogger := NewSecureLogger(writer, config, false)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
secureLogger.Info(ctx, "Test info message")
|
||||
secureLogger.Warn(ctx, "Test warn message")
|
||||
secureLogger.Error(ctx, "Test error message")
|
||||
}
|
||||
|
||||
func TestCreateSecureLogger(t *testing.T) {
|
||||
prodLogger := CreateSecureLogger(true)
|
||||
if prodLogger == nil {
|
||||
t.Error("Expected non-nil logger for production mode")
|
||||
}
|
||||
|
||||
devLogger := CreateSecureLogger(false)
|
||||
if devLogger == nil {
|
||||
t.Error("Expected non-nil logger for development mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureLogger_LogMode(t *testing.T) {
|
||||
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
||||
config := logger.Config{
|
||||
SlowThreshold: time.Second,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: false,
|
||||
}
|
||||
|
||||
secureLogger := NewSecureLogger(writer, config, false)
|
||||
|
||||
newLogger := secureLogger.LogMode(logger.Error)
|
||||
if newLogger == nil {
|
||||
t.Error("Expected non-nil logger from LogMode")
|
||||
}
|
||||
|
||||
if secureLogger.config.LogLevel != logger.Info {
|
||||
t.Error("Original logger should be unchanged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureLogger_Trace(t *testing.T) {
|
||||
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
||||
config := logger.Config{
|
||||
SlowThreshold: time.Second,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: false,
|
||||
}
|
||||
|
||||
secureLogger := NewSecureLogger(writer, config, false)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("silent_level", func(t *testing.T) {
|
||||
silentLogger := secureLogger.LogMode(logger.Silent)
|
||||
silentLogger.Trace(ctx, time.Now(), func() (string, int64) {
|
||||
return "SELECT * FROM users", 1
|
||||
}, nil)
|
||||
})
|
||||
|
||||
t.Run("error_level_with_error", func(t *testing.T) {
|
||||
errorLogger := secureLogger.LogMode(logger.Error)
|
||||
errorLogger.Trace(ctx, time.Now(), func() (string, int64) {
|
||||
return "SELECT * FROM users", 1
|
||||
}, errors.New("test error"))
|
||||
})
|
||||
|
||||
t.Run("warn_level_slow_query", func(t *testing.T) {
|
||||
warnLogger := secureLogger.LogMode(logger.Warn)
|
||||
|
||||
startTime := time.Now().Add(-2 * time.Second)
|
||||
warnLogger.Trace(ctx, startTime, func() (string, int64) {
|
||||
return "SELECT * FROM users", 1
|
||||
}, nil)
|
||||
})
|
||||
|
||||
t.Run("info_level", func(t *testing.T) {
|
||||
infoLogger := secureLogger.LogMode(logger.Info)
|
||||
infoLogger.Trace(ctx, time.Now(), func() (string, int64) {
|
||||
return "SELECT * FROM users", 1
|
||||
}, nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSecureLogger_MaskSQLValues(t *testing.T) {
|
||||
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
||||
config := logger.Config{
|
||||
SlowThreshold: time.Second,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: false,
|
||||
}
|
||||
|
||||
secureLogger := NewSecureLogger(writer, config, true)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "email_in_sql",
|
||||
sql: "SELECT * FROM users WHERE email = 'user@example.com'",
|
||||
expected: "SELECT * FROM users WHERE email = '[MASKED]'",
|
||||
},
|
||||
{
|
||||
name: "token_in_sql",
|
||||
sql: "SELECT * FROM users WHERE token = 'abc123def456ghi789'",
|
||||
expected: "SELECT * FROM users WHERE token = '[MASKED]'",
|
||||
},
|
||||
{
|
||||
name: "uuid_in_sql",
|
||||
sql: "SELECT * FROM users WHERE id = '550e8400-e29b-41d4-a716-446655440000'",
|
||||
expected: "SELECT * FROM users WHERE id = '[MASKED]'",
|
||||
},
|
||||
{
|
||||
name: "normal_value",
|
||||
sql: "SELECT * FROM users WHERE id = 123",
|
||||
expected: "SELECT * FROM users WHERE id = 123",
|
||||
},
|
||||
{
|
||||
name: "multiple_values",
|
||||
sql: "SELECT * FROM users WHERE email = 'user@example.com' AND id = 123",
|
||||
expected: "SELECT * FROM users WHERE email = '[MASKED]' AND id = 123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := secureLogger.maskSQLValues(tt.sql)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureLogger_IsRecordNotFoundError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "record_not_found",
|
||||
err: errors.New("record not found"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "not_found",
|
||||
err: errors.New("not found"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "RECORD NOT FOUND",
|
||||
err: errors.New("RECORD NOT FOUND"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "NOT FOUND",
|
||||
err: errors.New("NOT FOUND"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "other_error",
|
||||
err: errors.New("connection failed"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "nil_error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsRecordNotFoundError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v for error '%v', got %v", tt.expected, tt.err, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureLogger_ProductionMode(t *testing.T) {
|
||||
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
||||
config := logger.Config{
|
||||
SlowThreshold: time.Second,
|
||||
LogLevel: logger.Error,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: false,
|
||||
}
|
||||
|
||||
secureLogger := NewSecureLogger(writer, config, true)
|
||||
ctx := context.Background()
|
||||
|
||||
secureLogger.Info(ctx, "User login: %s", "user@example.com")
|
||||
secureLogger.Warn(ctx, "Token validation: %s", "abc123def456ghi789")
|
||||
secureLogger.Error(ctx, "Database error: %s", "connection failed")
|
||||
}
|
||||
|
||||
func TestSecureLogger_DevelopmentMode(t *testing.T) {
|
||||
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
|
||||
config := logger.Config{
|
||||
SlowThreshold: time.Second,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: false,
|
||||
}
|
||||
|
||||
secureLogger := NewSecureLogger(writer, config, false)
|
||||
ctx := context.Background()
|
||||
|
||||
secureLogger.Info(ctx, "User login: %s", "user@example.com")
|
||||
secureLogger.Warn(ctx, "Token validation: %s", "abc123def456ghi789")
|
||||
secureLogger.Error(ctx, "Database error: %s", "connection failed")
|
||||
}
|
||||
Reference in New Issue
Block a user