diff --git a/internal/services/jwt_service.go b/internal/services/jwt_service.go index c7c2c88..3f43ae2 100644 --- a/internal/services/jwt_service.go +++ b/internal/services/jwt_service.go @@ -123,46 +123,27 @@ func (j *JWTService) VerifyAccessToken(tokenString string) (uint, error) { return claims.UserID, nil } -func (j *JWTService) RefreshAccessToken(refreshTokenString string) (string, error) { - - tokenHash := j.hashToken(refreshTokenString) - - refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash) +func (j *JWTService) RefreshAccessTokenWithRotation(refreshTokenString string) (string, string, error) { + refreshToken, user, err := j.validateRefreshToken(refreshTokenString) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return "", ErrRefreshTokenInvalid - } - return "", fmt.Errorf("lookup refresh token: %w", err) + return "", "", err } - if time.Now().After(refreshToken.ExpiresAt) { - - j.refreshRepo.DeleteByID(refreshToken.ID) - return "", ErrRefreshTokenExpired - } - - user, err := j.userRepo.GetByID(refreshToken.UserID) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - - j.refreshRepo.DeleteByID(refreshToken.ID) - return "", ErrRefreshTokenInvalid - } - return "", fmt.Errorf("lookup user: %w", err) - } - - if user.Locked { - - j.refreshRepo.DeleteByID(refreshToken.ID) - return "", ErrAccountLocked + if err := j.refreshRepo.DeleteByID(refreshToken.ID); err != nil { + return "", "", fmt.Errorf("revoke refresh token: %w", err) } accessToken, err := j.GenerateAccessToken(user) if err != nil { - return "", fmt.Errorf("generate access token: %w", err) + return "", "", fmt.Errorf("generate access token: %w", err) } - return accessToken, nil + newRefreshToken, err := j.GenerateRefreshToken(user) + if err != nil { + return "", "", fmt.Errorf("generate refresh token: %w", err) + } + + return accessToken, newRefreshToken, nil } func (j *JWTService) RevokeRefreshToken(refreshTokenString string) error { @@ -354,6 +335,39 @@ func (j *JWTService) validateTokenMetadata(token *jwt.Token, claims *TokenClaims return nil } +func (j *JWTService) validateRefreshToken(refreshTokenString string) (*database.RefreshToken, *database.User, error) { + tokenHash := j.hashToken(refreshTokenString) + + refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil, ErrRefreshTokenInvalid + } + return nil, nil, fmt.Errorf("lookup refresh token: %w", err) + } + + if time.Now().After(refreshToken.ExpiresAt) { + _ = j.refreshRepo.DeleteByID(refreshToken.ID) + return nil, nil, ErrRefreshTokenExpired + } + + user, err := j.userRepo.GetByID(refreshToken.UserID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + _ = j.refreshRepo.DeleteByID(refreshToken.ID) + return nil, nil, ErrRefreshTokenInvalid + } + return nil, nil, fmt.Errorf("lookup user: %w", err) + } + + if user.Locked { + _ = j.refreshRepo.DeleteByID(refreshToken.ID) + return nil, nil, ErrAccountLocked + } + + return refreshToken, user, nil +} + func (j *JWTService) hashToken(token string) string { hash := sha256.Sum256([]byte(token)) return hex.EncodeToString(hash[:])