Skip to content

[Bugfix] Fixes for new marlin moe usage #18017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ def get_moe_method(
"input_activations")

if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
# group_size=None means channelwise
group_size = weight_quant.group_size or -1
# Prefer to use the MarlinMoE kernel when it is supported.
if not check_moe_marlin_supports_layer(layer,
weight_quant.group_size):
if not check_moe_marlin_supports_layer(layer, group_size):
if (weight_quant.strategy in QuantizationStrategy.GROUP and
weight_quant.actorder in (ActivationOrdering.GROUP,
ActivationOrdering.DYNAMIC)):
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,9 @@ def apply(
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input is not None:
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"Apply router weight on input is not supported for "
"fused Marlin MoE method.")

topk_weights, topk_ids = FusedMoE.select_experts(
Expand Down