Skip to content

Commit f03a0ae

Browse files
committed
[Bugfix] Pad hidden_states to avoid cross-ring AllGatherV
Signed-off-by: ApsarasX <apsarax@outlook.com>
1 parent a0c3e9b commit f03a0ae

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
229229

230230
if self.tp_size > 1:
231231
# pass
232-
num_tokens, hidden_size = hidden_states.shape
233-
if num_tokens < self.tp_size:
234-
target_size = self.tp_size
235-
new_hidden_states = torch.empty([target_size, hidden_size],
236-
dtype=hidden_states.dtype,
237-
device=hidden_states.device)
238-
new_hidden_states[:num_tokens] = hidden_states
239-
hidden_states = new_hidden_states
232+
num_tokens, _ = hidden_states.shape
233+
padded_num_tokens = (self.tp_size -
234+
num_tokens % self.tp_size) % self.tp_size
235+
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
236+
if padded_num_tokens > 0:
237+
hidden_states = nn.functional.pad(hidden_states,
238+
(0, 0, 0, padded_num_tokens))
240239
chunk_hidden_states = torch.tensor_split(hidden_states,
241240
self.tp_size,
242241
dim=0)
@@ -259,8 +258,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
259258
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
260259
self.tp_group)
261260
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
262-
if num_tokens < self.tp_size:
263-
final_hidden_states = final_hidden_states[:num_tokens]
261+
if padded_num_tokens > 0:
262+
final_hidden_states = final_hidden_states[:-padded_num_tokens]
264263
else:
265264
final_hidden_states = router_hidden_states
266265

0 commit comments

Comments
 (0)