Skip to content

[Bug]: 'FutureWrapper' object has no attribute 'sampled_token_ids' when using ray to perform pipeline parallelism #19063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
1 task done
havever opened this issue Jun 3, 2025 · 2 comments
Labels
bug Something isn't working

Comments

@havever
Copy link

havever commented Jun 3, 2025

Your current environment

The output of python collect_env.py
==============================
Versions of relevant libraries
==============================
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pynvml==11.5.0
[pip3] pyzmq==26.4.0
[pip3] torch==2.6.0
[pip3] torchaudio==2.6.0
[pip3] torchvision==0.21.0
[pip3] transformers==4.52.3
[pip3] triton==3.2.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
[conda] pynvml                    11.5.0                   pypi_0    pypi
[conda] pyzmq                     26.4.0                   pypi_0    pypi
[conda] torch                     2.6.0                    pypi_0    pypi
[conda] torchaudio                2.6.0                    pypi_0    pypi
[conda] torchvision               0.21.0                   pypi_0    pypi
[conda] transformers              4.52.3                   pypi_0    pypi
[conda] triton                    3.2.0                    pypi_0    pypi

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
Neuron SDK Version           : N/A
vLLM Version                 : 0.8.5.post1
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
        GPU0    GPU1    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      SYS     0-13,28-41      0               N/A
GPU1    SYS      X      14-27,42-55     1               N/A

🐛 Describe the bug

I ran the first test function in vllm/tests/v1/engine/test_engine_core.py, and modified the EngineArgs by adding the parameters pipeline_parallel_size=2 and distributed_executor_backend='ray'. However, the error message shows that FutureWrapper has no attribute sampled_token_ids.

At the same time, I modified the update_from_output method of the Scheduler class in vllm/v1/core/sched/scheduler.py, trying to extract result from model_runner_output, but this did not solve the problem. When running the test code, I encountered another error: 'ModelRunnerOutput' object has no attribute 'finished_req_ids'. It seems like somewhere inside Ray, a SchedulerOutput was expected, but a ModelRunnerOutput was actually passed.

Below, I’ve included the two error messages and the function I used from the test folder.

def test_engine_core(monkeypatch: pytest.MonkeyPatch):

with monkeypatch.context() as m:
    m.setenv("VLLM_USE_V1", "1")
    """Setup the EngineCore."""
    engine_args = EngineArgs(model=MODEL_NAME, pipeline_parallel_size=2,
                             distributed_executor_backend='ray')
    vllm_config = engine_args.create_engine_config()
    executor_class = Executor.get_class(vllm_config)

    engine_core = EngineCore(vllm_config=vllm_config,
                             executor_class=executor_class,
                             log_stats=True)
    """Test basic request lifecycle."""

    # First request.
    engine_core.add_request(make_request())
    assert len(engine_core.scheduler.waiting) == 1
    assert len(engine_core.scheduler.running) == 0

    _ = engine_core.step()
    assert len(engine_core.scheduler.waiting) == 0
    assert len(engine_core.scheduler.running) == 1

    # Second request.
    engine_core.add_request(make_request())
    assert len(engine_core.scheduler.waiting) == 1
    assert len(engine_core.scheduler.running) == 1

    _ = engine_core.step()
    assert len(engine_core.scheduler.waiting) == 0
    assert len(engine_core.scheduler.running) == 2

    # Add two requests in a row.
    engine_core.add_request(make_request())
    engine_core.add_request(make_request())
    assert len(engine_core.scheduler.waiting) == 2
    assert len(engine_core.scheduler.running) == 2

    _ = engine_core.step()
    assert len(engine_core.scheduler.waiting) == 0
    assert len(engine_core.scheduler.running) == 4

    # Loop through until they are all done.
    while len(engine_core.step().outputs) > 0:
        pass

    assert len(engine_core.scheduler.waiting) == 0
    assert len(engine_core.scheduler.running) == 0
    """Test abort cycle."""

    # Basic abort.
    req = make_request()
    request_id = req.request_id

    engine_core.add_request(req)
    assert len(engine_core.scheduler.waiting) == 1
    assert len(engine_core.scheduler.running) == 0
    assert engine_core.scheduler.has_unfinished_requests()
    assert not engine_core.scheduler.has_finished_requests()

    _ = engine_core.step()
    assert len(engine_core.scheduler.waiting) == 0
    assert len(engine_core.scheduler.running) == 1
    assert engine_core.scheduler.has_unfinished_requests()
    assert not engine_core.scheduler.has_finished_requests()

    engine_core.abort_requests([request_id])
    assert len(engine_core.scheduler.waiting) == 0
    assert len(engine_core.scheduler.running) == 0
    assert not engine_core.scheduler.has_unfinished_requests()
    assert engine_core.scheduler.has_finished_requests()

    _ = engine_core.step()
    assert not engine_core.scheduler.has_unfinished_requests()
    assert not engine_core.scheduler.has_finished_requests()

    # Add, step, abort 1 of the 3.
    req0 = make_request()
    req1 = make_request()
    req2 = make_request()

    engine_core.add_request(req0)
    engine_core.add_request(req1)
    assert len(engine_core.scheduler.waiting) == 2
    assert len(engine_core.scheduler.running) == 0

    _ = engine_core.step()
    assert len(engine_core.scheduler.waiting) == 0
    assert len(engine_core.scheduler.running) == 2

    engine_core.add_request(req2)
    assert len(engine_core.scheduler.waiting) == 1
    assert len(engine_core.scheduler.running) == 2

    _ = engine_core.step()
    assert len(engine_core.scheduler.waiting) == 0
    assert len(engine_core.scheduler.running) == 3

    # Abort just one.
    engine_core.abort_requests([req1.request_id])
    assert len(engine_core.scheduler.waiting) == 0
    assert len(engine_core.scheduler.running) == 2

    _ = engine_core.step()
    assert len(engine_core.scheduler.waiting) == 0
    assert len(engine_core.scheduler.running) == 2

    # Abort the other requests at the same time.
    engine_core.abort_requests([req2.request_id, req0.request_id])
    assert len(engine_core.scheduler.waiting) == 0
    assert len(engine_core.scheduler.running) == 0

    # Sending duplicate requests with same request_id
    req0 = make_request()
    req1 = make_request()
    req0.request_id = req1.request_id = "test"
    engine_core.add_request(req0)

    while len(engine_core.step().outputs) > 0:
        pass

    engine_core.add_request(req1)
    while len(engine_core.step().outputs) > 0:
        pass

    assert len(engine_core.scheduler.waiting) == 0
    assert len(engine_core.scheduler.running) == 0

Traceback (most recent call last):
File "/data/user/test/vllm/tests/utils.py", line 727, in wrapper
f(*args, **kwargs)
File "/data/user/test/vllm/tests/v1/engine/test_engine_core.py", line 67, in test_engine_core
_ = engine_core.step()
File "/data/conda_envs/sllm-store/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 208, in step
engine_core_outputs = self.scheduler.update_from_output(
File "/data/conda_envs/sllm-store/lib/python3.10/site-packages/vllm/v1/core/sched/scheduler.py", line 630, in update_from_output
sampled_token_ids = model_runner_output.sampled_token_ids
AttributeError: 'FutureWrapper' object has no attribute 'sampled_token_ids'

Traceback (most recent call last):
File "/data/user/test/vllm/tests/utils.py", line 727, in wrapper
f(args, **kwargs)
File "/data/user/test/vllm/tests/v1/engine/test_engine_core.py", line 91, in test_engine_core
while len(engine_core.step().outputs) > 0:
File "/data/conda_envs/sllm-store/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 208, in step
engine_core_outputs = self.scheduler.update_from_output(
File "/data/conda_envs/sllm-store/lib/python3.10/site-packages/vllm/v1/core/sched/scheduler.py", line 630, in update_from_output
model_runner_output = model_runner_output.result()
File "/data/conda_envs/sllm-store/lib/python3.10/site-packages/vllm/v1/executor/ray_distributed_executor.py", line 24, in result
return self.ref.get()
File "/data/conda_envs/sllm-store/lib/python3.10/site-packages/ray/experimental/compiled_dag_ref.py", line 150, in get
return _process_return_vals(return_vals, True)
File "/data/conda_envs/sllm-store/lib/python3.10/site-packages/ray/experimental/compiled_dag_ref.py", line 27, in _process_return_vals
raise val.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::RayWorkerWrapper.ray_call() (pid=2675487, ip=
)
File "/data/conda_envs/sllm-store/lib/python3.10/site-packages/vllm/executor/ray_utils.py", line 139, in execute_model_ray
output = self.worker.model_runner.execute_model(
File "/data/conda_envs/sllm-store/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/data/conda_envs/sllm-store/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1012, in execute_model
self._update_states(scheduler_output)
File "/data/conda_envs/sllm-store/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 296, in _update_states
for req_id in scheduler_output.finished_req_ids:
AttributeError: 'ModelRunnerOutput' object has no attribute 'finished_req_ids'

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@havever havever added the bug Something isn't working label Jun 3, 2025
@chaunceyjiang
Copy link
Contributor

What version of vLLM are you using? Have you tried the latest version?

@havever
Copy link
Author

havever commented Jun 3, 2025

What version of vLLM are you using? Have you tried the latest version?

Thanks for the reminder. I just tried the latest version 0.9.1, but the issue still persists. From what I can see, it seems like 0.9.1 doesn’t fix the pp bug yet?

I tried 0.9.1 and 0.8.5 post

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants