refactor: use go generics
This commit is contained in:
@@ -42,14 +42,23 @@ func (p *ParallelProcessor) SetPasswordHash(hash string) {
|
||||
p.passwordHash = hash
|
||||
}
|
||||
|
||||
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||
defer cancel()
|
||||
type indexedResult[T any] struct {
|
||||
value T
|
||||
index int
|
||||
}
|
||||
|
||||
results := make(chan userResult, count)
|
||||
func processInParallel[T any](
|
||||
ctx context.Context,
|
||||
maxWorkers int,
|
||||
count int,
|
||||
processor func(index int) (T, error),
|
||||
errorPrefix string,
|
||||
progress *ProgressIndicator,
|
||||
) ([]T, error) {
|
||||
results := make(chan indexedResult[T], count)
|
||||
errors := make(chan error, count)
|
||||
|
||||
semaphore := make(chan struct{}, p.maxWorkers)
|
||||
semaphore := make(chan struct{}, maxWorkers)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := range count {
|
||||
@@ -65,13 +74,13 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
|
||||
}
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
user, err := p.createSingleUser(userRepo, index+1)
|
||||
value, err := processor(index + 1)
|
||||
if err != nil {
|
||||
errors <- fmt.Errorf("create user %d: %w", index+1, err)
|
||||
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
|
||||
return
|
||||
}
|
||||
|
||||
results <- userResult{user: user, index: index}
|
||||
results <- indexedResult[T]{value: value, index: index}
|
||||
}(i)
|
||||
}
|
||||
|
||||
@@ -81,7 +90,7 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
|
||||
close(errors)
|
||||
}()
|
||||
|
||||
users := make([]database.User, count)
|
||||
items := make([]T, count)
|
||||
completed := 0
|
||||
firstError := make(chan error, 1)
|
||||
|
||||
@@ -101,9 +110,9 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
|
||||
select {
|
||||
case result, ok := <-results:
|
||||
if !ok {
|
||||
return users, nil
|
||||
return items, nil
|
||||
}
|
||||
users[result.index] = result.user
|
||||
items[result.index] = result.value
|
||||
completed++
|
||||
if progress != nil {
|
||||
progress.Update(completed)
|
||||
@@ -111,26 +120,59 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
|
||||
case err := <-firstError:
|
||||
return nil, err
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("timeout creating users: %w", ctx.Err())
|
||||
return nil, fmt.Errorf("timeout: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
return users, nil
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||
defer cancel()
|
||||
|
||||
return processInParallel(ctx, p.maxWorkers, count,
|
||||
func(index int) (database.User, error) {
|
||||
return p.createSingleUser(userRepo, index)
|
||||
},
|
||||
"create user",
|
||||
progress,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||
defer cancel()
|
||||
|
||||
results := make(chan postResult, count)
|
||||
return processInParallel(ctx, p.maxWorkers, count,
|
||||
func(index int) (database.Post, error) {
|
||||
return p.createSinglePost(postRepo, authorID, index)
|
||||
},
|
||||
"create post",
|
||||
progress,
|
||||
)
|
||||
}
|
||||
|
||||
func processItemsInParallel[T any, R any](
|
||||
ctx context.Context,
|
||||
maxWorkers int,
|
||||
items []T,
|
||||
processor func(index int, item T) (R, error),
|
||||
errorPrefix string,
|
||||
aggregator func(accumulator R, value R) R,
|
||||
initialValue R,
|
||||
progress *ProgressIndicator,
|
||||
) (R, error) {
|
||||
count := len(items)
|
||||
results := make(chan indexedResult[R], count)
|
||||
errors := make(chan error, count)
|
||||
|
||||
semaphore := make(chan struct{}, p.maxWorkers)
|
||||
semaphore := make(chan struct{}, maxWorkers)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := range count {
|
||||
for i, item := range items {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
go func(index int, item T) {
|
||||
defer wg.Done()
|
||||
|
||||
select {
|
||||
@@ -141,14 +183,14 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
|
||||
}
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
post, err := p.createSinglePost(postRepo, authorID, index+1)
|
||||
value, err := processor(index, item)
|
||||
if err != nil {
|
||||
errors <- fmt.Errorf("create post %d: %w", index+1, err)
|
||||
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
|
||||
return
|
||||
}
|
||||
|
||||
results <- postResult{post: post, index: index}
|
||||
}(i)
|
||||
results <- indexedResult[R]{value: value, index: index}
|
||||
}(i, item)
|
||||
}
|
||||
|
||||
go func() {
|
||||
@@ -157,7 +199,7 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
|
||||
close(errors)
|
||||
}()
|
||||
|
||||
posts := make([]database.Post, count)
|
||||
accumulator := initialValue
|
||||
completed := 0
|
||||
firstError := make(chan error, 1)
|
||||
|
||||
@@ -177,111 +219,55 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
|
||||
select {
|
||||
case result, ok := <-results:
|
||||
if !ok {
|
||||
return posts, nil
|
||||
return accumulator, nil
|
||||
}
|
||||
posts[result.index] = result.post
|
||||
accumulator = aggregator(accumulator, result.value)
|
||||
completed++
|
||||
if progress != nil {
|
||||
progress.Update(completed)
|
||||
}
|
||||
case err := <-firstError:
|
||||
return nil, err
|
||||
return initialValue, err
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("timeout creating posts: %w", ctx.Err())
|
||||
return initialValue, fmt.Errorf("timeout: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
return posts, nil
|
||||
return accumulator, nil
|
||||
}
|
||||
|
||||
func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||
defer cancel()
|
||||
|
||||
results := make(chan voteResult, len(posts))
|
||||
errors := make(chan error, len(posts))
|
||||
|
||||
semaphore := make(chan struct{}, p.maxWorkers)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, post := range posts {
|
||||
wg.Add(1)
|
||||
go func(index int, post database.Post) {
|
||||
defer wg.Done()
|
||||
|
||||
select {
|
||||
case semaphore <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
errors <- ctx.Err()
|
||||
return
|
||||
}
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
votes, err := p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
|
||||
if err != nil {
|
||||
errors <- fmt.Errorf("create votes for post %d: %w", post.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
results <- voteResult{votes: votes, index: index}
|
||||
}(i, post)
|
||||
}
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
close(errors)
|
||||
}()
|
||||
|
||||
totalVotes := 0
|
||||
completed := 0
|
||||
firstError := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
for err := range errors {
|
||||
if err != nil {
|
||||
select {
|
||||
case firstError <- err:
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for completed < len(posts) {
|
||||
select {
|
||||
case result, ok := <-results:
|
||||
if !ok {
|
||||
return totalVotes, nil
|
||||
}
|
||||
totalVotes += result.votes
|
||||
completed++
|
||||
if progress != nil {
|
||||
progress.Update(completed)
|
||||
}
|
||||
case err := <-firstError:
|
||||
return 0, err
|
||||
case <-ctx.Done():
|
||||
return 0, fmt.Errorf("timeout creating votes: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
return totalVotes, nil
|
||||
return processItemsInParallel(ctx, p.maxWorkers, posts,
|
||||
func(index int, post database.Post) (int, error) {
|
||||
return p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
|
||||
},
|
||||
"create votes for post",
|
||||
func(acc, val int) int { return acc + val },
|
||||
0,
|
||||
progress,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||
defer cancel()
|
||||
func processItemsInParallelNoResult[T any](
|
||||
ctx context.Context,
|
||||
maxWorkers int,
|
||||
items []T,
|
||||
processor func(index int, item T) error,
|
||||
errorFormatter func(index int, item T, err error) error,
|
||||
progress *ProgressIndicator,
|
||||
) error {
|
||||
count := len(items)
|
||||
errors := make(chan error, count)
|
||||
|
||||
errors := make(chan error, len(posts))
|
||||
|
||||
semaphore := make(chan struct{}, p.maxWorkers)
|
||||
semaphore := make(chan struct{}, maxWorkers)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, post := range posts {
|
||||
for i, item := range items {
|
||||
wg.Add(1)
|
||||
go func(index int, post database.Post) {
|
||||
go func(index int, item T) {
|
||||
defer wg.Done()
|
||||
|
||||
select {
|
||||
@@ -292,16 +278,20 @@ func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.Pos
|
||||
}
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
err := p.updateSinglePostScore(postRepo, voteRepo, post)
|
||||
err := processor(index, item)
|
||||
if err != nil {
|
||||
errors <- fmt.Errorf("update post %d scores: %w", post.ID, err)
|
||||
if errorFormatter != nil {
|
||||
errors <- errorFormatter(index, item, err)
|
||||
} else {
|
||||
errors <- fmt.Errorf("process item %d: %w", index+1, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if progress != nil {
|
||||
progress.Update(index + 1)
|
||||
}
|
||||
}(i, post)
|
||||
}(i, item)
|
||||
}
|
||||
|
||||
go func() {
|
||||
@@ -318,19 +308,19 @@ func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.Pos
|
||||
return nil
|
||||
}
|
||||
|
||||
type userResult struct {
|
||||
user database.User
|
||||
index int
|
||||
}
|
||||
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||
defer cancel()
|
||||
|
||||
type postResult struct {
|
||||
post database.Post
|
||||
index int
|
||||
}
|
||||
|
||||
type voteResult struct {
|
||||
votes int
|
||||
index int
|
||||
return processItemsInParallelNoResult(ctx, p.maxWorkers, posts,
|
||||
func(index int, post database.Post) error {
|
||||
return p.updateSinglePostScore(postRepo, voteRepo, post)
|
||||
},
|
||||
func(index int, post database.Post, err error) error {
|
||||
return fmt.Errorf("update post %d scores: %w", post.ID, err)
|
||||
},
|
||||
progress,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *ParallelProcessor) generateRandomIdentifier() string {
|
||||
|
||||
Reference in New Issue
Block a user