Compare commits

...

2 Commits

Author SHA1 Message Date
697f201d60 feat: use database transactions to ensure atomicity 2025-11-21 16:21:04 +01:00
f4ab8bda45 feat: transaction rollback test 2025-11-21 16:20:41 +01:00
2 changed files with 92 additions and 4 deletions

View File

@@ -57,10 +57,15 @@ func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
}
return withDatabase(cfg, func(db *gorm.DB) error {
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
return runSeedCommand(userRepo, postRepo, voteRepo, fs.Args())
return db.Transaction(func(tx *gorm.DB) error {
userRepo := repositories.NewUserRepository(db).WithTx(tx)
postRepo := repositories.NewPostRepository(db).WithTx(tx)
voteRepo := repositories.NewVoteRepository(db).WithTx(tx)
if err := runSeedCommand(userRepo, postRepo, voteRepo, fs.Args()); err != nil {
return err
}
return nil
})
})
}

View File

@@ -362,3 +362,86 @@ func findSeedUser(users []database.User) *database.User {
}
return nil
}
func TestSeedCommandTransactionRollback(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to connect to database: %v", err)
}
err = db.AutoMigrate(&database.User{}, &database.Post{}, &database.Vote{})
if err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
t.Run("transaction rolls back on failure", func(t *testing.T) {
initialUserCount, _ := userRepo.Count()
initialPostCount, _ := postRepo.Count()
initialVoteCount, _ := voteRepo.Count()
err := db.Transaction(func(tx *gorm.DB) error {
txUserRepo := userRepo.WithTx(tx)
txPostRepo := postRepo.WithTx(tx)
txVoteRepo := voteRepo.WithTx(tx)
err := seedDatabase(txUserRepo, txPostRepo, txVoteRepo, []string{"--users", "2", "--posts", "3"})
if err != nil {
return err
}
return fmt.Errorf("simulated failure")
})
if err == nil {
t.Fatal("Expected transaction to fail")
}
finalUserCount, _ := userRepo.Count()
finalPostCount, _ := postRepo.Count()
finalVoteCount, _ := voteRepo.Count()
if finalUserCount != initialUserCount {
t.Errorf("Expected user count to remain %d after rollback, got %d", initialUserCount, finalUserCount)
}
if finalPostCount != initialPostCount {
t.Errorf("Expected post count to remain %d after rollback, got %d", initialPostCount, finalPostCount)
}
if finalVoteCount != initialVoteCount {
t.Errorf("Expected vote count to remain %d after rollback, got %d", initialVoteCount, finalVoteCount)
}
})
t.Run("transaction commits on success", func(t *testing.T) {
initialUserCount, _ := userRepo.Count()
initialPostCount, _ := postRepo.Count()
err := db.Transaction(func(tx *gorm.DB) error {
txUserRepo := userRepo.WithTx(tx)
txPostRepo := postRepo.WithTx(tx)
txVoteRepo := voteRepo.WithTx(tx)
return seedDatabase(txUserRepo, txPostRepo, txVoteRepo, []string{"--users", "1", "--posts", "1"})
})
if err != nil {
t.Fatalf("Expected transaction to succeed, got error: %v", err)
}
finalUserCount, _ := userRepo.Count()
finalPostCount, _ := postRepo.Count()
expectedUsers := initialUserCount + 2
expectedPosts := initialPostCount + 1
if finalUserCount < expectedUsers {
t.Errorf("Expected at least %d users after commit, got %d", expectedUsers, finalUserCount)
}
if finalPostCount < expectedPosts {
t.Errorf("Expected at least %d posts after commit, got %d", expectedPosts, finalPostCount)
}
})
}