Skip to content

[Bugfix] [ROCm]: Remove assertion logic when using AITER fused moe in unquantizedMethod to reenable LLama4 BF16 #18205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2025

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented May 15, 2025

Remove the assert introduced in this PR #15956 to remove the blockage for running llama4 bf16 model with AITER fused moe.

@bnellnm May we know what was the intention behind of adding the assertions on the rocm_aiter_fused_moe code path?

Error when running llama4

[rank0]: Traceback (most recent call last):
[rank0]:   File "/app/upstreamcheckrocm64/vllmtests/fast_check_acc/test_accuracy.py", line 92, in <module>
[rank0]:     main(args)
[rank0]:   File "/app/upstreamcheckrocm64/vllmtests/fast_check_acc/test_accuracy.py", line 18, in main
[rank0]:     llm = LLM(**dataclasses.asdict(engine_args))
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/utils.py", line 1176, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/entrypoints/llm.py", line 250, in __init__
[rank0]:     self.llm_engine = LLMEngine.from_engine_args(
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/engine/llm_engine.py", line 511, in from_engine_args
[rank0]:     return engine_cls.from_vllm_config(
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/engine/llm_engine.py", line 487, in from_vllm_config
[rank0]:     return cls(
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/engine/llm_engine.py", line 278, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/engine/llm_engine.py", line 423, in _initialize_kv_caches
[rank0]:     self.model_executor.determine_num_available_blocks())
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/executor/executor_base.py", line 103, in determine_num_available_blocks
[rank0]:     results = self.collective_rpc("determine_num_available_blocks")
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/executor/executor_base.py", line 331, in collective_rpc
[rank0]:     return self._run_workers(method, *args, **(kwargs or {}))
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/executor/mp_distributed_executor.py", line 185, in _run_workers
[rank0]:     driver_worker_output = run_method(self.driver_worker, sent_method,
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/utils.py", line 2552, in run_method
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/worker/worker.py", line 254, in determine_num_available_blocks
[rank0]:     self.model_runner.profile_run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/worker/model_runner.py", line 1302, in profile_run
[rank0]:     self._dummy_run(max_num_batched_tokens, max_num_seqs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/worker/model_runner.py", line 1428, in _dummy_run
[rank0]:     self.execute_model(model_input, kv_caches, intermediate_tensors)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/worker/model_runner.py", line 1846, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/models/mllama4.py", line 768, in forward
[rank0]:     return self.language_model(input_ids, positions, intermediate_tensors,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/models/llama.py", line 573, in forward
[rank0]:     model_output = self.model(input_ids, positions, intermediate_tensors,
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/compilation/decorators.py", line 172, in __call__
[rank0]:     return self.forward(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/models/llama.py", line 384, in forward
[rank0]:     hidden_states, residual = layer(positions, hidden_states, residual)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/models/llama4.py", line 318, in forward
[rank0]:     hidden_states = self.feed_forward(hidden_states)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/models/llama4.py", line 98, in forward
[rank0]:     routed_out = self.experts(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/layers/fused_moe/layer.py", line 1262, in forward
[rank0]:     return self.forward_impl(hidden_states, router_logits)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/layers/fused_moe/layer.py", line 1328, in forward_impl
[rank0]:     final_hidden_states = self.quant_method.apply(
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/layers/fused_moe/layer.py", line 416, in apply
[rank0]:     return self.forward(
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/custom_op.py", line 23, in forward
[rank0]:     return self._forward_method(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/custom_op.py", line 38, in forward_hip
[rank0]:     return self.forward_cuda(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/layers/fused_moe/layer.py", line 506, in forward_cuda
[rank0]:     assert not apply_router_weight_on_input
[rank0]: AssertionError

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@LucasWilkinson LucasWilkinson self-requested a review May 15, 2025 13:46
@LucasWilkinson
Copy link
Collaborator

Resolved by: #18161

@tjtanaa
Copy link
Contributor Author

tjtanaa commented May 15, 2025

@LucasWilkinson PR #18161 is about performance regression. It does not resolve this issue addressed in the PR.

@tjtanaa tjtanaa changed the title [Bugfix] [ROCm]: Remove wrong assertion logic when using AITER fused moe in unquantizedMethod [Bugfix] [ROCm]: Remove assertion logic when using AITER fused moe in unquantizedMethod to reenable LLama4 BF16 May 15, 2025
@tjtanaa
Copy link
Contributor Author

tjtanaa commented May 15, 2025

@LucasWilkinson Please take a look at this error log. I am using the latest vLLM commit that contains the PR that you shared.

in Llama4. apply_router_weight_on_input is set to True. The assertion is causing issue.

[rank0]: Traceback (most recent call last):
[rank0]:   File "/app/upstreamcheckrocm64/vllmtests/fast_check_acc/test_accuracy.py", line 92, in <module>
[rank0]:     main(args)
[rank0]:   File "/app/upstreamcheckrocm64/vllmtests/fast_check_acc/test_accuracy.py", line 18, in main
[rank0]:     llm = LLM(**dataclasses.asdict(engine_args))
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/utils.py", line 1176, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/entrypoints/llm.py", line 250, in __init__
[rank0]:     self.llm_engine = LLMEngine.from_engine_args(
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/engine/llm_engine.py", line 511, in from_engine_args
[rank0]:     return engine_cls.from_vllm_config(
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/engine/llm_engine.py", line 487, in from_vllm_config
[rank0]:     return cls(
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/engine/llm_engine.py", line 278, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/engine/llm_engine.py", line 423, in _initialize_kv_caches
[rank0]:     self.model_executor.determine_num_available_blocks())
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/executor/executor_base.py", line 103, in determine_num_available_blocks
[rank0]:     results = self.collective_rpc("determine_num_available_blocks")
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/executor/executor_base.py", line 331, in collective_rpc
[rank0]:     return self._run_workers(method, *args, **(kwargs or {}))
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/executor/mp_distributed_executor.py", line 185, in _run_workers
[rank0]:     driver_worker_output = run_method(self.driver_worker, sent_method,
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/utils.py", line 2552, in run_method
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/worker/worker.py", line 254, in determine_num_available_blocks
[rank0]:     self.model_runner.profile_run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/worker/model_runner.py", line 1302, in profile_run
[rank0]:     self._dummy_run(max_num_batched_tokens, max_num_seqs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/worker/model_runner.py", line 1428, in _dummy_run
[rank0]:     self.execute_model(model_input, kv_caches, intermediate_tensors)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/worker/model_runner.py", line 1846, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/models/mllama4.py", line 768, in forward
[rank0]:     return self.language_model(input_ids, positions, intermediate_tensors,
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/models/llama.py", line 573, in forward
[rank0]:     model_output = self.model(input_ids, positions, intermediate_tensors,
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/compilation/decorators.py", line 172, in __call__
[rank0]:     return self.forward(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/models/llama.py", line 384, in forward
[rank0]:     hidden_states, residual = layer(positions, hidden_states, residual)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/models/llama4.py", line 318, in forward
[rank0]:     hidden_states = self.feed_forward(hidden_states)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/models/llama4.py", line 98, in forward
[rank0]:     routed_out = self.experts(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/layers/fused_moe/layer.py", line 1262, in forward
[rank0]:     return self.forward_impl(hidden_states, router_logits)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/layers/fused_moe/layer.py", line 1328, in forward_impl
[rank0]:     final_hidden_states = self.quant_method.apply(
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/layers/fused_moe/layer.py", line 416, in apply
[rank0]:     return self.forward(
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/custom_op.py", line 23, in forward
[rank0]:     return self._forward_method(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/custom_op.py", line 38, in forward_hip
[rank0]:     return self.forward_cuda(*args, **kwargs)
[rank0]:   File "/app/upstreamcheckrocm64/vllmmain/vllm/model_executor/layers/fused_moe/layer.py", line 506, in forward_cuda
[rank0]:     assert not apply_router_weight_on_input
[rank0]: AssertionError

@LucasWilkinson
Copy link
Collaborator

Apologies! Misread the PR, that is my bad!

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, sorry for the confusion! Got my threads crossed

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) May 15, 2025 15:10
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label May 15, 2025
@tjtanaa
Copy link
Contributor Author

tjtanaa commented May 15, 2025

Apologies! Misread the PR, that is my bad!

I will take note of sharing some error traces to avoid confusion in future PR. 😄

@LucasWilkinson
Copy link
Collaborator

Ideally can you please add the gsm8k (or even better mmlu) accuracy results using the AITER backend here? Just to confirm the weight is being applied correctly?

@DarkLight1337 DarkLight1337 added this to the v0.9.0 milestone May 15, 2025
@tjtanaa
Copy link
Contributor Author

tjtanaa commented May 15, 2025

@LucasWilkinson
After applying this PR fix #18093

vllm (pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=8,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9227 ± 0.0074
strict-match 5 exact_match 0.9067 ± 0.0080

@vllm-bot vllm-bot merged commit 9254052 into vllm-project:main May 15, 2025
77 of 83 checks passed
gshtras pushed a commit to ROCm/vllm that referenced this pull request May 15, 2025
… unquantizedMethod to reenable LLama4 BF16 (vllm-project#18205)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@tjtanaa tjtanaa deleted the fix-fused-moe-assertion branch May 16, 2025 16:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants