@@ -228,12 +228,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
228
228
shared_output = self .shared_experts (hidden_states )
229
229
230
230
if self .tp_size > 1 :
231
- padded_num_tokens = (self .tp_size -
232
- num_tokens % self .tp_size ) % self .tp_size
231
+ num_padding_tokens = (self .tp_size -
232
+ num_tokens % self .tp_size ) % self .tp_size
233
233
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
234
- if padded_num_tokens > 0 :
235
- hidden_states = nn .functional .pad (hidden_states ,
236
- (0 , 0 , 0 , padded_num_tokens ))
234
+ if num_padding_tokens > 0 :
235
+ hidden_states = nn .functional .pad (
236
+ hidden_states , (0 , 0 , 0 , num_padding_tokens ))
237
237
chunk_hidden_states = torch .tensor_split (hidden_states ,
238
238
self .tp_size ,
239
239
dim = 0 )
@@ -256,8 +256,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
256
256
dist .all_gather (list (chunk_hidden_states ), router_hidden_states ,
257
257
self .tp_group )
258
258
final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
259
- if padded_num_tokens > 0 :
260
- final_hidden_states = final_hidden_states [:- padded_num_tokens ]
259
+ if num_padding_tokens > 0 :
260
+ final_hidden_states = final_hidden_states [:- num_padding_tokens ]
261
261
else :
262
262
final_hidden_states = router_hidden_states
263
263
0 commit comments