Compare commits
2 Commits
65576cc623
...
697f201d60
| Author | SHA1 | Date | |
|---|---|---|---|
| 697f201d60 | |||
| f4ab8bda45 |
@@ -57,10 +57,15 @@ func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return withDatabase(cfg, func(db *gorm.DB) error {
|
return withDatabase(cfg, func(db *gorm.DB) error {
|
||||||
userRepo := repositories.NewUserRepository(db)
|
return db.Transaction(func(tx *gorm.DB) error {
|
||||||
postRepo := repositories.NewPostRepository(db)
|
userRepo := repositories.NewUserRepository(db).WithTx(tx)
|
||||||
voteRepo := repositories.NewVoteRepository(db)
|
postRepo := repositories.NewPostRepository(db).WithTx(tx)
|
||||||
return runSeedCommand(userRepo, postRepo, voteRepo, fs.Args())
|
voteRepo := repositories.NewVoteRepository(db).WithTx(tx)
|
||||||
|
if err := runSeedCommand(userRepo, postRepo, voteRepo, fs.Args()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -362,3 +362,86 @@ func findSeedUser(users []database.User) *database.User {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSeedCommandTransactionRollback(t *testing.T) {
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to connect to database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.AutoMigrate(&database.User{}, &database.Post{}, &database.Vote{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to migrate database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userRepo := repositories.NewUserRepository(db)
|
||||||
|
postRepo := repositories.NewPostRepository(db)
|
||||||
|
voteRepo := repositories.NewVoteRepository(db)
|
||||||
|
|
||||||
|
t.Run("transaction rolls back on failure", func(t *testing.T) {
|
||||||
|
initialUserCount, _ := userRepo.Count()
|
||||||
|
initialPostCount, _ := postRepo.Count()
|
||||||
|
initialVoteCount, _ := voteRepo.Count()
|
||||||
|
|
||||||
|
err := db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
txUserRepo := userRepo.WithTx(tx)
|
||||||
|
txPostRepo := postRepo.WithTx(tx)
|
||||||
|
txVoteRepo := voteRepo.WithTx(tx)
|
||||||
|
|
||||||
|
err := seedDatabase(txUserRepo, txPostRepo, txVoteRepo, []string{"--users", "2", "--posts", "3"})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("simulated failure")
|
||||||
|
})
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected transaction to fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
finalUserCount, _ := userRepo.Count()
|
||||||
|
finalPostCount, _ := postRepo.Count()
|
||||||
|
finalVoteCount, _ := voteRepo.Count()
|
||||||
|
|
||||||
|
if finalUserCount != initialUserCount {
|
||||||
|
t.Errorf("Expected user count to remain %d after rollback, got %d", initialUserCount, finalUserCount)
|
||||||
|
}
|
||||||
|
if finalPostCount != initialPostCount {
|
||||||
|
t.Errorf("Expected post count to remain %d after rollback, got %d", initialPostCount, finalPostCount)
|
||||||
|
}
|
||||||
|
if finalVoteCount != initialVoteCount {
|
||||||
|
t.Errorf("Expected vote count to remain %d after rollback, got %d", initialVoteCount, finalVoteCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("transaction commits on success", func(t *testing.T) {
|
||||||
|
initialUserCount, _ := userRepo.Count()
|
||||||
|
initialPostCount, _ := postRepo.Count()
|
||||||
|
|
||||||
|
err := db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
txUserRepo := userRepo.WithTx(tx)
|
||||||
|
txPostRepo := postRepo.WithTx(tx)
|
||||||
|
txVoteRepo := voteRepo.WithTx(tx)
|
||||||
|
|
||||||
|
return seedDatabase(txUserRepo, txPostRepo, txVoteRepo, []string{"--users", "1", "--posts", "1"})
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected transaction to succeed, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
finalUserCount, _ := userRepo.Count()
|
||||||
|
finalPostCount, _ := postRepo.Count()
|
||||||
|
|
||||||
|
expectedUsers := initialUserCount + 2
|
||||||
|
expectedPosts := initialPostCount + 1
|
||||||
|
|
||||||
|
if finalUserCount < expectedUsers {
|
||||||
|
t.Errorf("Expected at least %d users after commit, got %d", expectedUsers, finalUserCount)
|
||||||
|
}
|
||||||
|
if finalPostCount < expectedPosts {
|
||||||
|
t.Errorf("Expected at least %d posts after commit, got %d", expectedPosts, finalPostCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user