Skip to content

Commit 0172a58

Browse files
youkaichaosumitd2
authored andcommitted
[torch.compile] improve allreduce registration (vllm-project#9061)
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent 0cce3f0 commit 0172a58

File tree

2 files changed

+21
-32
lines changed

2 files changed

+21
-32
lines changed

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,24 +265,21 @@ def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
265265

266266
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
267267
# when custom allreduce is disabled, this will be None
268-
if self.disabled:
268+
if self.disabled or not self.should_custom_ar(input):
269269
return None
270270
if self._IS_CAPTURING:
271271
if torch.cuda.is_current_stream_capturing():
272-
if self.should_custom_ar(input):
273-
return self.all_reduce_reg(input)
272+
return self.all_reduce_reg(input)
274273
else:
275-
if self.should_custom_ar(input):
276-
# if warm up, mimic the allocation pattern
277-
# since custom allreduce is out-of-place
278-
return torch.empty_like(input)
274+
# if warm up, mimic the allocation pattern
275+
# since custom allreduce is out-of-place
276+
return torch.empty_like(input)
279277
else:
280278
# note: outside of cuda graph context,
281279
# custom allreduce incurs a cost of cudaMemcpy, which should
282280
# be small(<=1% of overall latency) compared to the performance
283281
# gains of using custom kernels
284-
if self.should_custom_ar(input):
285-
return self.all_reduce_unreg(input)
282+
return self.all_reduce_unreg(input)
286283

287284
return None
288285

vllm/distributed/parallel_state.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
105105
group = _groups[group_name]()
106106
if group is None:
107107
raise ValueError(f"Group {group_name} is destroyed.")
108-
group._all_reduce(tensor)
108+
group._all_reduce_in_place(tensor)
109109

110110
@inplace_all_reduce.register_fake
111111
def _(tensor: torch.Tensor, group_name: str) -> None:
@@ -118,7 +118,7 @@ def outplace_all_reduce(tensor: torch.Tensor,
118118
group = _groups[group_name]()
119119
if group is None:
120120
raise ValueError(f"Group {group_name} is destroyed.")
121-
return group._all_reduce(tensor)
121+
return group._all_reduce_out_place(tensor)
122122

123123
@outplace_all_reduce.register_fake
124124
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
@@ -338,40 +338,33 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
338338
return input_
339339

340340
if not supports_custom_op():
341-
return self._all_reduce(input_)
341+
self._all_reduce_in_place(input_)
342+
return input_
342343

343344
if self.tpu_communicator is not None and \
344345
not self.tpu_communicator.disabled:
345346
# TPU handles Dynamo with its own logic.
346-
return self._all_reduce(input_)
347+
return self.tpu_communicator.all_reduce(input_)
347348

348-
if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
349+
if self.ca_comm is not None and \
350+
not self.ca_comm.disabled and \
351+
self.ca_comm.should_custom_ar(input_):
349352
return torch.ops.vllm.outplace_all_reduce(
350353
input_, group_name=self.unique_name)
351354
else:
352355
torch.ops.vllm.inplace_all_reduce(input_,
353356
group_name=self.unique_name)
354357
return input_
355358

356-
def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
357-
"""
358-
The actual all-reduce implementation.
359-
360-
NOTE: This operation will be applied in-place or out-of-place.
361-
Always assume this function modifies its input, but use the return
362-
value as the output.
363-
"""
359+
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
364360
ca_comm = self.ca_comm
361+
assert ca_comm is not None
362+
assert not ca_comm.disabled
363+
out = ca_comm.custom_all_reduce(input_)
364+
assert out is not None
365+
return out
365366

366-
# For TPUs, use TPU communicator.
367-
tpu_comm = self.tpu_communicator
368-
if tpu_comm is not None and not tpu_comm.disabled:
369-
return tpu_comm.all_reduce(input_)
370-
371-
if ca_comm is not None:
372-
out = ca_comm.custom_all_reduce(input_)
373-
if out is not None:
374-
return out
367+
def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
375368
pynccl_comm = self.pynccl_comm
376369
if (pynccl_comm is not None and not pynccl_comm.disabled):
377370
pynccl_comm.all_reduce(input_)
@@ -380,7 +373,6 @@ def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
380373
ipex.distributed.all_reduce(input_, group=self.device_group)
381374
else:
382375
torch.distributed.all_reduce(input_, group=self.device_group)
383-
return input_
384376

385377
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
386378
world_size = self.world_size

0 commit comments

Comments
 (0)