Compare commits

..

2 Commits

2 changed files with 47 additions and 33 deletions

View File

@@ -123,46 +123,27 @@ func (j *JWTService) VerifyAccessToken(tokenString string) (uint, error) {
return claims.UserID, nil
}
func (j *JWTService) RefreshAccessToken(refreshTokenString string) (string, error) {
tokenHash := j.hashToken(refreshTokenString)
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
func (j *JWTService) RefreshAccessTokenWithRotation(refreshTokenString string) (string, string, error) {
refreshToken, user, err := j.validateRefreshToken(refreshTokenString)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrRefreshTokenInvalid
}
return "", fmt.Errorf("lookup refresh token: %w", err)
return "", "", err
}
if time.Now().After(refreshToken.ExpiresAt) {
j.refreshRepo.DeleteByID(refreshToken.ID)
return "", ErrRefreshTokenExpired
}
user, err := j.userRepo.GetByID(refreshToken.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
j.refreshRepo.DeleteByID(refreshToken.ID)
return "", ErrRefreshTokenInvalid
}
return "", fmt.Errorf("lookup user: %w", err)
}
if user.Locked {
j.refreshRepo.DeleteByID(refreshToken.ID)
return "", ErrAccountLocked
if err := j.refreshRepo.DeleteByID(refreshToken.ID); err != nil {
return "", "", fmt.Errorf("revoke refresh token: %w", err)
}
accessToken, err := j.GenerateAccessToken(user)
if err != nil {
return "", fmt.Errorf("generate access token: %w", err)
return "", "", fmt.Errorf("generate access token: %w", err)
}
return accessToken, nil
newRefreshToken, err := j.GenerateRefreshToken(user)
if err != nil {
return "", "", fmt.Errorf("generate refresh token: %w", err)
}
return accessToken, newRefreshToken, nil
}
func (j *JWTService) RevokeRefreshToken(refreshTokenString string) error {
@@ -354,6 +335,39 @@ func (j *JWTService) validateTokenMetadata(token *jwt.Token, claims *TokenClaims
return nil
}
func (j *JWTService) validateRefreshToken(refreshTokenString string) (*database.RefreshToken, *database.User, error) {
tokenHash := j.hashToken(refreshTokenString)
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, ErrRefreshTokenInvalid
}
return nil, nil, fmt.Errorf("lookup refresh token: %w", err)
}
if time.Now().After(refreshToken.ExpiresAt) {
_ = j.refreshRepo.DeleteByID(refreshToken.ID)
return nil, nil, ErrRefreshTokenExpired
}
user, err := j.userRepo.GetByID(refreshToken.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
_ = j.refreshRepo.DeleteByID(refreshToken.ID)
return nil, nil, ErrRefreshTokenInvalid
}
return nil, nil, fmt.Errorf("lookup user: %w", err)
}
if user.Locked {
_ = j.refreshRepo.DeleteByID(refreshToken.ID)
return nil, nil, ErrAccountLocked
}
return refreshToken, user, nil
}
func (j *JWTService) hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])

View File

@@ -71,7 +71,7 @@ func (s *SessionService) issueAuthResult(user *database.User) (*AuthResult, erro
}
func (s *SessionService) RefreshAccessToken(refreshToken string) (*AuthResult, error) {
accessToken, err := s.jwtService.RefreshAccessToken(refreshToken)
accessToken, newRefreshToken, err := s.jwtService.RefreshAccessTokenWithRotation(refreshToken)
if err != nil {
return nil, err
}
@@ -88,7 +88,7 @@ func (s *SessionService) RefreshAccessToken(refreshToken string) (*AuthResult, e
return &AuthResult{
AccessToken: accessToken,
RefreshToken: refreshToken,
RefreshToken: newRefreshToken,
User: sanitizeUser(user),
}, nil
}