feat: transaction rollback test

This commit is contained in:
2025-11-21 16:20:41 +01:00
parent 65576cc623
commit f4ab8bda45

View File

@@ -362,3 +362,86 @@ func findSeedUser(users []database.User) *database.User {
} }
return nil return nil
} }
func TestSeedCommandTransactionRollback(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to connect to database: %v", err)
}
err = db.AutoMigrate(&database.User{}, &database.Post{}, &database.Vote{})
if err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
t.Run("transaction rolls back on failure", func(t *testing.T) {
initialUserCount, _ := userRepo.Count()
initialPostCount, _ := postRepo.Count()
initialVoteCount, _ := voteRepo.Count()
err := db.Transaction(func(tx *gorm.DB) error {
txUserRepo := userRepo.WithTx(tx)
txPostRepo := postRepo.WithTx(tx)
txVoteRepo := voteRepo.WithTx(tx)
err := seedDatabase(txUserRepo, txPostRepo, txVoteRepo, []string{"--users", "2", "--posts", "3"})
if err != nil {
return err
}
return fmt.Errorf("simulated failure")
})
if err == nil {
t.Fatal("Expected transaction to fail")
}
finalUserCount, _ := userRepo.Count()
finalPostCount, _ := postRepo.Count()
finalVoteCount, _ := voteRepo.Count()
if finalUserCount != initialUserCount {
t.Errorf("Expected user count to remain %d after rollback, got %d", initialUserCount, finalUserCount)
}
if finalPostCount != initialPostCount {
t.Errorf("Expected post count to remain %d after rollback, got %d", initialPostCount, finalPostCount)
}
if finalVoteCount != initialVoteCount {
t.Errorf("Expected vote count to remain %d after rollback, got %d", initialVoteCount, finalVoteCount)
}
})
t.Run("transaction commits on success", func(t *testing.T) {
initialUserCount, _ := userRepo.Count()
initialPostCount, _ := postRepo.Count()
err := db.Transaction(func(tx *gorm.DB) error {
txUserRepo := userRepo.WithTx(tx)
txPostRepo := postRepo.WithTx(tx)
txVoteRepo := voteRepo.WithTx(tx)
return seedDatabase(txUserRepo, txPostRepo, txVoteRepo, []string{"--users", "1", "--posts", "1"})
})
if err != nil {
t.Fatalf("Expected transaction to succeed, got error: %v", err)
}
finalUserCount, _ := userRepo.Count()
finalPostCount, _ := postRepo.Count()
expectedUsers := initialUserCount + 2
expectedPosts := initialPostCount + 1
if finalUserCount < expectedUsers {
t.Errorf("Expected at least %d users after commit, got %d", expectedUsers, finalUserCount)
}
if finalPostCount < expectedPosts {
t.Errorf("Expected at least %d posts after commit, got %d", expectedPosts, finalPostCount)
}
})
}