To gitea and beyond, let's go(-yco)

This commit is contained in:
2025-11-10 19:12:09 +01:00
parent 8f6133392d
commit 71a031342b
245 changed files with 83994 additions and 0 deletions

View File

@@ -0,0 +1,77 @@
package database
import (
"fmt"
"goyco/internal/config"
"goyco/internal/middleware"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
func connectDB(cfg *config.Config) (*gorm.DB, error) {
dsn := cfg.GetConnectionString()
gormLogger := CreateSecureLogger(!cfg.App.Debug)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: gormLogger,
})
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
return db, nil
}
func Connect(cfg *config.Config) (*gorm.DB, error) {
return connectDB(cfg)
}
func ConnectWithMonitoring(cfg *config.Config, monitor middleware.DBMonitor) (*gorm.DB, error) {
db, err := connectDB(cfg)
if err != nil {
return nil, err
}
if monitor != nil {
monitoringPlugin := NewGormDBMonitor(monitor)
if err := db.Use(monitoringPlugin); err != nil {
return nil, fmt.Errorf("failed to add monitoring plugin: %w", err)
}
}
return db, nil
}
func Migrate(db *gorm.DB) error {
if db == nil {
return fmt.Errorf("database connection is nil")
}
err := db.AutoMigrate(
&User{},
&Post{},
&Vote{},
&AccountDeletionRequest{},
&RefreshToken{},
)
if err != nil {
return fmt.Errorf("failed to migrate database: %w", err)
}
return nil
}
func Close(db *gorm.DB) error {
if db == nil {
return nil
}
sqlDB, err := db.DB()
if err != nil {
return fmt.Errorf("failed to get underlying sql.DB: %w", err)
}
return sqlDB.Close()
}

View File

@@ -0,0 +1,169 @@
package database
import (
"context"
"database/sql"
"fmt"
"log"
"time"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"goyco/internal/config"
)
type ConnectionPoolConfig struct {
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
ConnTimeout time.Duration
HealthCheckInterval time.Duration
}
func DefaultConnectionPoolConfig() ConnectionPoolConfig {
return ConnectionPoolConfig{
MaxOpenConns: 25,
MaxIdleConns: 10,
ConnMaxLifetime: 5 * time.Minute,
ConnMaxIdleTime: 1 * time.Minute,
ConnTimeout: 30 * time.Second,
HealthCheckInterval: 30 * time.Second,
}
}
func ProductionConnectionPoolConfig() ConnectionPoolConfig {
return ConnectionPoolConfig{
MaxOpenConns: 100,
MaxIdleConns: 25,
ConnMaxLifetime: 10 * time.Minute,
ConnMaxIdleTime: 2 * time.Minute,
ConnTimeout: 10 * time.Second,
HealthCheckInterval: 15 * time.Second,
}
}
func HighTrafficConnectionPoolConfig() ConnectionPoolConfig {
return ConnectionPoolConfig{
MaxOpenConns: 200,
MaxIdleConns: 50,
ConnMaxLifetime: 15 * time.Minute,
ConnMaxIdleTime: 5 * time.Minute,
ConnTimeout: 5 * time.Second,
HealthCheckInterval: 10 * time.Second,
}
}
type ConnectionPoolManager struct {
db *gorm.DB
sqlDB *sql.DB
config ConnectionPoolConfig
ctx context.Context
cancel context.CancelFunc
}
func NewConnectionPoolManager(cfg *config.Config, poolConfig ConnectionPoolConfig) (*ConnectionPoolManager, error) {
dsn := cfg.GetConnectionString()
secureLogger := CreateSecureLogger(!cfg.App.Debug)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: secureLogger,
})
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err)
}
sqlDB.SetMaxOpenConns(poolConfig.MaxOpenConns)
sqlDB.SetMaxIdleConns(poolConfig.MaxIdleConns)
sqlDB.SetConnMaxLifetime(poolConfig.ConnMaxLifetime)
sqlDB.SetConnMaxIdleTime(poolConfig.ConnMaxIdleTime)
ctx, cancel := context.WithTimeout(context.Background(), poolConfig.ConnTimeout)
if err := sqlDB.PingContext(ctx); err != nil {
cancel()
return nil, fmt.Errorf("failed to ping database: %w", err)
}
cancel()
managerCtx, managerCancel := context.WithCancel(context.Background())
manager := &ConnectionPoolManager{
db: db,
sqlDB: sqlDB,
config: poolConfig,
ctx: managerCtx,
cancel: managerCancel,
}
go manager.startHealthCheck()
return manager, nil
}
func (m *ConnectionPoolManager) GetDB() *gorm.DB {
return m.db
}
func (m *ConnectionPoolManager) GetSQLDB() *sql.DB {
return m.sqlDB
}
func (m *ConnectionPoolManager) GetPoolStats() sql.DBStats {
return m.sqlDB.Stats()
}
func (m *ConnectionPoolManager) startHealthCheck() {
ticker := time.NewTicker(m.config.HealthCheckInterval)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.performHealthCheck()
}
}
}
func (m *ConnectionPoolManager) performHealthCheck() {
ctx, cancel := context.WithTimeout(m.ctx, m.config.ConnTimeout)
defer cancel()
if err := m.sqlDB.PingContext(ctx); err != nil {
log.Printf("Database health check failed: %v", err)
}
}
func (m *ConnectionPoolManager) Close() error {
if m.cancel != nil {
m.cancel()
}
if m.sqlDB != nil {
return m.sqlDB.Close()
}
return nil
}
func ConnectWithPool(cfg *config.Config) (*ConnectionPoolManager, error) {
var poolConfig ConnectionPoolConfig
if cfg.App.Debug {
poolConfig = DefaultConnectionPoolConfig()
} else {
poolConfig = ProductionConnectionPoolConfig()
}
if cfg.App.BaseURL != "" && !cfg.App.Debug {
poolConfig = HighTrafficConnectionPoolConfig()
}
return NewConnectionPoolManager(cfg, poolConfig)
}

View File

@@ -0,0 +1,253 @@
package database
import (
"strings"
"testing"
"time"
"goyco/internal/config"
)
func TestConnectionPoolConfig(t *testing.T) {
t.Run("default_config", func(t *testing.T) {
config := DefaultConnectionPoolConfig()
if config.MaxOpenConns <= 0 {
t.Error("MaxOpenConns should be positive")
}
if config.MaxIdleConns <= 0 {
t.Error("MaxIdleConns should be positive")
}
if config.ConnMaxLifetime <= 0 {
t.Error("ConnMaxLifetime should be positive")
}
if config.ConnMaxIdleTime <= 0 {
t.Error("ConnMaxIdleTime should be positive")
}
if config.ConnTimeout <= 0 {
t.Error("ConnTimeout should be positive")
}
if config.HealthCheckInterval <= 0 {
t.Error("HealthCheckInterval should be positive")
}
})
t.Run("production_config", func(t *testing.T) {
config := ProductionConnectionPoolConfig()
if config.MaxOpenConns < 50 {
t.Error("Production MaxOpenConns should be higher")
}
if config.MaxIdleConns < 10 {
t.Error("Production MaxIdleConns should be higher")
}
})
t.Run("high_traffic_config", func(t *testing.T) {
config := HighTrafficConnectionPoolConfig()
if config.MaxOpenConns < 100 {
t.Error("High traffic MaxOpenConns should be very high")
}
if config.MaxIdleConns < 25 {
t.Error("High traffic MaxIdleConns should be high")
}
})
}
func TestConnectionPoolManager_Stats(t *testing.T) {
t.Run("config_validation", func(t *testing.T) {
config := DefaultConnectionPoolConfig()
if config.MaxOpenConns < config.MaxIdleConns {
t.Error("MaxOpenConns should be >= MaxIdleConns")
}
if config.ConnMaxLifetime < config.ConnMaxIdleTime {
t.Error("ConnMaxLifetime should be >= ConnMaxIdleTime")
}
if config.ConnTimeout > 60*time.Second {
t.Error("ConnTimeout should be reasonable")
}
})
}
func TestConnectionPoolConfig_Values(t *testing.T) {
tests := []struct {
name string
config ConnectionPoolConfig
}{
{
name: "default",
config: DefaultConnectionPoolConfig(),
},
{
name: "production",
config: ProductionConnectionPoolConfig(),
},
{
name: "high_traffic",
config: HighTrafficConnectionPoolConfig(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := tt.config
if config.MaxOpenConns <= 0 {
t.Errorf("MaxOpenConns should be positive, got %d", config.MaxOpenConns)
}
if config.MaxIdleConns <= 0 {
t.Errorf("MaxIdleConns should be positive, got %d", config.MaxIdleConns)
}
if config.ConnMaxLifetime <= 0 {
t.Errorf("ConnMaxLifetime should be positive, got %v", config.ConnMaxLifetime)
}
if config.ConnMaxIdleTime <= 0 {
t.Errorf("ConnMaxIdleTime should be positive, got %v", config.ConnMaxIdleTime)
}
if config.ConnTimeout <= 0 {
t.Errorf("ConnTimeout should be positive, got %v", config.ConnTimeout)
}
if config.HealthCheckInterval <= 0 {
t.Errorf("HealthCheckInterval should be positive, got %v", config.HealthCheckInterval)
}
if config.MaxOpenConns < config.MaxIdleConns {
t.Errorf("MaxOpenConns (%d) should be >= MaxIdleConns (%d)", config.MaxOpenConns, config.MaxIdleConns)
}
if config.ConnMaxLifetime < config.ConnMaxIdleTime {
t.Errorf("ConnMaxLifetime (%v) should be >= ConnMaxIdleTime (%v)", config.ConnMaxLifetime, config.ConnMaxIdleTime)
}
})
}
}
func TestNewConnectionPoolManager(t *testing.T) {
t.Run("invalid_database_config", func(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "invalid-host",
Port: "9999",
User: "invalid",
Password: "invalid",
Name: "invalid",
SSLMode: "disable",
},
App: config.AppConfig{
Debug: true,
},
}
poolConfig := DefaultConnectionPoolConfig()
manager, err := NewConnectionPoolManager(cfg, poolConfig)
if err == nil {
t.Error("Expected error with invalid database config")
}
if manager != nil {
t.Error("Expected nil manager with invalid database config")
}
if !strings.Contains(err.Error(), "failed to connect to database") {
t.Errorf("Expected connection error, got: %v", err)
}
})
}
func TestConnectionPoolManager_Methods(t *testing.T) {
t.Run("get_db_methods", func(t *testing.T) {
manager := &ConnectionPoolManager{
db: nil,
sqlDB: nil,
}
if manager.GetDB() != nil {
t.Error("Expected nil DB from uninitialized manager")
}
if manager.GetSQLDB() != nil {
t.Error("Expected nil SQLDB from uninitialized manager")
}
})
}
func TestConnectWithPool(t *testing.T) {
t.Run("debug_mode_config", func(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "localhost",
Port: "5432",
User: "test",
Password: "test",
Name: "test_db",
SSLMode: "disable",
},
App: config.AppConfig{
Debug: true,
},
}
manager, err := ConnectWithPool(cfg)
if err == nil {
t.Error("Expected error with invalid database config")
}
if manager != nil {
t.Error("Expected nil manager with invalid database config")
}
})
t.Run("production_mode_config", func(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "localhost",
Port: "5432",
User: "test",
Password: "test",
Name: "test_db",
SSLMode: "disable",
},
App: config.AppConfig{
Debug: false,
},
}
manager, err := ConnectWithPool(cfg)
if err == nil {
t.Error("Expected error with invalid database config")
}
if manager != nil {
t.Error("Expected nil manager with invalid database config")
}
})
t.Run("high_traffic_config", func(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "localhost",
Port: "5432",
User: "test",
Password: "test",
Name: "test_db",
SSLMode: "disable",
},
App: config.AppConfig{
Debug: false,
BaseURL: "https://example.com",
},
}
manager, err := ConnectWithPool(cfg)
if err == nil {
t.Error("Expected error with invalid database config")
}
if manager != nil {
t.Error("Expected nil manager with invalid database config")
}
})
}

View File

@@ -0,0 +1,156 @@
package database
import (
"context"
"strings"
"testing"
"time"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"goyco/internal/config"
"goyco/internal/middleware"
)
func TestConnectReturnsErrorWhenUnableToReachDatabase(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
done := make(chan error, 1)
go func() {
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "127.0.0.1",
Port: "1",
User: "postgres",
Password: "password",
Name: "goyco_test",
SSLMode: "disable",
},
}
_, err := Connect(cfg)
done <- err
}()
select {
case err := <-done:
if err == nil {
t.Fatalf("expected connection error but got nil")
}
if !strings.Contains(err.Error(), "failed to connect to database") {
t.Fatalf("unexpected error: %v", err)
}
case <-ctx.Done():
t.Fatalf("connection test timed out after 5 seconds")
}
}
func TestMigrateFailsWhenDBNil(t *testing.T) {
err := Migrate(nil)
if err == nil {
t.Fatalf("expected error when DB is nil")
}
}
func TestMigrateCreatesTables(t *testing.T) {
dbName := "file:memdb_" + t.Name() + "?mode=memory&cache=private"
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("failed to open sqlite in-memory database: %v", err)
}
if err := Migrate(db); err != nil {
t.Fatalf("expected migrations to succeed, got error: %v", err)
}
migrator := db.Migrator()
models := []any{&User{}, &Post{}, &Vote{}}
for _, model := range models {
if !migrator.HasTable(model) {
t.Fatalf("expected table for %T to exist after migration", model)
}
}
}
func TestCloseReturnsNilWhenDBNil(t *testing.T) {
if err := Close(nil); err != nil {
t.Fatalf("expected nil error when DB is nil, got %v", err)
}
}
func TestCloseClosesUnderlyingConnection(t *testing.T) {
dbName := "file:memdb_" + t.Name() + "?mode=memory&cache=private"
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("failed to open sqlite in-memory database: %v", err)
}
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("failed to get sql.DB: %v", err)
}
if err := Close(db); err != nil {
t.Fatalf("expected close to succeed, got %v", err)
}
if err := sqlDB.Ping(); err == nil {
t.Fatalf("expected ping on closed connection to fail")
}
}
func TestConnectWithMonitoring(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "localhost",
Port: "5432",
User: "test",
Password: "test",
Name: "test_db",
SSLMode: "disable",
},
App: config.AppConfig{
Debug: true,
},
}
_, err := ConnectWithMonitoring(cfg, nil)
if err == nil {
t.Fatalf("expected connection error with invalid database config")
}
if !strings.Contains(err.Error(), "failed to connect to database") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestConnectWithMonitoringWithValidMonitor(t *testing.T) {
mockMonitor := middleware.NewInMemoryDBMonitor()
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "localhost",
Port: "5432",
User: "test",
Password: "test",
Name: "test_db",
SSLMode: "disable",
},
App: config.AppConfig{
Debug: true,
},
}
_, err := ConnectWithMonitoring(cfg, mockMonitor)
if err == nil {
t.Fatalf("expected connection error with invalid database config")
}
if !strings.Contains(err.Error(), "failed to connect to database") {
t.Fatalf("unexpected error: %v", err)
}
}

View File

@@ -0,0 +1,88 @@
package database
import (
"time"
"gorm.io/gorm"
)
type Post struct {
ID uint `gorm:"primaryKey"`
Title string `gorm:"not null"`
URL string `gorm:"uniqueIndex"`
Content string
AuthorID *uint
AuthorName string
Author User `gorm:"foreignKey:AuthorID;constraint:OnDelete:CASCADE"`
UpVotes int `gorm:"default:0"`
DownVotes int `gorm:"default:0"`
Score int `gorm:"default:0"`
Votes []Vote `gorm:"foreignKey:PostID;constraint:OnDelete:CASCADE"`
CurrentVote VoteType `gorm:"-"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
type User struct {
ID uint `gorm:"primaryKey"`
Username string `gorm:"uniqueIndex;not null"`
Email string `gorm:"uniqueIndex;not null"`
Password string `gorm:"not null"`
EmailVerified bool `gorm:"default:false;not null"`
EmailVerifiedAt *time.Time
EmailVerificationToken string `gorm:"index"`
EmailVerificationSentAt *time.Time
PasswordResetToken string `gorm:"index"`
PasswordResetSentAt *time.Time
PasswordResetExpiresAt *time.Time
Locked bool `gorm:"default:false"`
SessionVersion uint `gorm:"default:1;not null"`
RefreshTokens []RefreshToken `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"`
Posts []Post `gorm:"foreignKey:AuthorID"`
Votes []Vote `gorm:"foreignKey:UserID"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
type RefreshToken struct {
ID uint `gorm:"primaryKey"`
UserID uint `gorm:"not null;index"`
User User `gorm:"constraint:OnDelete:CASCADE"`
TokenHash string `gorm:"uniqueIndex;not null"`
ExpiresAt time.Time `gorm:"not null;index"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
type AccountDeletionRequest struct {
ID uint `gorm:"primaryKey"`
UserID uint `gorm:"uniqueIndex"`
User User `gorm:"constraint:OnDelete:CASCADE"`
TokenHash string `gorm:"uniqueIndex;not null"`
ExpiresAt time.Time `gorm:"not null"`
CreatedAt time.Time
}
type Vote struct {
ID uint `gorm:"primaryKey"`
UserID *uint `gorm:"uniqueIndex:idx_user_post_vote,where:deleted_at IS NULL AND user_id IS NOT NULL"`
User *User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"`
PostID uint `gorm:"not null;uniqueIndex:idx_user_post_vote,where:deleted_at IS NULL AND user_id IS NOT NULL;uniqueIndex:idx_hash_post_vote,where:deleted_at IS NULL AND vote_hash IS NOT NULL"`
Post Post `gorm:"foreignKey:PostID;constraint:OnDelete:CASCADE"`
Type VoteType `gorm:"not null"`
VoteHash *string `gorm:"uniqueIndex:idx_hash_post_vote,where:deleted_at IS NULL AND vote_hash IS NOT NULL"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
type VoteType string
const (
VoteUp VoteType = "up"
VoteDown VoteType = "down"
VoteNone VoteType = "none"
)

View File

@@ -0,0 +1,603 @@
package database
import (
"fmt"
"testing"
"time"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func newTestDB(t *testing.T) *gorm.DB {
t.Helper()
dbName := "file:memdb_" + t.Name() + "?mode=memory&cache=shared&_journal_mode=WAL&_synchronous=NORMAL"
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("Failed to connect to test database: %v", err)
}
err = db.AutoMigrate(
&User{},
&Post{},
&Vote{},
&AccountDeletionRequest{},
&RefreshToken{},
)
if err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
if execErr := db.Exec("PRAGMA busy_timeout = 5000").Error; execErr != nil {
t.Fatalf("Failed to configure busy timeout: %v", execErr)
}
if execErr := db.Exec("PRAGMA foreign_keys = ON").Error; execErr != nil {
t.Fatalf("Failed to enable foreign keys: %v", execErr)
}
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("Failed to access SQL DB: %v", err)
}
sqlDB.SetMaxOpenConns(1)
sqlDB.SetMaxIdleConns(1)
sqlDB.SetConnMaxLifetime(5 * time.Minute)
return db
}
func createTestUser(t *testing.T, db *gorm.DB) *User {
t.Helper()
uniqueID := time.Now().UnixNano()
user := &User{
Username: fmt.Sprintf("testuser%d", uniqueID),
Email: fmt.Sprintf("test%d@example.com", uniqueID),
Password: "hashedpassword123",
EmailVerified: true,
}
if err := db.Create(user).Error; err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
return user
}
func createTestPost(t *testing.T, db *gorm.DB, authorID uint) *Post {
t.Helper()
post := &Post{
Title: "Test Post " + t.Name(),
URL: "https://example.com/test" + t.Name(),
Content: "Test content",
AuthorID: &authorID,
}
if err := db.Create(post).Error; err != nil {
t.Fatalf("Failed to create test post: %v", err)
}
return post
}
func TestUser_Model(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}()
t.Run("create_user", func(t *testing.T) {
user := &User{
Username: "testuser",
Email: "test@example.com",
Password: "hashedpassword",
EmailVerified: true,
}
if err := db.Create(user).Error; err != nil {
t.Fatalf("Failed to create user: %v", err)
}
if user.ID == 0 {
t.Error("Expected user ID to be set")
}
if user.CreatedAt.IsZero() {
t.Error("Expected CreatedAt to be set")
}
if user.UpdatedAt.IsZero() {
t.Error("Expected UpdatedAt to be set")
}
})
t.Run("user_constraints", func(t *testing.T) {
user1 := &User{
Username: "duplicate",
Email: "user1@example.com",
Password: "hashedpassword",
EmailVerified: true,
}
user2 := &User{
Username: "duplicate",
Email: "user2@example.com",
Password: "hashedpassword",
EmailVerified: true,
}
if err := db.Create(user1).Error; err != nil {
t.Fatalf("Failed to create first user: %v", err)
}
if err := db.Create(user2).Error; err == nil {
t.Error("Expected error when creating user with duplicate username")
}
user3 := &User{
Username: "unique",
Email: "user1@example.com",
Password: "hashedpassword",
EmailVerified: true,
}
if err := db.Create(user3).Error; err == nil {
t.Error("Expected error when creating user with duplicate email")
}
})
t.Run("user_relationships", func(t *testing.T) {
user := &User{
Username: "author",
Email: "author@example.com",
Password: "hashedpassword",
EmailVerified: true,
}
if err := db.Create(user).Error; err != nil {
t.Fatalf("Failed to create user: %v", err)
}
post1 := &Post{
Title: "Post 1",
URL: "https://example.com/1",
Content: "Content 1",
AuthorID: &user.ID,
}
post2 := &Post{
Title: "Post 2",
URL: "https://example.com/2",
Content: "Content 2",
AuthorID: &user.ID,
}
if err := db.Create(post1).Error; err != nil {
t.Fatalf("Failed to create post 1: %v", err)
}
if err := db.Create(post2).Error; err != nil {
t.Fatalf("Failed to create post 2: %v", err)
}
var foundUser User
if err := db.Preload("Posts").First(&foundUser, user.ID).Error; err != nil {
t.Fatalf("Failed to load user with posts: %v", err)
}
if len(foundUser.Posts) != 2 {
t.Errorf("Expected 2 posts, got %d", len(foundUser.Posts))
}
})
}
func TestPost_Model(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}()
t.Run("create_post", func(t *testing.T) {
user := createTestUser(t, db)
post := &Post{
Title: "Test Post",
URL: "https://example.com/test",
Content: "Test content",
AuthorID: &user.ID,
}
if err := db.Create(post).Error; err != nil {
t.Fatalf("Failed to create post: %v", err)
}
if post.ID == 0 {
t.Error("Expected post ID to be set")
}
if post.CreatedAt.IsZero() {
t.Error("Expected CreatedAt to be set")
}
if post.UpdatedAt.IsZero() {
t.Error("Expected UpdatedAt to be set")
}
if post.UpVotes != 0 {
t.Error("Expected UpVotes to be 0 by default")
}
if post.DownVotes != 0 {
t.Error("Expected DownVotes to be 0 by default")
}
if post.Score != 0 {
t.Error("Expected Score to be 0 by default")
}
})
t.Run("post_constraints", func(t *testing.T) {
user := createTestUser(t, db)
post1 := &Post{
Title: "Post 1",
URL: "https://example.com/unique",
Content: "Content 1",
AuthorID: &user.ID,
}
post2 := &Post{
Title: "Post 2",
URL: "https://example.com/unique",
Content: "Content 2",
AuthorID: &user.ID,
}
if err := db.Create(post1).Error; err != nil {
t.Fatalf("Failed to create first post: %v", err)
}
if err := db.Create(post2).Error; err == nil {
t.Error("Expected error when creating post with duplicate URL")
}
})
t.Run("post_relationships", func(t *testing.T) {
user1 := createTestUser(t, db)
user2 := createTestUser(t, db)
post := createTestPost(t, db, user1.ID)
vote1 := &Vote{
UserID: &user1.ID,
PostID: post.ID,
Type: VoteUp,
}
vote2 := &Vote{
UserID: &user2.ID,
PostID: post.ID,
Type: VoteDown,
}
if err := db.Create(vote1).Error; err != nil {
t.Fatalf("Failed to create vote 1: %v", err)
}
if err := db.Create(vote2).Error; err != nil {
t.Fatalf("Failed to create vote 2: %v", err)
}
var foundPost Post
if err := db.Preload("Votes").First(&foundPost, post.ID).Error; err != nil {
t.Fatalf("Failed to load post with votes: %v", err)
}
if len(foundPost.Votes) != 2 {
t.Errorf("Expected 2 votes, got %d", len(foundPost.Votes))
}
})
}
func TestVote_Model(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}()
t.Run("create_vote", func(t *testing.T) {
user := createTestUser(t, db)
post := createTestPost(t, db, user.ID)
vote := &Vote{
UserID: &user.ID,
PostID: post.ID,
Type: VoteUp,
}
if err := db.Create(vote).Error; err != nil {
t.Fatalf("Failed to create vote: %v", err)
}
if vote.ID == 0 {
t.Error("Expected vote ID to be set")
}
if vote.CreatedAt.IsZero() {
t.Error("Expected CreatedAt to be set")
}
if vote.UpdatedAt.IsZero() {
t.Error("Expected UpdatedAt to be set")
}
})
t.Run("vote_constraints", func(t *testing.T) {
user := createTestUser(t, db)
post := createTestPost(t, db, user.ID)
vote1 := &Vote{
UserID: &user.ID,
PostID: post.ID,
Type: VoteUp,
}
vote2 := &Vote{
UserID: &user.ID,
PostID: post.ID,
Type: VoteDown,
}
if err := db.Create(vote1).Error; err != nil {
t.Fatalf("Failed to create first vote: %v", err)
}
if err := db.Create(vote2).Error; err == nil {
t.Error("Expected error when creating vote with duplicate user-post combination")
}
})
t.Run("vote_types", func(t *testing.T) {
user := createTestUser(t, db)
voteTypes := []VoteType{VoteUp, VoteDown, VoteNone}
for i, voteType := range voteTypes {
post := &Post{
Title: "Test Post " + string(rune(i)),
URL: "https://example.com/test" + string(rune(i)),
Content: "Test content",
AuthorID: &user.ID,
}
if err := db.Create(post).Error; err != nil {
t.Fatalf("Failed to create post %d: %v", i, err)
}
vote := &Vote{
UserID: &user.ID,
PostID: post.ID,
Type: voteType,
}
if err := db.Create(vote).Error; err != nil {
t.Fatalf("Failed to create vote with type %s: %v", voteType, err)
}
}
})
t.Run("vote_relationships", func(t *testing.T) {
user := createTestUser(t, db)
post := createTestPost(t, db, user.ID)
vote := &Vote{
UserID: &user.ID,
PostID: post.ID,
Type: VoteUp,
}
if err := db.Create(vote).Error; err != nil {
t.Fatalf("Failed to create vote: %v", err)
}
var foundVote Vote
if err := db.Preload("User").Preload("Post").First(&foundVote, vote.ID).Error; err != nil {
t.Fatalf("Failed to load vote with relationships: %v", err)
}
if foundVote.User.ID != user.ID {
t.Error("Expected vote to be associated with correct user")
}
if foundVote.Post.ID != post.ID {
t.Error("Expected vote to be associated with correct post")
}
})
}
func TestRefreshToken_Model(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}()
t.Run("create_refresh_token", func(t *testing.T) {
user := createTestUser(t, db)
token := &RefreshToken{
UserID: user.ID,
TokenHash: "hashedtoken123",
ExpiresAt: time.Now().Add(24 * time.Hour),
}
if err := db.Create(token).Error; err != nil {
t.Fatalf("Failed to create refresh token: %v", err)
}
if token.ID == 0 {
t.Error("Expected token ID to be set")
}
if token.CreatedAt.IsZero() {
t.Error("Expected CreatedAt to be set")
}
if token.UpdatedAt.IsZero() {
t.Error("Expected UpdatedAt to be set")
}
})
t.Run("refresh_token_constraints", func(t *testing.T) {
user := createTestUser(t, db)
token1 := &RefreshToken{
UserID: user.ID,
TokenHash: "uniquehash",
ExpiresAt: time.Now().Add(24 * time.Hour),
}
token2 := &RefreshToken{
UserID: user.ID,
TokenHash: "uniquehash",
ExpiresAt: time.Now().Add(24 * time.Hour),
}
if err := db.Create(token1).Error; err != nil {
t.Fatalf("Failed to create first token: %v", err)
}
if err := db.Create(token2).Error; err == nil {
t.Error("Expected error when creating token with duplicate hash")
}
})
}
func TestAccountDeletionRequest_Model(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}()
t.Run("create_account_deletion_request", func(t *testing.T) {
user := createTestUser(t, db)
request := &AccountDeletionRequest{
UserID: user.ID,
TokenHash: "deletiontoken123",
ExpiresAt: time.Now().Add(24 * time.Hour),
}
if err := db.Create(request).Error; err != nil {
t.Fatalf("Failed to create account deletion request: %v", err)
}
if request.ID == 0 {
t.Error("Expected request ID to be set")
}
if request.CreatedAt.IsZero() {
t.Error("Expected CreatedAt to be set")
}
})
t.Run("account_deletion_request_constraints", func(t *testing.T) {
user := createTestUser(t, db)
request1 := &AccountDeletionRequest{
UserID: user.ID,
TokenHash: "token1",
ExpiresAt: time.Now().Add(24 * time.Hour),
}
request2 := &AccountDeletionRequest{
UserID: user.ID,
TokenHash: "token2",
ExpiresAt: time.Now().Add(24 * time.Hour),
}
if err := db.Create(request1).Error; err != nil {
t.Fatalf("Failed to create first request: %v", err)
}
if err := db.Create(request2).Error; err == nil {
t.Error("Expected error when creating request with duplicate user")
}
})
}
func TestVoteType_Constants(t *testing.T) {
t.Run("vote_type_constants", func(t *testing.T) {
if VoteUp != "up" {
t.Errorf("Expected VoteUp to be 'up', got '%s'", VoteUp)
}
if VoteDown != "down" {
t.Errorf("Expected VoteDown to be 'down', got '%s'", VoteDown)
}
if VoteNone != "none" {
t.Errorf("Expected VoteNone to be 'none', got '%s'", VoteNone)
}
})
}
func TestModel_SoftDelete(t *testing.T) {
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}()
t.Run("user_soft_delete", func(t *testing.T) {
user := createTestUser(t, db)
if err := db.Delete(user).Error; err != nil {
t.Fatalf("Failed to soft delete user: %v", err)
}
var foundUser User
if err := db.First(&foundUser, user.ID).Error; err == nil {
t.Error("Expected user to be soft deleted")
}
if err := db.Unscoped().First(&foundUser, user.ID).Error; err != nil {
t.Fatalf("Expected to find soft deleted user with Unscoped: %v", err)
}
if foundUser.DeletedAt.Time.IsZero() {
t.Error("Expected DeletedAt to be set")
}
})
t.Run("post_soft_delete", func(t *testing.T) {
user := createTestUser(t, db)
post := createTestPost(t, db, user.ID)
if err := db.Delete(post).Error; err != nil {
t.Fatalf("Failed to soft delete post: %v", err)
}
var foundPost Post
if err := db.First(&foundPost, post.ID).Error; err == nil {
t.Error("Expected post to be soft deleted")
}
if err := db.Unscoped().First(&foundPost, post.ID).Error; err != nil {
t.Fatalf("Expected to find soft deleted post with Unscoped: %v", err)
}
if foundPost.DeletedAt.Time.IsZero() {
t.Error("Expected DeletedAt to be set")
}
})
}

View File

@@ -0,0 +1,190 @@
package database
import (
"context"
"time"
"gorm.io/gorm"
"goyco/internal/middleware"
)
type contextKey string
const gormOperationStartKey contextKey = "gorm_operation_start"
type GormDBMonitor struct {
monitor middleware.DBMonitor
}
func NewGormDBMonitor(monitor middleware.DBMonitor) *GormDBMonitor {
return &GormDBMonitor{
monitor: monitor,
}
}
func (g *GormDBMonitor) Name() string {
return "db_monitor"
}
func (g *GormDBMonitor) Initialize(db *gorm.DB) error {
db.Callback().Create().Before("gorm:create").Register("db_monitor:before_create", g.beforeCreate)
db.Callback().Create().After("gorm:create").Register("db_monitor:after_create", g.afterCreate)
db.Callback().Query().Before("gorm:query").Register("db_monitor:before_query", g.beforeQuery)
db.Callback().Query().After("gorm:query").Register("db_monitor:after_query", g.afterQuery)
db.Callback().Update().Before("gorm:update").Register("db_monitor:before_update", g.beforeUpdate)
db.Callback().Update().After("gorm:update").Register("db_monitor:after_update", g.afterUpdate)
db.Callback().Delete().Before("gorm:delete").Register("db_monitor:before_delete", g.beforeDelete)
db.Callback().Delete().After("gorm:delete").Register("db_monitor:after_delete", g.afterDelete)
db.Callback().Row().Before("gorm:row").Register("db_monitor:before_row", g.beforeRow)
db.Callback().Row().After("gorm:row").Register("db_monitor:after_row", g.afterRow)
db.Callback().Raw().Before("gorm:raw").Register("db_monitor:before_raw", g.beforeRaw)
db.Callback().Raw().After("gorm:raw").Register("db_monitor:after_raw", g.afterRaw)
return nil
}
func (g *GormDBMonitor) beforeCreate(db *gorm.DB) {
if g.monitor == nil {
return
}
ctx := context.WithValue(db.Statement.Context, gormOperationStartKey, time.Now())
db.Statement.Context = ctx
}
func (g *GormDBMonitor) afterCreate(db *gorm.DB) {
if g.monitor == nil {
return
}
g.logOperation(db, "CREATE")
}
func (g *GormDBMonitor) beforeQuery(db *gorm.DB) {
if g.monitor == nil {
return
}
ctx := context.WithValue(db.Statement.Context, gormOperationStartKey, time.Now())
db.Statement.Context = ctx
}
func (g *GormDBMonitor) afterQuery(db *gorm.DB) {
if g.monitor == nil {
return
}
g.logOperation(db, "SELECT")
}
func (g *GormDBMonitor) beforeUpdate(db *gorm.DB) {
if g.monitor == nil {
return
}
ctx := context.WithValue(db.Statement.Context, gormOperationStartKey, time.Now())
db.Statement.Context = ctx
}
func (g *GormDBMonitor) afterUpdate(db *gorm.DB) {
if g.monitor == nil {
return
}
g.logOperation(db, "UPDATE")
}
func (g *GormDBMonitor) beforeDelete(db *gorm.DB) {
if g.monitor == nil {
return
}
ctx := context.WithValue(db.Statement.Context, gormOperationStartKey, time.Now())
db.Statement.Context = ctx
}
func (g *GormDBMonitor) afterDelete(db *gorm.DB) {
if g.monitor == nil {
return
}
g.logOperation(db, "DELETE")
}
func (g *GormDBMonitor) beforeRow(db *gorm.DB) {
if g.monitor == nil {
return
}
ctx := context.WithValue(db.Statement.Context, gormOperationStartKey, time.Now())
db.Statement.Context = ctx
}
func (g *GormDBMonitor) afterRow(db *gorm.DB) {
if g.monitor == nil {
return
}
g.logOperation(db, "ROW")
}
func (g *GormDBMonitor) beforeRaw(db *gorm.DB) {
if g.monitor == nil {
return
}
ctx := context.WithValue(db.Statement.Context, gormOperationStartKey, time.Now())
db.Statement.Context = ctx
}
func (g *GormDBMonitor) afterRaw(db *gorm.DB) {
if g.monitor == nil {
return
}
g.logOperation(db, "RAW")
}
func (g *GormDBMonitor) logOperation(db *gorm.DB, operation string) {
if g.monitor == nil {
return
}
startTime, ok := db.Statement.Context.Value(gormOperationStartKey).(time.Time)
if !ok {
return
}
duration := time.Since(startTime)
query := g.buildQueryString(db, operation)
g.monitor.LogQuery(query, duration, db.Error)
}
func (g *GormDBMonitor) buildQueryString(db *gorm.DB, operation string) string {
if db.Statement.SQL.String() != "" {
return db.Statement.SQL.String()
}
query := operation
if db.Statement.Table != "" {
query += " FROM " + db.Statement.Table
}
if db.Statement.Model != nil {
if stmt := db.Statement; stmt.Schema != nil {
query = operation + " " + stmt.Schema.Table
}
}
return query
}

View File

@@ -0,0 +1,325 @@
package database
import (
"context"
"testing"
"time"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"goyco/internal/middleware"
)
func TestNewGormDBMonitor(t *testing.T) {
monitor := middleware.NewInMemoryDBMonitor()
gormMonitor := NewGormDBMonitor(monitor)
if gormMonitor == nil {
t.Fatal("Expected non-nil GormDBMonitor")
}
if gormMonitor.monitor != monitor {
t.Error("Expected monitor to be set correctly")
}
}
func TestGormDBMonitor_Name(t *testing.T) {
monitor := middleware.NewInMemoryDBMonitor()
gormMonitor := NewGormDBMonitor(monitor)
if gormMonitor.Name() != "db_monitor" {
t.Errorf("Expected name 'db_monitor', got '%s'", gormMonitor.Name())
}
}
func TestGormDBMonitor_Initialize(t *testing.T) {
monitor := middleware.NewInMemoryDBMonitor()
gormMonitor := NewGormDBMonitor(monitor)
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}()
err := gormMonitor.Initialize(db)
if err != nil {
t.Fatalf("Expected Initialize to succeed, got error: %v", err)
}
}
func TestGormDBMonitor_InitializeWithNilMonitor(t *testing.T) {
gormMonitor := NewGormDBMonitor(nil)
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}()
err := gormMonitor.Initialize(db)
if err != nil {
t.Fatalf("Expected Initialize to succeed with nil monitor, got error: %v", err)
}
}
func TestGormDBMonitor_Callbacks(t *testing.T) {
monitor := middleware.NewInMemoryDBMonitor()
gormMonitor := NewGormDBMonitor(monitor)
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}()
err := gormMonitor.Initialize(db)
if err != nil {
t.Fatalf("Failed to initialize plugin: %v", err)
}
user := &User{
Username: "testuser",
Email: "test@example.com",
Password: "hashedpassword",
EmailVerified: true,
}
if err := db.Create(user).Error; err != nil {
t.Fatalf("Failed to create user: %v", err)
}
var foundUser User
if err := db.First(&foundUser, user.ID).Error; err != nil {
t.Fatalf("Failed to find user: %v", err)
}
foundUser.Username = "updateduser"
if err := db.Save(&foundUser).Error; err != nil {
t.Fatalf("Failed to update user: %v", err)
}
if err := db.Delete(&foundUser).Error; err != nil {
t.Fatalf("Failed to delete user: %v", err)
}
}
func TestGormDBMonitor_CallbacksWithNilMonitor(t *testing.T) {
gormMonitor := NewGormDBMonitor(nil)
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}()
err := gormMonitor.Initialize(db)
if err != nil {
t.Fatalf("Failed to initialize plugin: %v", err)
}
user := &User{
Username: "testuser",
Email: "test@example.com",
Password: "hashedpassword",
EmailVerified: true,
}
if err := db.Create(user).Error; err != nil {
t.Fatalf("Failed to create user: %v", err)
}
}
func TestGormDBMonitor_BuildQueryString(t *testing.T) {
monitor := middleware.NewInMemoryDBMonitor()
gormMonitor := NewGormDBMonitor(monitor)
db := newTestDB(t)
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}()
err := gormMonitor.Initialize(db)
if err != nil {
t.Fatalf("Failed to initialize plugin: %v", err)
}
tests := []struct {
name string
operation string
table string
expected string
}{
{
name: "create_operation",
operation: "CREATE",
table: "users",
expected: "CREATE FROM users",
},
{
name: "select_operation",
operation: "SELECT",
table: "posts",
expected: "SELECT FROM posts",
},
{
name: "update_operation",
operation: "UPDATE",
table: "votes",
expected: "UPDATE FROM votes",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stmt := &gorm.Statement{
Table: tt.table,
}
mockDB := &gorm.DB{
Statement: stmt,
}
result := gormMonitor.buildQueryString(mockDB, tt.operation)
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
func TestGormDBMonitor_LogOperation(t *testing.T) {
monitor := middleware.NewInMemoryDBMonitor()
gormMonitor := NewGormDBMonitor(monitor)
startTime := time.Now()
ctx := context.WithValue(context.Background(), gormOperationStartKey, startTime)
stmt := &gorm.Statement{
Context: ctx,
Table: "users",
}
mockDB := &gorm.DB{
Statement: stmt,
}
gormMonitor.logOperation(mockDB, "CREATE")
gormMonitor.monitor = nil
gormMonitor.logOperation(mockDB, "CREATE")
}
func TestGormDBMonitor_LogOperationWithoutStartTime(t *testing.T) {
monitor := middleware.NewInMemoryDBMonitor()
gormMonitor := NewGormDBMonitor(monitor)
ctx := context.Background()
stmt := &gorm.Statement{
Context: ctx,
Table: "users",
}
mockDB := &gorm.DB{
Statement: stmt,
}
gormMonitor.logOperation(mockDB, "CREATE")
}
func TestGormDBMonitor_AllCallbackMethods(t *testing.T) {
monitor := middleware.NewInMemoryDBMonitor()
gormMonitor := NewGormDBMonitor(monitor)
gormMonitor.monitor = nil
ctx := context.Background()
stmt := &gorm.Statement{
Context: ctx,
Table: "users",
}
mockDB := &gorm.DB{
Statement: stmt,
}
gormMonitor.beforeCreate(mockDB)
gormMonitor.beforeQuery(mockDB)
gormMonitor.beforeUpdate(mockDB)
gormMonitor.beforeDelete(mockDB)
gormMonitor.beforeRow(mockDB)
gormMonitor.beforeRaw(mockDB)
gormMonitor.afterCreate(mockDB)
gormMonitor.afterQuery(mockDB)
gormMonitor.afterUpdate(mockDB)
gormMonitor.afterDelete(mockDB)
gormMonitor.afterRow(mockDB)
gormMonitor.afterRaw(mockDB)
}
func TestGormDBMonitor_WithRealDatabase(t *testing.T) {
monitor := middleware.NewInMemoryDBMonitor()
gormMonitor := NewGormDBMonitor(monitor)
dbName := "file:memdb_" + t.Name() + "?mode=memory&cache=private"
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer func() {
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}()
if err := db.AutoMigrate(&User{}); err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
err = gormMonitor.Initialize(db)
if err != nil {
t.Fatalf("Failed to initialize plugin: %v", err)
}
user := &User{
Username: "testuser",
Email: "test@example.com",
Password: "hashedpassword",
EmailVerified: true,
}
if err := db.Create(user).Error; err != nil {
t.Fatalf("Failed to create user: %v", err)
}
var foundUser User
if err := db.First(&foundUser, user.ID).Error; err != nil {
t.Fatalf("Failed to find user: %v", err)
}
foundUser.Username = "updateduser"
if err := db.Save(&foundUser).Error; err != nil {
t.Fatalf("Failed to update user: %v", err)
}
if err := db.Delete(&foundUser).Error; err != nil {
t.Fatalf("Failed to delete user: %v", err)
}
stats := monitor.GetStats()
if stats.TotalQueries == 0 {
t.Error("Expected monitor to have recorded some queries")
}
}

View File

@@ -0,0 +1,175 @@
package database
import (
"context"
"fmt"
"log"
"os"
"regexp"
"strings"
"time"
"gorm.io/gorm/logger"
)
type SecureLogger struct {
writer logger.Writer
config logger.Config
sensitiveFields []string
sensitivePattern *regexp.Regexp
productionMode bool
}
func NewSecureLogger(writer logger.Writer, config logger.Config, productionMode bool) *SecureLogger {
sensitiveFields := []string{
"password", "token", "secret", "key", "hash", "salt",
"email_verification_token", "password_reset_token",
"token_hash", "jwt_secret", "api_key", "access_token",
"refresh_token", "session_id", "cookie", "auth",
}
sensitivePattern := regexp.MustCompile(`(?i)(password|token|secret|key|hash|salt|email_verification_token|password_reset_token|token_hash|jwt_secret|api_key|access_token|refresh_token|session_id|cookie|auth)`)
return &SecureLogger{
writer: writer,
config: config,
sensitiveFields: sensitiveFields,
sensitivePattern: sensitivePattern,
productionMode: productionMode,
}
}
func (l *SecureLogger) LogMode(level logger.LogLevel) logger.Interface {
newLogger := *l
newLogger.config.LogLevel = level
return &newLogger
}
func (l *SecureLogger) Info(ctx context.Context, msg string, data ...any) {
if l.config.LogLevel >= logger.Info {
l.log(ctx, "info", msg, data...)
}
}
func (l *SecureLogger) Warn(ctx context.Context, msg string, data ...any) {
if l.config.LogLevel >= logger.Warn {
l.log(ctx, "warn", msg, data...)
}
}
func (l *SecureLogger) Error(ctx context.Context, msg string, data ...any) {
if l.config.LogLevel >= logger.Error {
l.log(ctx, "error", msg, data...)
}
}
func (l *SecureLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.config.LogLevel <= logger.Silent {
return
}
elapsed := time.Since(begin)
switch {
case err != nil && l.config.LogLevel >= logger.Error && (!l.config.IgnoreRecordNotFoundError || !IsRecordNotFoundError(err)):
sql, rows := fc()
l.log(ctx, "error", fmt.Sprintf("[%.3fms] [rows:%v] %s", float64(elapsed.Nanoseconds())/1e6, rows, sql))
case elapsed > l.config.SlowThreshold && l.config.SlowThreshold != 0 && l.config.LogLevel >= logger.Warn:
sql, rows := fc()
l.log(ctx, "warn", fmt.Sprintf("[%.3fms] [rows:%v] %s", float64(elapsed.Nanoseconds())/1e6, rows, sql))
case l.config.LogLevel == logger.Info:
sql, rows := fc()
l.log(ctx, "info", fmt.Sprintf("[%.3fms] [rows:%v] %s", float64(elapsed.Nanoseconds())/1e6, rows, sql))
}
}
func (l *SecureLogger) log(_ context.Context, level, msg string, data ...any) {
if l.productionMode {
msg = l.maskSensitiveData(msg)
maskedData := make([]any, len(data))
for i, d := range data {
maskedData[i] = l.maskSensitiveData(fmt.Sprintf("%v", d))
}
data = maskedData
}
formattedMsg := fmt.Sprintf(msg, data...)
l.writer.Printf("[%s] %s", strings.ToUpper(level), formattedMsg)
}
func (l *SecureLogger) maskSensitiveData(data string) string {
if l.productionMode {
data = regexp.MustCompile(`\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`).ReplaceAllString(data, "[EMAIL_MASKED]")
data = regexp.MustCompile(`\b[A-Za-z0-9]{20,}\b`).ReplaceAllStringFunc(data, func(match string) string {
if l.sensitivePattern.MatchString(match) {
return "[TOKEN_MASKED]"
}
return match
})
data = l.maskSQLValues(data)
}
return data
}
func (l *SecureLogger) maskSQLValues(sql string) string {
paramPattern := regexp.MustCompile(`'([^']*)'`)
return paramPattern.ReplaceAllStringFunc(sql, func(match string) string {
value := strings.Trim(match, "'")
if l.isSensitiveValue(value) {
return "'[MASKED]'"
}
return match
})
}
func (l *SecureLogger) isSensitiveValue(value string) bool {
if regexp.MustCompile(`\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`).MatchString(value) {
return true
}
if len(value) > 20 && regexp.MustCompile(`^[A-Za-z0-9+/]{20,}={0,2}$`).MatchString(value) {
return true
}
if regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`).MatchString(value) {
return true
}
if regexp.MustCompile(`^[A-Za-z0-9+/]+={0,2}$`).MatchString(value) && len(value) > 10 {
return true
}
return false
}
func IsRecordNotFoundError(err error) bool {
if err == nil {
return false
}
return strings.Contains(strings.ToLower(err.Error()), "record not found") ||
strings.Contains(strings.ToLower(err.Error()), "not found")
}
func CreateSecureLogger(productionMode bool) logger.Interface {
config := logger.Config{
SlowThreshold: time.Second,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: true,
Colorful: false,
}
if productionMode {
config.LogLevel = logger.Error
config.SlowThreshold = 2 * time.Second
}
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
return NewSecureLogger(writer, config, productionMode)
}

View File

@@ -0,0 +1,368 @@
package database
import (
"context"
"errors"
"log"
"os"
"strings"
"testing"
"time"
"gorm.io/gorm/logger"
)
func TestSecureLogger_MaskSensitiveData(t *testing.T) {
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
config := logger.Config{
SlowThreshold: time.Second,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: true,
Colorful: false,
}
tests := []struct {
name string
production bool
input string
expected string
}{
{
name: "development_mode_no_masking",
production: false,
input: "SELECT * FROM users WHERE email = 'user@example.com'",
expected: "SELECT * FROM users WHERE email = 'user@example.com'",
},
{
name: "production_mode_mask_email",
production: true,
input: "SELECT * FROM users WHERE email = 'user@example.com'",
expected: "SELECT * FROM users WHERE email = '[EMAIL_MASKED]'",
},
{
name: "production_mode_mask_token",
production: true,
input: "SELECT * FROM users WHERE password_reset_token = 'abc123def456ghi789'",
expected: "SELECT * FROM users WHERE password_reset_token = '[TOKEN_MASKED]'",
},
{
name: "production_mode_mask_uuid",
production: true,
input: "SELECT * FROM users WHERE id = '550e8400-e29b-41d4-a716-446655440000'",
expected: "SELECT * FROM users WHERE id = '[TOKEN_MASKED]'",
},
{
name: "production_mode_no_masking_short_values",
production: true,
input: "SELECT * FROM users WHERE id = 123",
expected: "SELECT * FROM users WHERE id = 123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
secureLogger := NewSecureLogger(writer, config, tt.production)
result := secureLogger.maskSensitiveData(tt.input)
if tt.production {
if strings.Contains(result, "user@example.com") {
t.Errorf("Email should be masked in production mode")
}
if strings.Contains(result, "abc123def456ghi789") {
t.Errorf("Token should be masked in production mode")
}
} else {
if result != tt.input {
t.Errorf("Expected %q, got %q", tt.input, result)
}
}
})
}
}
func TestSecureLogger_IsSensitiveValue(t *testing.T) {
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
config := logger.Config{
SlowThreshold: time.Second,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: true,
Colorful: false,
}
secureLogger := NewSecureLogger(writer, config, true)
tests := []struct {
name string
value string
expected bool
}{
{
name: "email_address",
value: "user@example.com",
expected: true,
},
{
name: "long_token",
value: "abc123def456ghi789jkl012mno345pqr678stu901vwx234yz",
expected: true,
},
{
name: "uuid",
value: "550e8400-e29b-41d4-a716-446655440000",
expected: true,
},
{
name: "short_value",
value: "123",
expected: false,
},
{
name: "normal_text",
value: "golang programming",
expected: false,
},
{
name: "base64_like",
value: "SGVsbG8gV29ybGQ=",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := secureLogger.isSensitiveValue(tt.value)
if result != tt.expected {
t.Errorf("Expected %v for value %q, got %v", tt.expected, tt.value, result)
}
})
}
}
func TestSecureLogger_LogLevels(t *testing.T) {
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
config := logger.Config{
SlowThreshold: time.Second,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: true,
Colorful: false,
}
secureLogger := NewSecureLogger(writer, config, false)
ctx := context.Background()
secureLogger.Info(ctx, "Test info message")
secureLogger.Warn(ctx, "Test warn message")
secureLogger.Error(ctx, "Test error message")
}
func TestCreateSecureLogger(t *testing.T) {
prodLogger := CreateSecureLogger(true)
if prodLogger == nil {
t.Error("Expected non-nil logger for production mode")
}
devLogger := CreateSecureLogger(false)
if devLogger == nil {
t.Error("Expected non-nil logger for development mode")
}
}
func TestSecureLogger_LogMode(t *testing.T) {
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
config := logger.Config{
SlowThreshold: time.Second,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: true,
Colorful: false,
}
secureLogger := NewSecureLogger(writer, config, false)
newLogger := secureLogger.LogMode(logger.Error)
if newLogger == nil {
t.Error("Expected non-nil logger from LogMode")
}
if secureLogger.config.LogLevel != logger.Info {
t.Error("Original logger should be unchanged")
}
}
func TestSecureLogger_Trace(t *testing.T) {
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
config := logger.Config{
SlowThreshold: time.Second,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: true,
Colorful: false,
}
secureLogger := NewSecureLogger(writer, config, false)
ctx := context.Background()
t.Run("silent_level", func(t *testing.T) {
silentLogger := secureLogger.LogMode(logger.Silent)
silentLogger.Trace(ctx, time.Now(), func() (string, int64) {
return "SELECT * FROM users", 1
}, nil)
})
t.Run("error_level_with_error", func(t *testing.T) {
errorLogger := secureLogger.LogMode(logger.Error)
errorLogger.Trace(ctx, time.Now(), func() (string, int64) {
return "SELECT * FROM users", 1
}, errors.New("test error"))
})
t.Run("warn_level_slow_query", func(t *testing.T) {
warnLogger := secureLogger.LogMode(logger.Warn)
startTime := time.Now().Add(-2 * time.Second)
warnLogger.Trace(ctx, startTime, func() (string, int64) {
return "SELECT * FROM users", 1
}, nil)
})
t.Run("info_level", func(t *testing.T) {
infoLogger := secureLogger.LogMode(logger.Info)
infoLogger.Trace(ctx, time.Now(), func() (string, int64) {
return "SELECT * FROM users", 1
}, nil)
})
}
func TestSecureLogger_MaskSQLValues(t *testing.T) {
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
config := logger.Config{
SlowThreshold: time.Second,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: true,
Colorful: false,
}
secureLogger := NewSecureLogger(writer, config, true)
tests := []struct {
name string
sql string
expected string
}{
{
name: "email_in_sql",
sql: "SELECT * FROM users WHERE email = 'user@example.com'",
expected: "SELECT * FROM users WHERE email = '[MASKED]'",
},
{
name: "token_in_sql",
sql: "SELECT * FROM users WHERE token = 'abc123def456ghi789'",
expected: "SELECT * FROM users WHERE token = '[MASKED]'",
},
{
name: "uuid_in_sql",
sql: "SELECT * FROM users WHERE id = '550e8400-e29b-41d4-a716-446655440000'",
expected: "SELECT * FROM users WHERE id = '[MASKED]'",
},
{
name: "normal_value",
sql: "SELECT * FROM users WHERE id = 123",
expected: "SELECT * FROM users WHERE id = 123",
},
{
name: "multiple_values",
sql: "SELECT * FROM users WHERE email = 'user@example.com' AND id = 123",
expected: "SELECT * FROM users WHERE email = '[MASKED]' AND id = 123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := secureLogger.maskSQLValues(tt.sql)
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
func TestSecureLogger_IsRecordNotFoundError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "record_not_found",
err: errors.New("record not found"),
expected: true,
},
{
name: "not_found",
err: errors.New("not found"),
expected: true,
},
{
name: "RECORD NOT FOUND",
err: errors.New("RECORD NOT FOUND"),
expected: true,
},
{
name: "NOT FOUND",
err: errors.New("NOT FOUND"),
expected: true,
},
{
name: "other_error",
err: errors.New("connection failed"),
expected: false,
},
{
name: "nil_error",
err: nil,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsRecordNotFoundError(tt.err)
if result != tt.expected {
t.Errorf("Expected %v for error '%v', got %v", tt.expected, tt.err, result)
}
})
}
}
func TestSecureLogger_ProductionMode(t *testing.T) {
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
config := logger.Config{
SlowThreshold: time.Second,
LogLevel: logger.Error,
IgnoreRecordNotFoundError: true,
Colorful: false,
}
secureLogger := NewSecureLogger(writer, config, true)
ctx := context.Background()
secureLogger.Info(ctx, "User login: %s", "user@example.com")
secureLogger.Warn(ctx, "Token validation: %s", "abc123def456ghi789")
secureLogger.Error(ctx, "Database error: %s", "connection failed")
}
func TestSecureLogger_DevelopmentMode(t *testing.T) {
writer := log.New(os.Stdout, "\r\n", log.LstdFlags)
config := logger.Config{
SlowThreshold: time.Second,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: true,
Colorful: false,
}
secureLogger := NewSecureLogger(writer, config, false)
ctx := context.Background()
secureLogger.Info(ctx, "User login: %s", "user@example.com")
secureLogger.Warn(ctx, "Token validation: %s", "abc123def456ghi789")
secureLogger.Error(ctx, "Database error: %s", "connection failed")
}