Skip to content

Commit 8f44a92

Browse files
authored
[BugFix] fix group_topk (#8430)
1 parent 360ddbd commit 8f44a92

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ def fused_topk(
410410

411411
if renormalize:
412412
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
413+
413414
return topk_weights, topk_ids
414415

415416

@@ -443,7 +444,8 @@ def grouped_topk(hidden_states: torch.Tensor,
443444

444445
if renormalize:
445446
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
446-
return topk_weights, topk_ids
447+
448+
return topk_weights, topk_ids.to(torch.int32)
447449

448450

449451
def get_config_dtype_str(dtype: torch.dtype,

0 commit comments

Comments
 (0)