Skip to content

[Core] Concurrently Poll Ray Driver and Worker Results to Avoid Distributed Init Deadlock #7159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
61 changes: 44 additions & 17 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import concurrent.futures
import os
from collections import defaultdict
from itertools import islice, repeat
Expand Down Expand Up @@ -239,7 +240,9 @@ def sort_by_driver_then_worker_ip(worker):
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)

self._run_workers("init_device")
# must run driver in background thread if len(workers) > 0 to avoid
# NCCL init deadlock (https://github.com/vllm-project/vllm/pull/7159)
self._run_workers("init_device", run_driver_in_background_thread=True)
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
Expand Down Expand Up @@ -309,6 +312,7 @@ def _run_workers(
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
run_driver_in_background_thread: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
Expand Down Expand Up @@ -358,33 +362,56 @@ def _run_workers(
# Just return futures
return ray_worker_outputs

driver_worker_output = []
# In SPMD mode, the driver worker is the same as any other worker,
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if not self.use_ray_spmd_worker:
# Concurrently poll driver worker and remote ray workers
# to avoid deadlock when performing distributed init
# (see: https://github.com/vllm-project/vllm/issues/3455)
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]

# Start the driver worker after all the ray workers.
# Start the driver worker task after all the ray workers'.
if not use_dummy_driver:
driver_worker_output = [
self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)
]
if (not run_driver_in_background_thread
# no background thread required when there are
# no concurrent worker tasks
or not ray_worker_outputs):
all_worker_outputs = [
self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs)
]
else:
# Poll driver and worker tasks concurrently
# in background threads
with concurrent.futures.ThreadPoolExecutor(
max_workers=2) as executor:
driver_poll_thread = executor.submit(
self.driver_worker.execute_method, method,
*driver_args, **driver_kwargs)
worker_poll_thread = executor.submit(
ray.get, ray_worker_outputs)

for completed_future in concurrent.futures.as_completed(
[driver_poll_thread, worker_poll_thread]):
# Will raise exception if underlying thread raises
res = completed_future.result()
if not isinstance(res, list):
driver_output = [res]
else:
worker_outputs = res
all_worker_outputs = driver_output + worker_outputs
else:
assert self.driver_dummy_worker is not None
driver_worker_output = [
ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
]

# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
driver_output = self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs)
all_worker_outputs = ray.get([driver_output] +
ray_worker_outputs)
else:
all_worker_outputs = ray.get(ray_worker_outputs)

return driver_worker_output + ray_worker_outputs
return all_worker_outputs

def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
Expand Down
Loading