diff --git a/internal/services/jwt_service_test.go b/internal/services/jwt_service_test.go index 267dc09..ba675f8 100644 --- a/internal/services/jwt_service_test.go +++ b/internal/services/jwt_service_test.go @@ -348,39 +348,54 @@ func TestJWTService_VerifyAccessToken(t *testing.T) { }) } -func TestJWTService_RefreshAccessToken(t *testing.T) { +func TestJWTService_RefreshAccessTokenWithRotation(t *testing.T) { jwtService, userRepo, refreshRepo := createTestJWTService() user := createTestUser() userRepo.users[user.ID] = user - t.Run("Successful_Refresh", func(t *testing.T) { + refreshToken, err := jwtService.GenerateRefreshToken(user) + if err != nil { + t.Fatalf("Failed to generate refresh token: %v", err) + } - refreshToken, err := jwtService.GenerateRefreshToken(user) - if err != nil { - t.Fatalf("Failed to generate refresh token: %v", err) - } + accessToken, newRefreshToken, err := jwtService.RefreshAccessTokenWithRotation(refreshToken) + if err != nil { + t.Fatalf("Expected successful token refresh, got error: %v", err) + } - accessToken, err := jwtService.RefreshAccessToken(refreshToken) - if err != nil { - t.Fatalf("Expected successful token refresh, got error: %v", err) - } + if accessToken == "" { + t.Error("Expected non-empty access token") + } + if newRefreshToken == "" { + t.Error("Expected non-empty refresh token") + } + if newRefreshToken == refreshToken { + t.Error("Expected refresh token to rotate") + } - if accessToken == "" { - t.Error("Expected non-empty access token") - } + userID, err := jwtService.VerifyAccessToken(accessToken) + if err != nil { + t.Fatalf("Expected valid access token, got error: %v", err) + } + if userID != user.ID { + t.Errorf("Expected user ID %d, got %d", user.ID, userID) + } - userID, err := jwtService.VerifyAccessToken(accessToken) - if err != nil { - t.Fatalf("Expected valid access token, got error: %v", err) - } + _, _, err = jwtService.RefreshAccessTokenWithRotation(refreshToken) + if err == nil { + t.Fatal("Expected error for rotated refresh token") + } + if !errors.Is(err, ErrRefreshTokenInvalid) { + t.Errorf("Expected ErrRefreshTokenInvalid, got %v", err) + } - if userID != user.ID { - t.Errorf("Expected user ID %d, got %d", user.ID, userID) - } - }) + _, _, err = jwtService.RefreshAccessTokenWithRotation(newRefreshToken) + if err != nil { + t.Fatalf("Expected new refresh token to be usable, got error: %v", err) + } t.Run("Invalid_Refresh_Token", func(t *testing.T) { - _, err := jwtService.RefreshAccessToken("invalid-refresh-token") + _, _, err := jwtService.RefreshAccessTokenWithRotation("invalid-refresh-token") if err == nil { t.Error("Expected error for invalid refresh token") } @@ -390,7 +405,6 @@ func TestJWTService_RefreshAccessToken(t *testing.T) { }) t.Run("Expired_Refresh_Token", func(t *testing.T) { - refreshToken := &database.RefreshToken{ UserID: user.ID, TokenHash: "expired-token-hash", @@ -403,7 +417,7 @@ func TestJWTService_RefreshAccessToken(t *testing.T) { refreshToken.TokenHash = tokenHash refreshRepo.tokens[tokenHash] = refreshToken - _, err := jwtService.RefreshAccessToken(testToken) + _, _, err := jwtService.RefreshAccessTokenWithRotation(testToken) if err == nil { t.Error("Expected error for expired refresh token") } @@ -413,7 +427,6 @@ func TestJWTService_RefreshAccessToken(t *testing.T) { }) t.Run("User_Not_Found", func(t *testing.T) { - refreshToken, err := jwtService.GenerateRefreshToken(user) if err != nil { t.Fatalf("Failed to generate refresh token: %v", err) @@ -421,7 +434,7 @@ func TestJWTService_RefreshAccessToken(t *testing.T) { delete(userRepo.users, user.ID) - _, err = jwtService.RefreshAccessToken(refreshToken) + _, _, err = jwtService.RefreshAccessTokenWithRotation(refreshToken) if err == nil { t.Error("Expected error for non-existent user") } @@ -439,7 +452,7 @@ func TestJWTService_RefreshAccessToken(t *testing.T) { t.Fatalf("Failed to generate refresh token: %v", err) } - _, err = jwtService.RefreshAccessToken(refreshToken) + _, _, err = jwtService.RefreshAccessTokenWithRotation(refreshToken) if err == nil { t.Error("Expected error for locked user") } @@ -937,7 +950,7 @@ func TestJWTService_Integration(t *testing.T) { t.Errorf("Expected user ID %d, got %d", user.ID, userID) } - newAccessToken, err := jwtService.RefreshAccessToken(refreshToken) + newAccessToken, rotatedRefreshToken, err := jwtService.RefreshAccessTokenWithRotation(refreshToken) if err != nil { t.Fatalf("Failed to refresh access token: %v", err) } @@ -950,12 +963,12 @@ func TestJWTService_Integration(t *testing.T) { t.Errorf("Expected user ID %d, got %d", user.ID, userID) } - err = jwtService.RevokeRefreshToken(refreshToken) + err = jwtService.RevokeRefreshToken(rotatedRefreshToken) if err != nil { t.Fatalf("Failed to revoke refresh token: %v", err) } - _, err = jwtService.RefreshAccessToken(refreshToken) + _, _, err = jwtService.RefreshAccessTokenWithRotation(rotatedRefreshToken) if err == nil { t.Error("Expected error when using revoked refresh token") }