diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4a6825c01fc..7b9f5e6ce0f 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -1,6 +1,7 @@ import asyncio import os from collections import defaultdict +from concurrent import futures from itertools import islice, repeat from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -239,7 +240,10 @@ def sort_by_driver_then_worker_ip(worker): ] self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) - self._run_workers("init_device") + # must run driver in background thread if len(workers) > 0 to avoid + # distributed init deadlock. + # (https://github.com/vllm-project/vllm/pull/7159) + self._run_workers("init_device", run_driver_in_background_thread=True) self._run_workers("load_model", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) @@ -309,6 +313,7 @@ def _run_workers( all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, + run_driver_in_background_thread: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers. Can be used in the following @@ -358,7 +363,6 @@ def _run_workers( # Just return futures return ray_worker_outputs - driver_worker_output = [] # In SPMD mode, the driver worker is the same as any other worker, # so we only explicitly execute on the driver worker if using a # non-SPMD worker class. @@ -366,25 +370,50 @@ def _run_workers( driver_args = args if all_args is None else all_args[0] driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] - # Start the driver worker after all the ray workers. + # Start the driver worker task after all the ray workers'. if not use_dummy_driver: - driver_worker_output = [ - self.driver_worker.execute_method(method, *driver_args, - **driver_kwargs) - ] + # Driver task will run in this python process + if run_driver_in_background_thread and ray_worker_outputs: + # Poll driver and worker tasks concurrently in background + # threads. + # + # This can avoid deadlock if the driver task is + # blocking on some out of band comm (e.g. torch.dist.init) + # that is invalidated by a Ray worker exception. + # + # See: https://github.com/vllm-project/vllm/issues/3455 + + with futures.ThreadPoolExecutor(max_workers=2) as executor: + driver_poll_thread = executor.submit( + self.driver_worker.execute_method, method, + *driver_args, **driver_kwargs) + worker_poll_thread = executor.submit( + ray.get, ray_worker_outputs) + + for completed_future in futures.as_completed( + [driver_poll_thread, worker_poll_thread]): + # Will raise exception if underlying thread raises + res = completed_future.result() + if not isinstance(res, list): + driver_output = [res] + else: + worker_outputs = res + all_worker_outputs = driver_output + worker_outputs + else: + driver_output = self.driver_worker.execute_method( + method, *driver_args, **driver_kwargs) + all_worker_outputs = [driver_output + ] + ray.get(ray_worker_outputs) else: assert self.driver_dummy_worker is not None - driver_worker_output = [ - ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) - ] - - # Get the results of the ray workers. - if self.workers: - ray_worker_outputs = ray.get(ray_worker_outputs) + driver_output = self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs) + all_worker_outputs = ray.get([driver_output] + + ray_worker_outputs) + else: + all_worker_outputs = ray.get(ray_worker_outputs) - return driver_worker_output + ray_worker_outputs + return all_worker_outputs def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: """Wait for futures returned from _run_workers() with