From 62d466e4fa93c737cb1cb344e8ea81eeb61771cc Mon Sep 17 00:00:00 2001 From: Kharec Date: Fri, 21 Nov 2025 16:56:26 +0100 Subject: [PATCH] refactor: use go generics --- cmd/goyco/commands/parallel_processor.go | 230 +++++++++++------------ 1 file changed, 110 insertions(+), 120 deletions(-) diff --git a/cmd/goyco/commands/parallel_processor.go b/cmd/goyco/commands/parallel_processor.go index da1fb7b..00bca9d 100644 --- a/cmd/goyco/commands/parallel_processor.go +++ b/cmd/goyco/commands/parallel_processor.go @@ -42,14 +42,23 @@ func (p *ParallelProcessor) SetPasswordHash(hash string) { p.passwordHash = hash } -func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) { - ctx, cancel := context.WithTimeout(context.Background(), p.timeout) - defer cancel() +type indexedResult[T any] struct { + value T + index int +} - results := make(chan userResult, count) +func processInParallel[T any]( + ctx context.Context, + maxWorkers int, + count int, + processor func(index int) (T, error), + errorPrefix string, + progress *ProgressIndicator, +) ([]T, error) { + results := make(chan indexedResult[T], count) errors := make(chan error, count) - semaphore := make(chan struct{}, p.maxWorkers) + semaphore := make(chan struct{}, maxWorkers) var wg sync.WaitGroup for i := range count { @@ -65,13 +74,13 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo } defer func() { <-semaphore }() - user, err := p.createSingleUser(userRepo, index+1) + value, err := processor(index + 1) if err != nil { - errors <- fmt.Errorf("create user %d: %w", index+1, err) + errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err) return } - results <- userResult{user: user, index: index} + results <- indexedResult[T]{value: value, index: index} }(i) } @@ -81,7 +90,7 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo close(errors) }() - users := make([]database.User, count) + items := make([]T, count) completed := 0 firstError := make(chan error, 1) @@ -101,9 +110,9 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo select { case result, ok := <-results: if !ok { - return users, nil + return items, nil } - users[result.index] = result.user + items[result.index] = result.value completed++ if progress != nil { progress.Update(completed) @@ -111,26 +120,59 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo case err := <-firstError: return nil, err case <-ctx.Done(): - return nil, fmt.Errorf("timeout creating users: %w", ctx.Err()) + return nil, fmt.Errorf("timeout: %w", ctx.Err()) } } - return users, nil + return items, nil +} + +func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) { + ctx, cancel := context.WithTimeout(context.Background(), p.timeout) + defer cancel() + + return processInParallel(ctx, p.maxWorkers, count, + func(index int) (database.User, error) { + return p.createSingleUser(userRepo, index) + }, + "create user", + progress, + ) } func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) { ctx, cancel := context.WithTimeout(context.Background(), p.timeout) defer cancel() - results := make(chan postResult, count) + return processInParallel(ctx, p.maxWorkers, count, + func(index int) (database.Post, error) { + return p.createSinglePost(postRepo, authorID, index) + }, + "create post", + progress, + ) +} + +func processItemsInParallel[T any, R any]( + ctx context.Context, + maxWorkers int, + items []T, + processor func(index int, item T) (R, error), + errorPrefix string, + aggregator func(accumulator R, value R) R, + initialValue R, + progress *ProgressIndicator, +) (R, error) { + count := len(items) + results := make(chan indexedResult[R], count) errors := make(chan error, count) - semaphore := make(chan struct{}, p.maxWorkers) + semaphore := make(chan struct{}, maxWorkers) var wg sync.WaitGroup - for i := range count { + for i, item := range items { wg.Add(1) - go func(index int) { + go func(index int, item T) { defer wg.Done() select { @@ -141,14 +183,14 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo } defer func() { <-semaphore }() - post, err := p.createSinglePost(postRepo, authorID, index+1) + value, err := processor(index, item) if err != nil { - errors <- fmt.Errorf("create post %d: %w", index+1, err) + errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err) return } - results <- postResult{post: post, index: index} - }(i) + results <- indexedResult[R]{value: value, index: index} + }(i, item) } go func() { @@ -157,7 +199,7 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo close(errors) }() - posts := make([]database.Post, count) + accumulator := initialValue completed := 0 firstError := make(chan error, 1) @@ -177,111 +219,55 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo select { case result, ok := <-results: if !ok { - return posts, nil + return accumulator, nil } - posts[result.index] = result.post + accumulator = aggregator(accumulator, result.value) completed++ if progress != nil { progress.Update(completed) } case err := <-firstError: - return nil, err + return initialValue, err case <-ctx.Done(): - return nil, fmt.Errorf("timeout creating posts: %w", ctx.Err()) + return initialValue, fmt.Errorf("timeout: %w", ctx.Err()) } } - return posts, nil + return accumulator, nil } func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) { ctx, cancel := context.WithTimeout(context.Background(), p.timeout) defer cancel() - results := make(chan voteResult, len(posts)) - errors := make(chan error, len(posts)) - - semaphore := make(chan struct{}, p.maxWorkers) - var wg sync.WaitGroup - - for i, post := range posts { - wg.Add(1) - go func(index int, post database.Post) { - defer wg.Done() - - select { - case semaphore <- struct{}{}: - case <-ctx.Done(): - errors <- ctx.Err() - return - } - defer func() { <-semaphore }() - - votes, err := p.createVotesForPost(voteRepo, users, post, avgVotesPerPost) - if err != nil { - errors <- fmt.Errorf("create votes for post %d: %w", post.ID, err) - return - } - - results <- voteResult{votes: votes, index: index} - }(i, post) - } - - go func() { - wg.Wait() - close(results) - close(errors) - }() - - totalVotes := 0 - completed := 0 - firstError := make(chan error, 1) - - go func() { - for err := range errors { - if err != nil { - select { - case firstError <- err: - default: - } - return - } - } - }() - - for completed < len(posts) { - select { - case result, ok := <-results: - if !ok { - return totalVotes, nil - } - totalVotes += result.votes - completed++ - if progress != nil { - progress.Update(completed) - } - case err := <-firstError: - return 0, err - case <-ctx.Done(): - return 0, fmt.Errorf("timeout creating votes: %w", ctx.Err()) - } - } - - return totalVotes, nil + return processItemsInParallel(ctx, p.maxWorkers, posts, + func(index int, post database.Post) (int, error) { + return p.createVotesForPost(voteRepo, users, post, avgVotesPerPost) + }, + "create votes for post", + func(acc, val int) int { return acc + val }, + 0, + progress, + ) } -func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error { - ctx, cancel := context.WithTimeout(context.Background(), p.timeout) - defer cancel() +func processItemsInParallelNoResult[T any]( + ctx context.Context, + maxWorkers int, + items []T, + processor func(index int, item T) error, + errorFormatter func(index int, item T, err error) error, + progress *ProgressIndicator, +) error { + count := len(items) + errors := make(chan error, count) - errors := make(chan error, len(posts)) - - semaphore := make(chan struct{}, p.maxWorkers) + semaphore := make(chan struct{}, maxWorkers) var wg sync.WaitGroup - for i, post := range posts { + for i, item := range items { wg.Add(1) - go func(index int, post database.Post) { + go func(index int, item T) { defer wg.Done() select { @@ -292,16 +278,20 @@ func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.Pos } defer func() { <-semaphore }() - err := p.updateSinglePostScore(postRepo, voteRepo, post) + err := processor(index, item) if err != nil { - errors <- fmt.Errorf("update post %d scores: %w", post.ID, err) + if errorFormatter != nil { + errors <- errorFormatter(index, item, err) + } else { + errors <- fmt.Errorf("process item %d: %w", index+1, err) + } return } if progress != nil { progress.Update(index + 1) } - }(i, post) + }(i, item) } go func() { @@ -318,19 +308,19 @@ func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.Pos return nil } -type userResult struct { - user database.User - index int -} +func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error { + ctx, cancel := context.WithTimeout(context.Background(), p.timeout) + defer cancel() -type postResult struct { - post database.Post - index int -} - -type voteResult struct { - votes int - index int + return processItemsInParallelNoResult(ctx, p.maxWorkers, posts, + func(index int, post database.Post) error { + return p.updateSinglePostScore(postRepo, voteRepo, post) + }, + func(index int, post database.Post, err error) error { + return fmt.Errorf("update post %d scores: %w", post.ID, err) + }, + progress, + ) } func (p *ParallelProcessor) generateRandomIdentifier() string {