@@ -229,14 +229,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
229
229
230
230
if self .tp_size > 1 :
231
231
# pass
232
- num_tokens , hidden_size = hidden_states .shape
233
- if num_tokens < self .tp_size :
234
- target_size = self .tp_size
235
- new_hidden_states = torch .empty ([target_size , hidden_size ],
236
- dtype = hidden_states .dtype ,
237
- device = hidden_states .device )
238
- new_hidden_states [:num_tokens ] = hidden_states
239
- hidden_states = new_hidden_states
232
+ num_tokens , _ = hidden_states .shape
233
+ padded_num_tokens = (self .tp_size -
234
+ num_tokens % self .tp_size ) % self .tp_size
235
+ # Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
236
+ if padded_num_tokens > 0 :
237
+ hidden_states = nn .functional .pad (hidden_states ,
238
+ (0 , 0 , 0 , padded_num_tokens ))
240
239
chunk_hidden_states = torch .tensor_split (hidden_states ,
241
240
self .tp_size ,
242
241
dim = 0 )
@@ -259,8 +258,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
259
258
dist .all_gather (list (chunk_hidden_states ), router_hidden_states ,
260
259
self .tp_group )
261
260
final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
262
- if num_tokens < self . tp_size :
263
- final_hidden_states = final_hidden_states [:num_tokens ]
261
+ if padded_num_tokens > 0 :
262
+ final_hidden_states = final_hidden_states [:- padded_num_tokens ]
264
263
else :
265
264
final_hidden_states = router_hidden_states
266
265
0 commit comments