We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 360ddbd commit 8f44a92Copy full SHA for 8f44a92
vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -410,6 +410,7 @@ def fused_topk(
410
411
if renormalize:
412
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
413
+
414
return topk_weights, topk_ids
415
416
@@ -443,7 +444,8 @@ def grouped_topk(hidden_states: torch.Tensor,
443
444
445
446
- return topk_weights, topk_ids
447
448
+ return topk_weights, topk_ids.to(torch.int32)
449
450
451
def get_config_dtype_str(dtype: torch.dtype,
0 commit comments