refactor: use go generics

This commit is contained in:
2025-11-21 16:56:26 +01:00
parent 0cd428d5d9
commit 62d466e4fa

View File

@@ -42,14 +42,23 @@ func (p *ParallelProcessor) SetPasswordHash(hash string) {
p.passwordHash = hash
}
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
type indexedResult[T any] struct {
value T
index int
}
results := make(chan userResult, count)
func processInParallel[T any](
ctx context.Context,
maxWorkers int,
count int,
processor func(index int) (T, error),
errorPrefix string,
progress *ProgressIndicator,
) ([]T, error) {
results := make(chan indexedResult[T], count)
errors := make(chan error, count)
semaphore := make(chan struct{}, p.maxWorkers)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i := range count {
@@ -65,13 +74,13 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
}
defer func() { <-semaphore }()
user, err := p.createSingleUser(userRepo, index+1)
value, err := processor(index + 1)
if err != nil {
errors <- fmt.Errorf("create user %d: %w", index+1, err)
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return
}
results <- userResult{user: user, index: index}
results <- indexedResult[T]{value: value, index: index}
}(i)
}
@@ -81,7 +90,7 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
close(errors)
}()
users := make([]database.User, count)
items := make([]T, count)
completed := 0
firstError := make(chan error, 1)
@@ -101,9 +110,9 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
select {
case result, ok := <-results:
if !ok {
return users, nil
return items, nil
}
users[result.index] = result.user
items[result.index] = result.value
completed++
if progress != nil {
progress.Update(completed)
@@ -111,26 +120,59 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
case err := <-firstError:
return nil, err
case <-ctx.Done():
return nil, fmt.Errorf("timeout creating users: %w", ctx.Err())
return nil, fmt.Errorf("timeout: %w", ctx.Err())
}
}
return users, nil
return items, nil
}
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
return processInParallel(ctx, p.maxWorkers, count,
func(index int) (database.User, error) {
return p.createSingleUser(userRepo, index)
},
"create user",
progress,
)
}
func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
results := make(chan postResult, count)
return processInParallel(ctx, p.maxWorkers, count,
func(index int) (database.Post, error) {
return p.createSinglePost(postRepo, authorID, index)
},
"create post",
progress,
)
}
func processItemsInParallel[T any, R any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) (R, error),
errorPrefix string,
aggregator func(accumulator R, value R) R,
initialValue R,
progress *ProgressIndicator,
) (R, error) {
count := len(items)
results := make(chan indexedResult[R], count)
errors := make(chan error, count)
semaphore := make(chan struct{}, p.maxWorkers)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i := range count {
for i, item := range items {
wg.Add(1)
go func(index int) {
go func(index int, item T) {
defer wg.Done()
select {
@@ -141,14 +183,14 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
}
defer func() { <-semaphore }()
post, err := p.createSinglePost(postRepo, authorID, index+1)
value, err := processor(index, item)
if err != nil {
errors <- fmt.Errorf("create post %d: %w", index+1, err)
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return
}
results <- postResult{post: post, index: index}
}(i)
results <- indexedResult[R]{value: value, index: index}
}(i, item)
}
go func() {
@@ -157,7 +199,7 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
close(errors)
}()
posts := make([]database.Post, count)
accumulator := initialValue
completed := 0
firstError := make(chan error, 1)
@@ -177,111 +219,55 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
select {
case result, ok := <-results:
if !ok {
return posts, nil
return accumulator, nil
}
posts[result.index] = result.post
accumulator = aggregator(accumulator, result.value)
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-firstError:
return nil, err
return initialValue, err
case <-ctx.Done():
return nil, fmt.Errorf("timeout creating posts: %w", ctx.Err())
return initialValue, fmt.Errorf("timeout: %w", ctx.Err())
}
}
return posts, nil
return accumulator, nil
}
func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
results := make(chan voteResult, len(posts))
errors := make(chan error, len(posts))
semaphore := make(chan struct{}, p.maxWorkers)
var wg sync.WaitGroup
for i, post := range posts {
wg.Add(1)
go func(index int, post database.Post) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
votes, err := p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
if err != nil {
errors <- fmt.Errorf("create votes for post %d: %w", post.ID, err)
return
}
results <- voteResult{votes: votes, index: index}
}(i, post)
}
go func() {
wg.Wait()
close(results)
close(errors)
}()
totalVotes := 0
completed := 0
firstError := make(chan error, 1)
go func() {
for err := range errors {
if err != nil {
select {
case firstError <- err:
default:
}
return
}
}
}()
for completed < len(posts) {
select {
case result, ok := <-results:
if !ok {
return totalVotes, nil
}
totalVotes += result.votes
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-firstError:
return 0, err
case <-ctx.Done():
return 0, fmt.Errorf("timeout creating votes: %w", ctx.Err())
}
}
return totalVotes, nil
return processItemsInParallel(ctx, p.maxWorkers, posts,
func(index int, post database.Post) (int, error) {
return p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
},
"create votes for post",
func(acc, val int) int { return acc + val },
0,
progress,
)
}
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
func processItemsInParallelNoResult[T any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) error,
errorFormatter func(index int, item T, err error) error,
progress *ProgressIndicator,
) error {
count := len(items)
errors := make(chan error, count)
errors := make(chan error, len(posts))
semaphore := make(chan struct{}, p.maxWorkers)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i, post := range posts {
for i, item := range items {
wg.Add(1)
go func(index int, post database.Post) {
go func(index int, item T) {
defer wg.Done()
select {
@@ -292,16 +278,20 @@ func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.Pos
}
defer func() { <-semaphore }()
err := p.updateSinglePostScore(postRepo, voteRepo, post)
err := processor(index, item)
if err != nil {
errors <- fmt.Errorf("update post %d scores: %w", post.ID, err)
if errorFormatter != nil {
errors <- errorFormatter(index, item, err)
} else {
errors <- fmt.Errorf("process item %d: %w", index+1, err)
}
return
}
if progress != nil {
progress.Update(index + 1)
}
}(i, post)
}(i, item)
}
go func() {
@@ -318,19 +308,19 @@ func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.Pos
return nil
}
type userResult struct {
user database.User
index int
}
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
type postResult struct {
post database.Post
index int
}
type voteResult struct {
votes int
index int
return processItemsInParallelNoResult(ctx, p.maxWorkers, posts,
func(index int, post database.Post) error {
return p.updateSinglePostScore(postRepo, voteRepo, post)
},
func(index int, post database.Post, err error) error {
return fmt.Errorf("update post %d scores: %w", post.ID, err)
},
progress,
)
}
func (p *ParallelProcessor) generateRandomIdentifier() string {