Skip to content

[Bug]: Dead lock in distributed inference when ray worker raises an exception #3455

Closed as not planned
@youkaichao

Description

@youkaichao

Your current environment

Any distributed inference tasks with ray currently suffer from this issue.

🐛 Describe the bug

Basic background of ray

ray provides an easy-to-use asynchronous execution framework:

def f():
    print(1)

import ray
ray.init()
marked_function = ray.remote(f) # mark `f` as a remote function that can be asynchronously executed
handle = marked_function.remote() # schedule a worker to asynchronously execute the function, immediately return a handle
result = ray.get(handle) # synchronously wait for the worker to finish and return the result

The way it deals with Exception is noteworthy, see comments in the below:

def f():
    print(1)
    raise RuntimeError("test")
    # the following line will not be executed
    print(2)

import ray
ray.init()
marked_function = ray.remote(f) # mark `f` as a remote function that can be asynchronously executed
handle = marked_function.remote() # schedule a worker to asynchronously execute the function, immediately return a handle

# ... do other work in the meantime ...
# the main process will not be notified if the worker fails

# only when we call `ray.get` will we be notified of the error
result = ray.get(handle) # raise the error that was thrown in the worker, wrapping it in a RayTaskError

The deadlock in distributed inference

The deadlock happens during initialization of distributed inference, i.e. creating process group to collaborate.

A minimal reproducible example looks like this:

import torch
import torch.distributed as dist

def f(rank, world_size, distributed_init_method):
    # raise RuntimeError # uncoment this line to see a deadlock
    dist.init_process_group(
        backend="gloo",
        init_method=distributed_init_method,
        world_size=world_size,
        rank=rank,
    )
    tensor = torch.zeros(1)
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    print(f"Rank {rank} has data {tensor.item()}")

import ray
ray.init()
marked_function = ray.remote(f)

distributed_init_method = "tcp://127.0.0.1:29500"
world_size = 2

# start the first process
handle = marked_function.remote(rank=0, world_size=world_size, distributed_init_method=distributed_init_method)

# the main process is the second process
# wait for the first process to join here to initialize the process group for distributed environment
dist.init_process_group(backend="gloo", init_method=distributed_init_method, world_size=world_size, rank=1)

# two processes are ready to communicate
tensor = torch.ones(1)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
print(f"Rank 1 has data {tensor.item()}")

result = ray.get(handle)

Normally it works with the following output:

2024-03-17 10:24:23,293 INFO worker.py:1724 -- Started a local Ray instance.
Rank 1 has data 1.0
(f pid=14616) Rank 0 has data 1.0

However, if the f function throws an exception before calling dist.init_process_group, it will be kept in an error state, waiting for the main process to call ray.get to error out; meanwhile, the main process is stuck at dist.init_process_group, waiting for the worker process to join to initialize the process group for distributed environment. Together they caused a deadlock.

How is this related with vLLM

vLLM uses ray for distributed inference, and the core code is attached below:

def _run_workers(
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
]
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Start the driver worker after all the ray workers.
driver_worker_output = getattr(self.driver_worker,
method)(*driver_args, **driver_kwargs)
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs

When calling init_model, both ray worker and the main process will reach the following function:

def init_model(self, cupy_port: Optional[int] = None) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank,
cupy_port, self.distributed_init_method)
# Initialize the model.
set_random_seed(self.model_config.seed)

And essentially we are back to the minimal reproducible example mentioned before. All of the exception before init_distributed_environment can cause deadlock.

In my case, my GPU driver has some problem, and torch.cuda.set_device raises an exception, causing the deadlock.

Solution to be discussed

Any suggestion to fix this is welcome.

Might be related: #2466 .

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions