Skip to content

[torch.compile] improve allreduce registration #9061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._all_reduce(tensor)
group._all_reduce_in_place(tensor)

@inplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> None:
Expand All @@ -118,7 +118,7 @@ def outplace_all_reduce(tensor: torch.Tensor,
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce(tensor)
return group._all_reduce_out_place(tensor)

@outplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
Expand Down Expand Up @@ -338,40 +338,33 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
return input_

if not supports_custom_op():
return self._all_reduce(input_)
self._all_reduce_in_place(input_)
return input_

if self.tpu_communicator is not None and \
not self.tpu_communicator.disabled:
# TPU handles Dynamo with its own logic.
return self._all_reduce(input_)
return self.tpu_communicator.all_reduce(input_)

if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
if self.ca_comm is not None and \
not self.ca_comm.disabled and \
self.ca_comm.should_custom_ar(input_):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is usually False during model capture, since we tend to use large prefill inputs for profiling, which is when the model capture happens.

return torch.ops.vllm.outplace_all_reduce(
input_, group_name=self.unique_name)
else:
torch.ops.vllm.inplace_all_reduce(input_,
group_name=self.unique_name)
return input_

def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
The actual all-reduce implementation.

NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
value as the output.
"""
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
ca_comm = self.ca_comm
assert ca_comm is not None
assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out

# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_reduce(input_)

if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
return out
def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
Expand All @@ -380,7 +373,6 @@ def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
ipex.distributed.all_reduce(input_, group=self.device_group)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
Expand Down
Loading