Compare commits

...

2 Commits

2 changed files with 389 additions and 140 deletions

View File

@@ -6,29 +6,17 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"os" "os"
"sync" "strings"
"time" "time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"goyco/internal/config" "goyco/internal/config"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/repositories" "goyco/internal/repositories"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
) )
var (
seedRandSource *rand.Rand
seedRandOnce sync.Once
)
func initSeedRand() {
seedRandOnce.Do(func() {
seed := time.Now().UnixNano()
seedRandSource = rand.New(rand.NewSource(seed))
})
}
func HandleSeedCommand(cfg *config.Config, name string, args []string) error { func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
fs := newFlagSet(name, printSeedUsage) fs := newFlagSet(name, printSeedUsage)
if err := parseCommand(fs, args, name); err != nil { if err := parseCommand(fs, args, name); err != nil {
@@ -39,10 +27,12 @@ func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
} }
return withDatabase(cfg, func(db *gorm.DB) error { return withDatabase(cfg, func(db *gorm.DB) error {
userRepo := repositories.NewUserRepository(db) return db.Transaction(func(tx *gorm.DB) error {
postRepo := repositories.NewPostRepository(db) userRepo := repositories.NewUserRepository(db).WithTx(tx)
voteRepo := repositories.NewVoteRepository(db) postRepo := repositories.NewPostRepository(db).WithTx(tx)
return runSeedCommand(userRepo, postRepo, voteRepo, fs.Args()) voteRepo := repositories.NewVoteRepository(db).WithTx(tx)
return runSeedCommand(userRepo, postRepo, voteRepo, fs.Args())
})
}) })
} }
@@ -72,45 +62,37 @@ func printSeedUsage() {
fmt.Fprintln(os.Stderr, " --votes-per-post: average votes per post (default: 15)") fmt.Fprintln(os.Stderr, " --votes-per-post: average votes per post (default: 15)")
} }
func clampFlagValue(value *int, min int, name string) {
if *value < min {
if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --%s value %d is too low, clamping to %d\n", name, *value, min)
}
*value = min
}
}
func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, args []string) error { func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, args []string) error {
fs := flag.NewFlagSet("seed database", flag.ContinueOnError) fs := flag.NewFlagSet("seed database", flag.ContinueOnError)
numPosts := fs.Int("posts", 40, "number of posts to create") numPosts := fs.Int("posts", 40, "number of posts to create")
numUsers := fs.Int("users", 5, "number of additional users to create") numUsers := fs.Int("users", 5, "number of additional users to create")
votesPerPost := fs.Int("votes-per-post", 15, "average votes per post") votesPerPost := fs.Int("votes-per-post", 15, "average votes per post")
fs.SetOutput(os.Stderr) fs.SetOutput(os.Stderr)
fs.Usage = func() {
fmt.Fprintln(os.Stderr, "Usage: goyco seed database [--posts <n>] [--users <n>] [--votes-per-post <n>]")
fmt.Fprintln(os.Stderr, "\nOptions:")
fs.PrintDefaults()
}
if err := fs.Parse(args); err != nil { if err := fs.Parse(args); err != nil {
if errors.Is(err, flag.ErrHelp) {
return nil
}
return err return err
} }
originalUsers := *numUsers clampFlagValue(numUsers, 0, "users")
originalPosts := *numPosts clampFlagValue(numPosts, 1, "posts")
originalVotesPerPost := *votesPerPost clampFlagValue(votesPerPost, 0, "votes-per-post")
if *numUsers < 0 {
if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --users value %d is negative, clamping to 0\n", *numUsers)
}
*numUsers = 0
}
if *numPosts <= 0 {
if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --posts value %d is too low, clamping to 1\n", *numPosts)
}
*numPosts = 1
}
if *votesPerPost < 0 {
if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --votes-per-post value %d is negative, clamping to 0\n", *votesPerPost)
}
*votesPerPost = 0
}
if !IsJSONOutput() && (originalUsers != *numUsers || originalPosts != *numPosts || originalVotesPerPost != *votesPerPost) {
fmt.Fprintf(os.Stderr, "Using clamped values: --users=%d --posts=%d --votes-per-post=%d\n", *numUsers, *numPosts, *votesPerPost)
}
if !IsJSONOutput() { if !IsJSONOutput() {
fmt.Println("Starting database seeding...") fmt.Println("Starting database seeding...")
@@ -129,71 +111,35 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
return fmt.Errorf("precompute user password hash: %w", err) return fmt.Errorf("precompute user password hash: %w", err)
} }
spinner := NewSpinner("Creating seed user")
if !IsJSONOutput() {
spinner.Spin()
}
seedUser, err := ensureSeedUser(userRepo, string(seedPasswordHash)) seedUser, err := ensureSeedUser(userRepo, string(seedPasswordHash))
if err != nil { if err != nil {
if !IsJSONOutput() {
spinner.Complete()
}
return fmt.Errorf("ensure seed user: %w", err) return fmt.Errorf("ensure seed user: %w", err)
} }
if !IsJSONOutput() { if !IsJSONOutput() {
spinner.Complete()
fmt.Printf("Seed user ready: ID=%d Username=%s\n", seedUser.ID, seedUser.Username) fmt.Printf("Seed user ready: ID=%d Username=%s\n", seedUser.ID, seedUser.Username)
} }
processor := NewParallelProcessor() generator := newSeedGenerator(string(userPasswordHash))
processor.SetPasswordHash(string(userPasswordHash)) allUsers := []database.User{*seedUser}
var progress *ProgressIndicator users, err := createUsers(generator, userRepo, *numUsers, "Creating users")
if !IsJSONOutput() && *numUsers > 0 {
progress = NewProgressIndicator(*numUsers, "Creating users (parallel)")
}
users, err := processor.CreateUsersInParallel(userRepo, *numUsers, progress)
if err != nil { if err != nil {
return fmt.Errorf("create random users: %w", err) return err
}
if !IsJSONOutput() && progress != nil {
progress.Complete()
} }
allUsers = append(allUsers, users...)
allUsers := append([]database.User{*seedUser}, users...) posts, err := createPosts(generator, postRepo, seedUser.ID, *numPosts, "Creating posts")
if !IsJSONOutput() && *numPosts > 0 {
progress = NewProgressIndicator(*numPosts, "Creating posts (parallel)")
}
posts, err := processor.CreatePostsInParallel(postRepo, seedUser.ID, *numPosts, progress)
if err != nil { if err != nil {
return fmt.Errorf("create random posts: %w", err) return err
}
if !IsJSONOutput() && progress != nil {
progress.Complete()
} }
if !IsJSONOutput() && len(posts) > 0 { votes, err := createVotes(generator, voteRepo, allUsers, posts, *votesPerPost, "Creating votes")
progress = NewProgressIndicator(len(posts), "Creating votes (parallel)")
}
votes, err := processor.CreateVotesInParallel(voteRepo, allUsers, posts, *votesPerPost, progress)
if err != nil { if err != nil {
return fmt.Errorf("create random votes: %w", err) return err
}
if !IsJSONOutput() && progress != nil {
progress.Complete()
} }
if !IsJSONOutput() && len(posts) > 0 { if err := updateScores(generator, postRepo, voteRepo, posts, "Updating scores"); err != nil {
progress = NewProgressIndicator(len(posts), "Updating scores (parallel)") return err
}
err = processor.UpdatePostScoresInParallel(postRepo, voteRepo, posts, progress)
if err != nil {
return fmt.Errorf("update post scores: %w", err)
}
if !IsJSONOutput() && progress != nil {
progress.Complete()
} }
if err := validateSeedConsistency(voteRepo, allUsers, posts); err != nil { if err := validateSeedConsistency(voteRepo, allUsers, posts); err != nil {
@@ -225,11 +171,15 @@ const (
) )
func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (*database.User, error) { func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (*database.User, error) {
if user, err := userRepo.GetByUsername(seedUsername); err == nil { user, err := userRepo.GetByUsername(seedUsername)
if err == nil {
return user, nil return user, nil
} }
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("failed to check if seed user exists: %w", err)
}
user := &database.User{ user = &database.User{
Username: seedUsername, Username: seedUsername,
Email: seedEmail, Email: seedEmail,
Password: passwordHash, Password: passwordHash,
@@ -243,10 +193,6 @@ func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (
return user, nil return user, nil
} }
func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) {
return voteRepo.GetVoteCountsByPostID(postID)
}
func validateSeedConsistency(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error { func validateSeedConsistency(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error {
userIDSet := make(map[uint]struct{}, len(users)) userIDSet := make(map[uint]struct{}, len(users))
for _, user := range users { for _, user := range users {
@@ -259,8 +205,11 @@ func validateSeedConsistency(voteRepo repositories.VoteRepository, users []datab
} }
for _, post := range posts { for _, post := range posts {
if err := validatePost(post, userIDSet); err != nil { if post.AuthorID == nil {
return err return fmt.Errorf("post %d has no author ID", post.ID)
}
if _, exists := userIDSet[*post.AuthorID]; !exists {
return fmt.Errorf("post %d references non-existent author ID %d", post.ID, *post.AuthorID)
} }
votes, err := voteRepo.GetByPostID(post.ID) votes, err := voteRepo.GetByPostID(post.ID)
@@ -268,46 +217,293 @@ func validateSeedConsistency(voteRepo repositories.VoteRepository, users []datab
return fmt.Errorf("failed to retrieve votes for post %d: %w", post.ID, err) return fmt.Errorf("failed to retrieve votes for post %d: %w", post.ID, err)
} }
if err := validateVotesForPost(post.ID, votes, userIDSet, postIDSet); err != nil { for _, vote := range votes {
return err if vote.PostID != post.ID {
} return fmt.Errorf("vote %d references post ID %d but was retrieved for post %d", vote.ID, vote.PostID, post.ID)
} }
if _, exists := postIDSet[vote.PostID]; !exists {
return nil return fmt.Errorf("vote %d references non-existent post ID %d", vote.ID, vote.PostID)
} }
if vote.UserID != nil {
func validatePost(post database.Post, userIDSet map[uint]struct{}) error { if _, exists := userIDSet[*vote.UserID]; !exists {
if post.AuthorID == nil { return fmt.Errorf("vote %d references non-existent user ID %d", vote.ID, *vote.UserID)
return fmt.Errorf("post %d has no author ID", post.ID) }
} }
if vote.Type != database.VoteUp && vote.Type != database.VoteDown {
if _, exists := userIDSet[*post.AuthorID]; !exists { return fmt.Errorf("vote %d has invalid type %q", vote.ID, vote.Type)
return fmt.Errorf("post %d references non-existent author ID %d", post.ID, *post.AuthorID)
}
return nil
}
func validateVotesForPost(postID uint, votes []database.Vote, userIDSet map[uint]struct{}, postIDSet map[uint]struct{}) error {
for _, vote := range votes {
if vote.PostID != postID {
return fmt.Errorf("vote %d references post ID %d but was retrieved for post %d", vote.ID, vote.PostID, postID)
}
if _, exists := postIDSet[vote.PostID]; !exists {
return fmt.Errorf("vote %d references non-existent post ID %d", vote.ID, vote.PostID)
}
if vote.UserID != nil {
if _, exists := userIDSet[*vote.UserID]; !exists {
return fmt.Errorf("vote %d references non-existent user ID %d", vote.ID, *vote.UserID)
} }
} }
if vote.Type != database.VoteUp && vote.Type != database.VoteDown {
return fmt.Errorf("vote %d has invalid type %q", vote.ID, vote.Type)
}
} }
return nil return nil
} }
type seedGenerator struct {
passwordHash string
randSource *rand.Rand
}
func newSeedGenerator(passwordHash string) *seedGenerator {
seed := time.Now().UnixNano()
return &seedGenerator{
passwordHash: passwordHash,
randSource: rand.New(rand.NewSource(seed)),
}
}
func isRetryableError(err error, keywords ...string) bool {
if err == nil {
return false
}
errMsg := strings.ToLower(err.Error())
if errors.Is(err, gorm.ErrDuplicatedKey) {
for _, keyword := range keywords {
if strings.Contains(errMsg, keyword) {
return true
}
}
return false
}
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
constraintLower := strings.ToLower(pqErr.Constraint)
errMsgLower := strings.ToLower(pqErr.Message)
for _, keyword := range keywords {
if strings.Contains(constraintLower, keyword) || strings.Contains(errMsgLower, keyword) {
return true
}
}
return false
}
if strings.Contains(errMsg, "duplicate") {
for _, keyword := range keywords {
if strings.Contains(errMsg, keyword) {
return true
}
}
}
return false
}
func createUsers(g *seedGenerator, userRepo repositories.UserRepository, count int, desc string) ([]database.User, error) {
if count == 0 {
return nil, nil
}
progress := maybeProgress(count, desc)
users := make([]database.User, 0, count)
for i := 0; i < count; i++ {
user, err := g.createSingleUser(userRepo, i+1)
if err != nil {
return nil, fmt.Errorf("create random user: %w", err)
}
users = append(users, user)
if progress != nil {
progress.Increment()
}
}
if progress != nil {
progress.Complete()
}
return users, nil
}
func createPosts(g *seedGenerator, postRepo repositories.PostRepository, authorID uint, count int, desc string) ([]database.Post, error) {
if count == 0 {
return nil, nil
}
progress := maybeProgress(count, desc)
posts := make([]database.Post, 0, count)
for i := 0; i < count; i++ {
post, err := g.createSinglePost(postRepo, authorID, i+1)
if err != nil {
return nil, fmt.Errorf("create random post: %w", err)
}
posts = append(posts, post)
if progress != nil {
progress.Increment()
}
}
if progress != nil {
progress.Complete()
}
return posts, nil
}
func createVotes(g *seedGenerator, voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, desc string) (int, error) {
if len(posts) == 0 {
return 0, nil
}
progress := maybeProgress(len(posts), desc)
votes := 0
for _, post := range posts {
count, err := g.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
if err != nil {
return 0, fmt.Errorf("create random votes for post %d: %w", post.ID, err)
}
votes += count
if progress != nil {
progress.Increment()
}
}
if progress != nil {
progress.Complete()
}
return votes, nil
}
func updateScores(g *seedGenerator, postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, desc string) error {
if len(posts) == 0 {
return nil
}
progress := maybeProgress(len(posts), desc)
for _, post := range posts {
if err := g.updateSinglePostScore(postRepo, voteRepo, post); err != nil {
return fmt.Errorf("update post scores: %w", err)
}
if progress != nil {
progress.Increment()
}
}
if progress != nil {
progress.Complete()
}
return nil
}
func maybeProgress(count int, desc string) *ProgressIndicator {
if !IsJSONOutput() && count > 0 {
return NewProgressIndicator(count, desc)
}
return nil
}
func (g *seedGenerator) generateRandomIdentifier() string {
const length = 12
const chars = "abcdefghijklmnopqrstuvwxyz0123456789"
identifier := make([]byte, length)
for i := range identifier {
identifier[i] = chars[g.randSource.Intn(len(chars))]
}
return string(identifier)
}
func (g *seedGenerator) createSingleUser(userRepo repositories.UserRepository, index int) (database.User, error) {
const maxRetries = 10
var lastErr error
for attempt := range maxRetries {
randomID := g.generateRandomIdentifier()
user := &database.User{
Username: fmt.Sprintf("user_%s", randomID),
Email: fmt.Sprintf("user_%s@goyco.local", randomID),
Password: g.passwordHash,
EmailVerified: true,
}
if err := userRepo.Create(user); err != nil {
lastErr = err
if !isRetryableError(err, "username", "email", "users_username_key", "users_email_key") {
return database.User{}, fmt.Errorf("failed to create user (attempt %d/%d): %w", attempt+1, maxRetries, err)
}
continue
}
return *user, nil
}
return database.User{}, fmt.Errorf("failed to create user after %d attempts: %w", maxRetries, lastErr)
}
var (
sampleTitles = []string{"Amazing JavaScript Framework", "Python Best Practices", "Go Performance Tips", "Database Optimization", "Web Security Guide", "Machine Learning Basics", "Cloud Architecture", "DevOps Automation", "API Design Patterns", "Frontend Optimization", "Backend Scaling", "Container Orchestration", "Microservices Architecture", "Testing Strategies", "Code Review Process", "Version Control Best Practices", "Continuous Integration", "Monitoring and Alerting", "Error Handling Patterns", "Data Structures Explained"}
sampleDomains = []string{"example.com", "techblog.org", "devguide.net", "programming.io", "codeexamples.com", "tutorialhub.org", "bestpractices.dev", "learnprogramming.net", "codingtips.org", "softwareengineering.com"}
)
func (g *seedGenerator) createSinglePost(postRepo repositories.PostRepository, authorID uint, index int) (database.Post, error) {
title := sampleTitles[index%len(sampleTitles)]
if index >= len(sampleTitles) {
title = fmt.Sprintf("%s - Part %d", title, (index/len(sampleTitles))+1)
}
domain := sampleDomains[index%len(sampleDomains)]
content := fmt.Sprintf("Autogenerated seed post #%d\n\nThis is sample content for testing purposes. The post discusses %s and provides valuable insights.", index, title)
const maxRetries = 10
var lastErr error
for attempt := range maxRetries {
randomID := g.generateRandomIdentifier()
post := &database.Post{
Title: title,
URL: fmt.Sprintf("https://%s/article/%s", domain, randomID),
Content: content,
AuthorID: &authorID,
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
if err := postRepo.Create(post); err != nil {
lastErr = err
if !isRetryableError(err, "url", "posts_url_key") {
return database.Post{}, fmt.Errorf("failed to create post (attempt %d/%d): %w", attempt+1, maxRetries, err)
}
continue
}
return *post, nil
}
return database.Post{}, fmt.Errorf("failed to create post after %d attempts: %w", maxRetries, lastErr)
}
func (g *seedGenerator) createVotesForPost(voteRepo repositories.VoteRepository, users []database.User, post database.Post, avgVotesPerPost int) (int, error) {
numVotes := g.randSource.Intn(avgVotesPerPost*2 + 1)
if numVotes == 0 && avgVotesPerPost > 0 {
if g.randSource.Intn(5) > 0 {
numVotes = 1
}
}
totalVotes := 0
usedUsers := make(map[uint]bool)
for i := 0; i < numVotes && len(usedUsers) < len(users); i++ {
userIdx := g.randSource.Intn(len(users))
user := users[userIdx]
if usedUsers[user.ID] {
continue
}
usedUsers[user.ID] = true
voteTypeInt := g.randSource.Intn(10)
var voteType database.VoteType
if voteTypeInt < 7 {
voteType = database.VoteUp
} else {
voteType = database.VoteDown
}
vote := &database.Vote{
UserID: &user.ID,
PostID: post.ID,
Type: voteType,
}
if err := voteRepo.CreateOrUpdate(vote); err != nil {
return totalVotes, fmt.Errorf("create or update vote: %w", err)
}
totalVotes++
}
return totalVotes, nil
}
func (g *seedGenerator) updateSinglePostScore(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, post database.Post) error {
upVotes, downVotes, err := voteRepo.GetVoteCountsByPostID(post.ID)
if err != nil {
return fmt.Errorf("get vote counts: %w", err)
}
post.UpVotes = upVotes
post.DownVotes = downVotes
post.Score = upVotes - downVotes
return postRepo.Update(&post)
}

View File

@@ -2,8 +2,11 @@ package commands
import ( import (
"fmt" "fmt"
"math/rand"
"strings" "strings"
"sync"
"testing" "testing"
"time"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/repositories" "goyco/internal/repositories"
@@ -13,6 +16,18 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
var (
seedRandSource *rand.Rand
seedRandOnce sync.Once
)
func initSeedRand() {
seedRandOnce.Do(func() {
seed := time.Now().UnixNano()
seedRandSource = rand.New(rand.NewSource(seed))
})
}
func TestSeedCommand(t *testing.T) { func TestSeedCommand(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil { if err != nil {
@@ -271,6 +286,22 @@ func TestSeedDatabaseFlagParsing(t *testing.T) {
t.Errorf("zero votes-per-post should be valid, got error: %v", err) t.Errorf("zero votes-per-post should be valid, got error: %v", err)
} }
}) })
t.Run("help flag returns no error", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--help"})
if err != nil {
t.Errorf("help flag should return no error, got: %v", err)
}
})
t.Run("short help flag returns no error", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"-h"})
if err != nil {
t.Errorf("short help flag should return no error, got: %v", err)
}
})
} }
func TestSeedCommandIdempotency(t *testing.T) { func TestSeedCommandIdempotency(t *testing.T) {
@@ -531,3 +562,25 @@ func TestEnsureSeedUser(t *testing.T) {
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount) t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
} }
} }
func TestEnsureSeedUser_HandlesDatabaseErrors(t *testing.T) {
userRepo := testutils.NewMockUserRepository()
passwordHash := "test_password_hash"
dbError := fmt.Errorf("database connection failed")
userRepo.SetGetByUsernameError(dbError)
_, err := ensureSeedUser(userRepo, passwordHash)
if err == nil {
t.Fatal("Expected error when GetByUsername returns database error")
}
if !strings.Contains(err.Error(), "failed to check if seed user exists") {
t.Errorf("Expected error message about checking seed user, got: %v", err)
}
if !strings.Contains(err.Error(), dbError.Error()) {
t.Errorf("Expected error to wrap original database error, got: %v", err)
}
}