Compare commits
3 Commits
39598a166d
...
01f2b1fe75
| Author | SHA1 | Date | |
|---|---|---|---|
| 01f2b1fe75 | |||
| 28134c101c | |||
| 2f78370d43 |
@@ -291,24 +291,7 @@ func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (
|
||||
|
||||
|
||||
func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) {
|
||||
votes, err := voteRepo.GetByPostID(postID)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
upVotes := 0
|
||||
downVotes := 0
|
||||
|
||||
for _, vote := range votes {
|
||||
switch vote.Type {
|
||||
case database.VoteUp:
|
||||
upVotes++
|
||||
case database.VoteDown:
|
||||
downVotes++
|
||||
}
|
||||
}
|
||||
|
||||
return upVotes, downVotes, nil
|
||||
return voteRepo.GetVoteCountsByPostID(postID)
|
||||
}
|
||||
|
||||
func validateSeedConsistency(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error {
|
||||
|
||||
@@ -20,6 +20,7 @@ type VoteRepository interface {
|
||||
Count() (int64, error)
|
||||
CountByPostID(postID uint) (int64, error)
|
||||
CountByUserID(userID uint) (int64, error)
|
||||
GetVoteCountsByPostID(postID uint) (upVotes int, downVotes int, err error)
|
||||
WithTx(tx *gorm.DB) VoteRepository
|
||||
}
|
||||
|
||||
@@ -144,3 +145,20 @@ func (r *voteRepository) Count() (int64, error) {
|
||||
err := r.db.Model(&database.Vote{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *voteRepository) GetVoteCountsByPostID(postID uint) (int, int, error) {
|
||||
var result struct {
|
||||
UpVotes int64
|
||||
DownVotes int64
|
||||
}
|
||||
|
||||
err := r.db.Model(&database.Vote{}).
|
||||
Select("COUNT(CASE WHEN type = ? THEN 1 END) as up_votes, COUNT(CASE WHEN type = ? THEN 1 END) as down_votes", database.VoteUp, database.VoteDown).
|
||||
Where("post_id = ?", postID).
|
||||
Scan(&result).Error
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return int(result.UpVotes), int(result.DownVotes), nil
|
||||
}
|
||||
|
||||
@@ -965,6 +965,25 @@ func (m *MockVoteRepository) CountByUserID(userID uint) (int64, error) {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) GetVoteCountsByPostID(postID uint) (int, int, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
upVotes := 0
|
||||
downVotes := 0
|
||||
for _, vote := range m.votes {
|
||||
if vote.PostID == postID {
|
||||
switch vote.Type {
|
||||
case database.VoteUp:
|
||||
upVotes++
|
||||
case database.VoteDown:
|
||||
downVotes++
|
||||
}
|
||||
}
|
||||
}
|
||||
return upVotes, downVotes, nil
|
||||
}
|
||||
|
||||
func (m *MockVoteRepository) Count() (int64, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
Reference in New Issue
Block a user