refactor: use go generics
This commit is contained in:
@@ -42,14 +42,23 @@ func (p *ParallelProcessor) SetPasswordHash(hash string) {
|
|||||||
p.passwordHash = hash
|
p.passwordHash = hash
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
|
type indexedResult[T any] struct {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
value T
|
||||||
defer cancel()
|
index int
|
||||||
|
}
|
||||||
|
|
||||||
results := make(chan userResult, count)
|
func processInParallel[T any](
|
||||||
|
ctx context.Context,
|
||||||
|
maxWorkers int,
|
||||||
|
count int,
|
||||||
|
processor func(index int) (T, error),
|
||||||
|
errorPrefix string,
|
||||||
|
progress *ProgressIndicator,
|
||||||
|
) ([]T, error) {
|
||||||
|
results := make(chan indexedResult[T], count)
|
||||||
errors := make(chan error, count)
|
errors := make(chan error, count)
|
||||||
|
|
||||||
semaphore := make(chan struct{}, p.maxWorkers)
|
semaphore := make(chan struct{}, maxWorkers)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
for i := range count {
|
for i := range count {
|
||||||
@@ -65,13 +74,13 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
|
|||||||
}
|
}
|
||||||
defer func() { <-semaphore }()
|
defer func() { <-semaphore }()
|
||||||
|
|
||||||
user, err := p.createSingleUser(userRepo, index+1)
|
value, err := processor(index + 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors <- fmt.Errorf("create user %d: %w", index+1, err)
|
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
results <- userResult{user: user, index: index}
|
results <- indexedResult[T]{value: value, index: index}
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,7 +90,7 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
|
|||||||
close(errors)
|
close(errors)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
users := make([]database.User, count)
|
items := make([]T, count)
|
||||||
completed := 0
|
completed := 0
|
||||||
firstError := make(chan error, 1)
|
firstError := make(chan error, 1)
|
||||||
|
|
||||||
@@ -101,9 +110,9 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
|
|||||||
select {
|
select {
|
||||||
case result, ok := <-results:
|
case result, ok := <-results:
|
||||||
if !ok {
|
if !ok {
|
||||||
return users, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
users[result.index] = result.user
|
items[result.index] = result.value
|
||||||
completed++
|
completed++
|
||||||
if progress != nil {
|
if progress != nil {
|
||||||
progress.Update(completed)
|
progress.Update(completed)
|
||||||
@@ -111,26 +120,59 @@ func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepo
|
|||||||
case err := <-firstError:
|
case err := <-firstError:
|
||||||
return nil, err
|
return nil, err
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, fmt.Errorf("timeout creating users: %w", ctx.Err())
|
return nil, fmt.Errorf("timeout: %w", ctx.Err())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return users, nil
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return processInParallel(ctx, p.maxWorkers, count,
|
||||||
|
func(index int) (database.User, error) {
|
||||||
|
return p.createSingleUser(userRepo, index)
|
||||||
|
},
|
||||||
|
"create user",
|
||||||
|
progress,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) {
|
func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
results := make(chan postResult, count)
|
return processInParallel(ctx, p.maxWorkers, count,
|
||||||
|
func(index int) (database.Post, error) {
|
||||||
|
return p.createSinglePost(postRepo, authorID, index)
|
||||||
|
},
|
||||||
|
"create post",
|
||||||
|
progress,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func processItemsInParallel[T any, R any](
|
||||||
|
ctx context.Context,
|
||||||
|
maxWorkers int,
|
||||||
|
items []T,
|
||||||
|
processor func(index int, item T) (R, error),
|
||||||
|
errorPrefix string,
|
||||||
|
aggregator func(accumulator R, value R) R,
|
||||||
|
initialValue R,
|
||||||
|
progress *ProgressIndicator,
|
||||||
|
) (R, error) {
|
||||||
|
count := len(items)
|
||||||
|
results := make(chan indexedResult[R], count)
|
||||||
errors := make(chan error, count)
|
errors := make(chan error, count)
|
||||||
|
|
||||||
semaphore := make(chan struct{}, p.maxWorkers)
|
semaphore := make(chan struct{}, maxWorkers)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
for i := range count {
|
for i, item := range items {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(index int) {
|
go func(index int, item T) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -141,14 +183,14 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
|
|||||||
}
|
}
|
||||||
defer func() { <-semaphore }()
|
defer func() { <-semaphore }()
|
||||||
|
|
||||||
post, err := p.createSinglePost(postRepo, authorID, index+1)
|
value, err := processor(index, item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors <- fmt.Errorf("create post %d: %w", index+1, err)
|
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
results <- postResult{post: post, index: index}
|
results <- indexedResult[R]{value: value, index: index}
|
||||||
}(i)
|
}(i, item)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -157,7 +199,7 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
|
|||||||
close(errors)
|
close(errors)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
posts := make([]database.Post, count)
|
accumulator := initialValue
|
||||||
completed := 0
|
completed := 0
|
||||||
firstError := make(chan error, 1)
|
firstError := make(chan error, 1)
|
||||||
|
|
||||||
@@ -177,36 +219,55 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
|
|||||||
select {
|
select {
|
||||||
case result, ok := <-results:
|
case result, ok := <-results:
|
||||||
if !ok {
|
if !ok {
|
||||||
return posts, nil
|
return accumulator, nil
|
||||||
}
|
}
|
||||||
posts[result.index] = result.post
|
accumulator = aggregator(accumulator, result.value)
|
||||||
completed++
|
completed++
|
||||||
if progress != nil {
|
if progress != nil {
|
||||||
progress.Update(completed)
|
progress.Update(completed)
|
||||||
}
|
}
|
||||||
case err := <-firstError:
|
case err := <-firstError:
|
||||||
return nil, err
|
return initialValue, err
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, fmt.Errorf("timeout creating posts: %w", ctx.Err())
|
return initialValue, fmt.Errorf("timeout: %w", ctx.Err())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return posts, nil
|
return accumulator, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) {
|
func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
results := make(chan voteResult, len(posts))
|
return processItemsInParallel(ctx, p.maxWorkers, posts,
|
||||||
errors := make(chan error, len(posts))
|
func(index int, post database.Post) (int, error) {
|
||||||
|
return p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
|
||||||
|
},
|
||||||
|
"create votes for post",
|
||||||
|
func(acc, val int) int { return acc + val },
|
||||||
|
0,
|
||||||
|
progress,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
semaphore := make(chan struct{}, p.maxWorkers)
|
func processItemsInParallelNoResult[T any](
|
||||||
|
ctx context.Context,
|
||||||
|
maxWorkers int,
|
||||||
|
items []T,
|
||||||
|
processor func(index int, item T) error,
|
||||||
|
errorFormatter func(index int, item T, err error) error,
|
||||||
|
progress *ProgressIndicator,
|
||||||
|
) error {
|
||||||
|
count := len(items)
|
||||||
|
errors := make(chan error, count)
|
||||||
|
|
||||||
|
semaphore := make(chan struct{}, maxWorkers)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
for i, post := range posts {
|
for i, item := range items {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(index int, post database.Post) {
|
go func(index int, item T) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -217,91 +278,20 @@ func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepo
|
|||||||
}
|
}
|
||||||
defer func() { <-semaphore }()
|
defer func() { <-semaphore }()
|
||||||
|
|
||||||
votes, err := p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
|
err := processor(index, item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors <- fmt.Errorf("create votes for post %d: %w", post.ID, err)
|
if errorFormatter != nil {
|
||||||
return
|
errors <- errorFormatter(index, item, err)
|
||||||
|
} else {
|
||||||
|
errors <- fmt.Errorf("process item %d: %w", index+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
results <- voteResult{votes: votes, index: index}
|
|
||||||
}(i, post)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
wg.Wait()
|
|
||||||
close(results)
|
|
||||||
close(errors)
|
|
||||||
}()
|
|
||||||
|
|
||||||
totalVotes := 0
|
|
||||||
completed := 0
|
|
||||||
firstError := make(chan error, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for err := range errors {
|
|
||||||
if err != nil {
|
|
||||||
select {
|
|
||||||
case firstError <- err:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for completed < len(posts) {
|
|
||||||
select {
|
|
||||||
case result, ok := <-results:
|
|
||||||
if !ok {
|
|
||||||
return totalVotes, nil
|
|
||||||
}
|
|
||||||
totalVotes += result.votes
|
|
||||||
completed++
|
|
||||||
if progress != nil {
|
|
||||||
progress.Update(completed)
|
|
||||||
}
|
|
||||||
case err := <-firstError:
|
|
||||||
return 0, err
|
|
||||||
case <-ctx.Done():
|
|
||||||
return 0, fmt.Errorf("timeout creating votes: %w", ctx.Err())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return totalVotes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
errors := make(chan error, len(posts))
|
|
||||||
|
|
||||||
semaphore := make(chan struct{}, p.maxWorkers)
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
for i, post := range posts {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(index int, post database.Post) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case semaphore <- struct{}{}:
|
|
||||||
case <-ctx.Done():
|
|
||||||
errors <- ctx.Err()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() { <-semaphore }()
|
|
||||||
|
|
||||||
err := p.updateSinglePostScore(postRepo, voteRepo, post)
|
|
||||||
if err != nil {
|
|
||||||
errors <- fmt.Errorf("update post %d scores: %w", post.ID, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if progress != nil {
|
if progress != nil {
|
||||||
progress.Update(index + 1)
|
progress.Update(index + 1)
|
||||||
}
|
}
|
||||||
}(i, post)
|
}(i, item)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -318,19 +308,19 @@ func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.Pos
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type userResult struct {
|
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
|
||||||
user database.User
|
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||||
index int
|
defer cancel()
|
||||||
}
|
|
||||||
|
|
||||||
type postResult struct {
|
return processItemsInParallelNoResult(ctx, p.maxWorkers, posts,
|
||||||
post database.Post
|
func(index int, post database.Post) error {
|
||||||
index int
|
return p.updateSinglePostScore(postRepo, voteRepo, post)
|
||||||
}
|
},
|
||||||
|
func(index int, post database.Post, err error) error {
|
||||||
type voteResult struct {
|
return fmt.Errorf("update post %d scores: %w", post.ID, err)
|
||||||
votes int
|
},
|
||||||
index int
|
progress,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ParallelProcessor) generateRandomIdentifier() string {
|
func (p *ParallelProcessor) generateRandomIdentifier() string {
|
||||||
|
|||||||
Reference in New Issue
Block a user