diff --git a/internal/middleware/auth_test.go b/internal/middleware/auth_test.go index 0ec1e66..6c11041 100644 --- a/internal/middleware/auth_test.go +++ b/internal/middleware/auth_test.go @@ -28,8 +28,8 @@ func TestNewAuthWithoutAuthorization(t *testing.T) { middleware := NewAuth(verifier) handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true - if id := GetUserIDFromContext(r.Context()); id != 0 { - t.Fatalf("unexpected user id %d", id) + if id := GetUserIDFromContext(r.Context()); id != nil { + t.Fatalf("unexpected user id %v", id) } })) @@ -54,8 +54,13 @@ func TestNewAuthValidToken(t *testing.T) { handlerCalled := false handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handlerCalled = true - if id := GetUserIDFromContext(r.Context()); id != 99 { - t.Fatalf("expected user id 99, got %d", id) + id := GetUserIDFromContext(r.Context()) + if id == nil || *id != 99 { + v := uint(0) + if id != nil { + v = *id + } + t.Fatalf("expected user id 99, got %d", v) } })) @@ -131,11 +136,12 @@ func TestNewAuthVerifierError(t *testing.T) { func TestGetUserIDFromContext(t *testing.T) { ctx := context.WithValue(context.Background(), UserIDKey, uint(55)) - if id := GetUserIDFromContext(ctx); id != 55 { - t.Fatalf("expected id 55, got %d", id) + id := GetUserIDFromContext(ctx) + if id == nil || *id != 55 { + t.Fatalf("expected id 55, got %v", id) } - if id := GetUserIDFromContext(context.Background()); id != 0 { - t.Fatalf("expected zero when id missing, got %d", id) + if ptr := GetUserIDFromContext(context.Background()); ptr != nil { + t.Fatalf("expected nil when id missing, got %v", ptr) } }