Compare commits

..

2 Commits

2 changed files with 47 additions and 33 deletions

View File

@@ -123,46 +123,27 @@ func (j *JWTService) VerifyAccessToken(tokenString string) (uint, error) {
return claims.UserID, nil return claims.UserID, nil
} }
func (j *JWTService) RefreshAccessToken(refreshTokenString string) (string, error) { func (j *JWTService) RefreshAccessTokenWithRotation(refreshTokenString string) (string, string, error) {
refreshToken, user, err := j.validateRefreshToken(refreshTokenString)
tokenHash := j.hashToken(refreshTokenString)
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { return "", "", err
return "", ErrRefreshTokenInvalid
}
return "", fmt.Errorf("lookup refresh token: %w", err)
} }
if time.Now().After(refreshToken.ExpiresAt) { if err := j.refreshRepo.DeleteByID(refreshToken.ID); err != nil {
return "", "", fmt.Errorf("revoke refresh token: %w", err)
j.refreshRepo.DeleteByID(refreshToken.ID)
return "", ErrRefreshTokenExpired
}
user, err := j.userRepo.GetByID(refreshToken.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
j.refreshRepo.DeleteByID(refreshToken.ID)
return "", ErrRefreshTokenInvalid
}
return "", fmt.Errorf("lookup user: %w", err)
}
if user.Locked {
j.refreshRepo.DeleteByID(refreshToken.ID)
return "", ErrAccountLocked
} }
accessToken, err := j.GenerateAccessToken(user) accessToken, err := j.GenerateAccessToken(user)
if err != nil { if err != nil {
return "", fmt.Errorf("generate access token: %w", err) return "", "", fmt.Errorf("generate access token: %w", err)
} }
return accessToken, nil newRefreshToken, err := j.GenerateRefreshToken(user)
if err != nil {
return "", "", fmt.Errorf("generate refresh token: %w", err)
}
return accessToken, newRefreshToken, nil
} }
func (j *JWTService) RevokeRefreshToken(refreshTokenString string) error { func (j *JWTService) RevokeRefreshToken(refreshTokenString string) error {
@@ -354,6 +335,39 @@ func (j *JWTService) validateTokenMetadata(token *jwt.Token, claims *TokenClaims
return nil return nil
} }
func (j *JWTService) validateRefreshToken(refreshTokenString string) (*database.RefreshToken, *database.User, error) {
tokenHash := j.hashToken(refreshTokenString)
refreshToken, err := j.refreshRepo.GetByTokenHash(tokenHash)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, ErrRefreshTokenInvalid
}
return nil, nil, fmt.Errorf("lookup refresh token: %w", err)
}
if time.Now().After(refreshToken.ExpiresAt) {
_ = j.refreshRepo.DeleteByID(refreshToken.ID)
return nil, nil, ErrRefreshTokenExpired
}
user, err := j.userRepo.GetByID(refreshToken.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
_ = j.refreshRepo.DeleteByID(refreshToken.ID)
return nil, nil, ErrRefreshTokenInvalid
}
return nil, nil, fmt.Errorf("lookup user: %w", err)
}
if user.Locked {
_ = j.refreshRepo.DeleteByID(refreshToken.ID)
return nil, nil, ErrAccountLocked
}
return refreshToken, user, nil
}
func (j *JWTService) hashToken(token string) string { func (j *JWTService) hashToken(token string) string {
hash := sha256.Sum256([]byte(token)) hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:]) return hex.EncodeToString(hash[:])

View File

@@ -71,7 +71,7 @@ func (s *SessionService) issueAuthResult(user *database.User) (*AuthResult, erro
} }
func (s *SessionService) RefreshAccessToken(refreshToken string) (*AuthResult, error) { func (s *SessionService) RefreshAccessToken(refreshToken string) (*AuthResult, error) {
accessToken, err := s.jwtService.RefreshAccessToken(refreshToken) accessToken, newRefreshToken, err := s.jwtService.RefreshAccessTokenWithRotation(refreshToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -88,7 +88,7 @@ func (s *SessionService) RefreshAccessToken(refreshToken string) (*AuthResult, e
return &AuthResult{ return &AuthResult{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, RefreshToken: newRefreshToken,
User: sanitizeUser(user), User: sanitizeUser(user),
}, nil }, nil
} }