From 6045d22b0efef900d7f42f1aeeafc143684a5b78 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 22:32:52 -0700 Subject: [PATCH 1/2] improve custom allreduce registration --- vllm/distributed/parallel_state.py | 38 ++++++++++++------------------ 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d3ac4eb78b1..6e1970bfed9 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -105,7 +105,7 @@ def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") - group._all_reduce(tensor) + group._all_reduce_in_place(tensor) @inplace_all_reduce.register_fake def _(tensor: torch.Tensor, group_name: str) -> None: @@ -118,7 +118,7 @@ def outplace_all_reduce(tensor: torch.Tensor, group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") - return group._all_reduce(tensor) + return group._all_reduce_out_place(tensor) @outplace_all_reduce.register_fake def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: @@ -338,14 +338,17 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: return input_ if not supports_custom_op(): - return self._all_reduce(input_) + self._all_reduce_in_place(input_) + return input_ if self.tpu_communicator is not None and \ not self.tpu_communicator.disabled: # TPU handles Dynamo with its own logic. - return self._all_reduce(input_) + return self.tpu_communicator.all_reduce(input_) - if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_): + if self.ca_comm is not None and \ + not self.ca_comm.disabled and \ + self.ca_comm.should_custom_ar(input_): return torch.ops.vllm.outplace_all_reduce( input_, group_name=self.unique_name) else: @@ -353,25 +356,15 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: group_name=self.unique_name) return input_ - def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor: - """ - The actual all-reduce implementation. - - NOTE: This operation will be applied in-place or out-of-place. - Always assume this function modifies its input, but use the return - value as the output. - """ + def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: ca_comm = self.ca_comm + assert ca_comm is not None + assert not ca_comm.disabled + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out - # For TPUs, use TPU communicator. - tpu_comm = self.tpu_communicator - if tpu_comm is not None and not tpu_comm.disabled: - return tpu_comm.all_reduce(input_) - - if ca_comm is not None: - out = ca_comm.custom_all_reduce(input_) - if out is not None: - return out + def _all_reduce_in_place(self, input_: torch.Tensor) -> None: pynccl_comm = self.pynccl_comm if (pynccl_comm is not None and not pynccl_comm.disabled): pynccl_comm.all_reduce(input_) @@ -380,7 +373,6 @@ def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor: ipex.distributed.all_reduce(input_, group=self.device_group) else: torch.distributed.all_reduce(input_, group=self.device_group) - return input_ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size From c3919d9ae4256d8d96bb8109876858bc4ab9ed93 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 4 Oct 2024 14:36:48 -0700 Subject: [PATCH 2/2] clean up code --- .../device_communicators/custom_all_reduce.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index c95192a5a1b..7de5b05a0b0 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -265,24 +265,21 @@ def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: # when custom allreduce is disabled, this will be None - if self.disabled: + if self.disabled or not self.should_custom_ar(input): return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - if self.should_custom_ar(input): - return self.all_reduce_reg(input) + return self.all_reduce_reg(input) else: - if self.should_custom_ar(input): - # if warm up, mimic the allocation pattern - # since custom allreduce is out-of-place - return torch.empty_like(input) + # if warm up, mimic the allocation pattern + # since custom allreduce is out-of-place + return torch.empty_like(input) else: # note: outside of cuda graph context, # custom allreduce incurs a cost of cudaMemcpy, which should # be small(<=1% of overall latency) compared to the performance # gains of using custom kernels - if self.should_custom_ar(input): - return self.all_reduce_unreg(input) + return self.all_reduce_unreg(input) return None