refactor: use go generics

This commit is contained in:
2025-11-21 16:56:26 +01:00
parent 0cd428d5d9
commit 62d466e4fa

View File

@@ -42,14 +42,23 @@ func (p *ParallelProcessor) SetPasswordHash(hash string) {
p.passwordHash = hash p.passwordHash = hash
} }
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) { type indexedResult[T any] struct {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) value T
defer cancel() index int
}
results := make(chan userResult, count) func processInParallel[T any](
ctx context.Context,
maxWorkers int,
count int,
processor func(index int) (T, error),
errorPrefix string,
progress *ProgressIndicator,
) ([]T, error) {
results := make(chan indexedResult[T], count)
errors := make(chan error, count) errors := make(chan error, count)
semaphore := make(chan struct{}, p.maxWorkers) semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := range count { for i := range count {
@@ -65,13 +74,13 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
} }
defer func() { <-semaphore }() defer func() { <-semaphore }()
user, err := p.createSingleUser(userRepo, index+1) value, err := processor(index + 1)
if err != nil { if err != nil {
errors <- fmt.Errorf("create user %d: %w", index+1, err) errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return return
} }
results <- userResult{user: user, index: index} results <- indexedResult[T]{value: value, index: index}
}(i) }(i)
} }
@@ -81,7 +90,7 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
close(errors) close(errors)
}() }()
users := make([]database.User, count) items := make([]T, count)
completed := 0 completed := 0
firstError := make(chan error, 1) firstError := make(chan error, 1)
@@ -101,9 +110,9 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
select { select {
case result, ok := <-results: case result, ok := <-results:
if !ok { if !ok {
return users, nil return items, nil
} }
users[result.index] = result.user items[result.index] = result.value
completed++ completed++
if progress != nil { if progress != nil {
progress.Update(completed) progress.Update(completed)
@@ -111,26 +120,59 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
case err := <-firstError: case err := <-firstError:
return nil, err return nil, err
case <-ctx.Done(): case <-ctx.Done():
return nil, fmt.Errorf("timeout creating users: %w", ctx.Err()) return nil, fmt.Errorf("timeout: %w", ctx.Err())
} }
} }
return users, nil return items, nil
}
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
return processInParallel(ctx, p.maxWorkers, count,
func(index int) (database.User, error) {
return p.createSingleUser(userRepo, index)
},
"create user",
progress,
)
} }
func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) { func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel() defer cancel()
results := make(chan postResult, count) return processInParallel(ctx, p.maxWorkers, count,
func(index int) (database.Post, error) {
return p.createSinglePost(postRepo, authorID, index)
},
"create post",
progress,
)
}
func processItemsInParallel[T any, R any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) (R, error),
errorPrefix string,
aggregator func(accumulator R, value R) R,
initialValue R,
progress *ProgressIndicator,
) (R, error) {
count := len(items)
results := make(chan indexedResult[R], count)
errors := make(chan error, count) errors := make(chan error, count)
semaphore := make(chan struct{}, p.maxWorkers) semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := range count { for i, item := range items {
wg.Add(1) wg.Add(1)
go func(index int) { go func(index int, item T) {
defer wg.Done() defer wg.Done()
select { select {
@@ -141,14 +183,14 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
} }
defer func() { <-semaphore }() defer func() { <-semaphore }()
post, err := p.createSinglePost(postRepo, authorID, index+1) value, err := processor(index, item)
if err != nil { if err != nil {
errors <- fmt.Errorf("create post %d: %w", index+1, err) errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return return
} }
results <- postResult{post: post, index: index} results <- indexedResult[R]{value: value, index: index}
}(i) }(i, item)
} }
go func() { go func() {
@@ -157,7 +199,7 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
close(errors) close(errors)
}() }()
posts := make([]database.Post, count) accumulator := initialValue
completed := 0 completed := 0
firstError := make(chan error, 1) firstError := make(chan error, 1)
@@ -177,36 +219,55 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
select { select {
case result, ok := <-results: case result, ok := <-results:
if !ok { if !ok {
return posts, nil return accumulator, nil
} }
posts[result.index] = result.post accumulator = aggregator(accumulator, result.value)
completed++ completed++
if progress != nil { if progress != nil {
progress.Update(completed) progress.Update(completed)
} }
case err := <-firstError: case err := <-firstError:
return nil, err return initialValue, err
case <-ctx.Done(): case <-ctx.Done():
return nil, fmt.Errorf("timeout creating posts: %w", ctx.Err()) return initialValue, fmt.Errorf("timeout: %w", ctx.Err())
} }
} }
return posts, nil return accumulator, nil
} }
func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) { func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel() defer cancel()
results := make(chan voteResult, len(posts)) return processItemsInParallel(ctx, p.maxWorkers, posts,
errors := make(chan error, len(posts)) func(index int, post database.Post) (int, error) {
return p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
},
"create votes for post",
func(acc, val int) int { return acc + val },
0,
progress,
)
}
semaphore := make(chan struct{}, p.maxWorkers) func processItemsInParallelNoResult[T any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) error,
errorFormatter func(index int, item T, err error) error,
progress *ProgressIndicator,
) error {
count := len(items)
errors := make(chan error, count)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup var wg sync.WaitGroup
for i, post := range posts { for i, item := range items {
wg.Add(1) wg.Add(1)
go func(index int, post database.Post) { go func(index int, item T) {
defer wg.Done() defer wg.Done()
select { select {
@@ -217,91 +278,20 @@ func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepo
} }
defer func() { <-semaphore }() defer func() { <-semaphore }()
votes, err := p.createVotesForPost(voteRepo, users, post, avgVotesPerPost) err := processor(index, item)
if err != nil { if err != nil {
errors <- fmt.Errorf("create votes for post %d: %w", post.ID, err) if errorFormatter != nil {
return errors <- errorFormatter(index, item, err)
} else {
errors <- fmt.Errorf("process item %d: %w", index+1, err)
} }
results <- voteResult{votes: votes, index: index}
}(i, post)
}
go func() {
wg.Wait()
close(results)
close(errors)
}()
totalVotes := 0
completed := 0
firstError := make(chan error, 1)
go func() {
for err := range errors {
if err != nil {
select {
case firstError <- err:
default:
}
return
}
}
}()
for completed < len(posts) {
select {
case result, ok := <-results:
if !ok {
return totalVotes, nil
}
totalVotes += result.votes
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-firstError:
return 0, err
case <-ctx.Done():
return 0, fmt.Errorf("timeout creating votes: %w", ctx.Err())
}
}
return totalVotes, nil
}
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
errors := make(chan error, len(posts))
semaphore := make(chan struct{}, p.maxWorkers)
var wg sync.WaitGroup
for i, post := range posts {
wg.Add(1)
go func(index int, post database.Post) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
err := p.updateSinglePostScore(postRepo, voteRepo, post)
if err != nil {
errors <- fmt.Errorf("update post %d scores: %w", post.ID, err)
return return
} }
if progress != nil { if progress != nil {
progress.Update(index + 1) progress.Update(index + 1)
} }
}(i, post) }(i, item)
} }
go func() { go func() {
@@ -318,19 +308,19 @@ func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.Pos
return nil return nil
} }
type userResult struct { func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
user database.User ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
index int defer cancel()
}
type postResult struct { return processItemsInParallelNoResult(ctx, p.maxWorkers, posts,
post database.Post func(index int, post database.Post) error {
index int return p.updateSinglePostScore(postRepo, voteRepo, post)
} },
func(index int, post database.Post, err error) error {
type voteResult struct { return fmt.Errorf("update post %d scores: %w", post.ID, err)
votes int },
index int progress,
)
} }
func (p *ParallelProcessor) generateRandomIdentifier() string { func (p *ParallelProcessor) generateRandomIdentifier() string {