@@ -105,7 +105,7 @@ def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
105
105
group = _groups [group_name ]()
106
106
if group is None :
107
107
raise ValueError (f"Group { group_name } is destroyed." )
108
- group ._all_reduce (tensor )
108
+ group ._all_reduce_in_place (tensor )
109
109
110
110
@inplace_all_reduce .register_fake
111
111
def _ (tensor : torch .Tensor , group_name : str ) -> None :
@@ -118,7 +118,7 @@ def outplace_all_reduce(tensor: torch.Tensor,
118
118
group = _groups [group_name ]()
119
119
if group is None :
120
120
raise ValueError (f"Group { group_name } is destroyed." )
121
- return group ._all_reduce (tensor )
121
+ return group ._all_reduce_out_place (tensor )
122
122
123
123
@outplace_all_reduce .register_fake
124
124
def _ (tensor : torch .Tensor , group_name : str ) -> torch .Tensor :
@@ -338,40 +338,33 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
338
338
return input_
339
339
340
340
if not supports_custom_op ():
341
- return self ._all_reduce (input_ )
341
+ self ._all_reduce_in_place (input_ )
342
+ return input_
342
343
343
344
if self .tpu_communicator is not None and \
344
345
not self .tpu_communicator .disabled :
345
346
# TPU handles Dynamo with its own logic.
346
- return self ._all_reduce (input_ )
347
+ return self .tpu_communicator . all_reduce (input_ )
347
348
348
- if self .ca_comm is not None and self .ca_comm .should_custom_ar (input_ ):
349
+ if self .ca_comm is not None and \
350
+ not self .ca_comm .disabled and \
351
+ self .ca_comm .should_custom_ar (input_ ):
349
352
return torch .ops .vllm .outplace_all_reduce (
350
353
input_ , group_name = self .unique_name )
351
354
else :
352
355
torch .ops .vllm .inplace_all_reduce (input_ ,
353
356
group_name = self .unique_name )
354
357
return input_
355
358
356
- def _all_reduce (self , input_ : torch .Tensor ) -> torch .Tensor :
357
- """
358
- The actual all-reduce implementation.
359
-
360
- NOTE: This operation will be applied in-place or out-of-place.
361
- Always assume this function modifies its input, but use the return
362
- value as the output.
363
- """
359
+ def _all_reduce_out_place (self , input_ : torch .Tensor ) -> torch .Tensor :
364
360
ca_comm = self .ca_comm
361
+ assert ca_comm is not None
362
+ assert not ca_comm .disabled
363
+ out = ca_comm .custom_all_reduce (input_ )
364
+ assert out is not None
365
+ return out
365
366
366
- # For TPUs, use TPU communicator.
367
- tpu_comm = self .tpu_communicator
368
- if tpu_comm is not None and not tpu_comm .disabled :
369
- return tpu_comm .all_reduce (input_ )
370
-
371
- if ca_comm is not None :
372
- out = ca_comm .custom_all_reduce (input_ )
373
- if out is not None :
374
- return out
367
+ def _all_reduce_in_place (self , input_ : torch .Tensor ) -> None :
375
368
pynccl_comm = self .pynccl_comm
376
369
if (pynccl_comm is not None and not pynccl_comm .disabled ):
377
370
pynccl_comm .all_reduce (input_ )
@@ -380,7 +373,6 @@ def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
380
373
ipex .distributed .all_reduce (input_ , group = self .device_group )
381
374
else :
382
375
torch .distributed .all_reduce (input_ , group = self .device_group )
383
- return input_
384
376
385
377
def all_gather (self , input_ : torch .Tensor , dim : int = - 1 ) -> torch .Tensor :
386
378
world_size = self .world_size
0 commit comments