diff --git a/cmd/goyco/commands/seed.go b/cmd/goyco/commands/seed.go index 780fa3c..e47f4a5 100644 --- a/cmd/goyco/commands/seed.go +++ b/cmd/goyco/commands/seed.go @@ -7,6 +7,7 @@ import ( "fmt" "math/big" "os" + "strings" "goyco/internal/config" "goyco/internal/database" @@ -169,6 +170,10 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po progress.Complete() } + if err := validateSeedConsistency(postRepo, voteRepo, allUsers, posts); err != nil { + return fmt.Errorf("seed consistency validation failed: %w", err) + } + if IsJSONOutput() { outputJSON(map[string]any{ "action": "seed_completed", @@ -188,7 +193,32 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po return nil } +func findExistingSeedUser(userRepo repositories.UserRepository) (*database.User, error) { + users, err := userRepo.GetAll(100, 0) + if err != nil { + return nil, err + } + + for _, user := range users { + if len(user.Username) >= 11 && user.Username[:11] == "seed_admin_" { + if len(user.Email) >= 13 && strings.HasSuffix(user.Email, "@goyco.local") { + emailPrefix := user.Email[:len(user.Email)-13] + if len(emailPrefix) >= 11 && emailPrefix[:11] == "seed_admin_" { + return &user, nil + } + } + } + } + + return nil, fmt.Errorf("no existing seed user found") +} + func ensureSeedUser(userRepo repositories.UserRepository) (*database.User, error) { + existingUser, err := findExistingSeedUser(userRepo) + if err == nil && existingUser != nil { + return existingUser, nil + } + seedPassword := "seed-password" randomID := generateRandomIdentifier() seedUsername := fmt.Sprintf("seed_admin_%s", randomID) @@ -427,3 +457,35 @@ func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, return upVotes, downVotes, nil } + +func validateSeedConsistency(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error { + userIDs := make(map[uint]bool) + for _, user := range users { + userIDs[user.ID] = true + } + + for _, post := range posts { + if post.AuthorID == nil { + return fmt.Errorf("post %d has no author", post.ID) + } + if !userIDs[*post.AuthorID] { + return fmt.Errorf("post %d has invalid author ID %d", post.ID, *post.AuthorID) + } + + votes, err := voteRepo.GetByPostID(post.ID) + if err != nil { + return fmt.Errorf("get votes for post %d: %w", post.ID, err) + } + + for _, vote := range votes { + if vote.UserID != nil && !userIDs[*vote.UserID] { + return fmt.Errorf("vote %d has invalid user ID %d", vote.ID, *vote.UserID) + } + if vote.PostID != post.ID { + return fmt.Errorf("vote %d has invalid post ID %d (expected %d)", vote.ID, vote.PostID, post.ID) + } + } + } + + return nil +}