diff --git a/cmd/goyco/commands/seed.go b/cmd/goyco/commands/seed.go index 5bbb6c1..80374f9 100644 --- a/cmd/goyco/commands/seed.go +++ b/cmd/goyco/commands/seed.go @@ -213,7 +213,7 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po progress.Complete() } - if err := validateSeedConsistency(postRepo, voteRepo, allUsers, posts); err != nil { + if err := validateSeedConsistency(voteRepo, allUsers, posts); err != nil { return fmt.Errorf("seed consistency validation failed: %w", err) } @@ -280,32 +280,65 @@ func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, return voteRepo.GetVoteCountsByPostID(postID) } -func validateSeedConsistency(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error { - userIDs := make(map[uint]bool) +func validateSeedConsistency(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error { + userIDSet := make(map[uint]struct{}, len(users)) for _, user := range users { - userIDs[user.ID] = true + userIDSet[user.ID] = struct{}{} + } + + postIDSet := make(map[uint]struct{}, len(posts)) + for _, post := range posts { + postIDSet[post.ID] = struct{}{} } for _, post := range posts { - if post.AuthorID == nil { - return fmt.Errorf("post %d has no author", post.ID) - } - if !userIDs[*post.AuthorID] { - return fmt.Errorf("post %d has invalid author ID %d", post.ID, *post.AuthorID) + if err := validatePost(post, userIDSet); err != nil { + return err } votes, err := voteRepo.GetByPostID(post.ID) if err != nil { - return fmt.Errorf("get votes for post %d: %w", post.ID, err) + return fmt.Errorf("failed to retrieve votes for post %d: %w", post.ID, err) } - for _, vote := range votes { - if vote.UserID != nil && !userIDs[*vote.UserID] { - return fmt.Errorf("vote %d has invalid user ID %d", vote.ID, *vote.UserID) - } - if vote.PostID != post.ID { - return fmt.Errorf("vote %d has invalid post ID %d (expected %d)", vote.ID, vote.PostID, post.ID) - } + if err := validateVotesForPost(post.ID, votes, userIDSet, postIDSet); err != nil { + return err + } + } + + return nil +} + +func validatePost(post database.Post, userIDSet map[uint]struct{}) error { + if post.AuthorID == nil { + return fmt.Errorf("post %d has no author ID", post.ID) + } + + if _, exists := userIDSet[*post.AuthorID]; !exists { + return fmt.Errorf("post %d references non-existent author ID %d", post.ID, *post.AuthorID) + } + + return nil +} + +func validateVotesForPost(postID uint, votes []database.Vote, userIDSet map[uint]struct{}, postIDSet map[uint]struct{}) error { + for _, vote := range votes { + if vote.PostID != postID { + return fmt.Errorf("vote %d references post ID %d but was retrieved for post %d", vote.ID, vote.PostID, postID) + } + + if _, exists := postIDSet[vote.PostID]; !exists { + return fmt.Errorf("vote %d references non-existent post ID %d", vote.ID, vote.PostID) + } + + if vote.UserID != nil { + if _, exists := userIDSet[*vote.UserID]; !exists { + return fmt.Errorf("vote %d references non-existent user ID %d", vote.ID, *vote.UserID) + } + } + + if vote.Type != database.VoteUp && vote.Type != database.VoteDown { + return fmt.Errorf("vote %d has invalid type %q", vote.ID, vote.Type) } }