Skip to content

[Bugfix][ROCm] Use chunked_prefill_paged_decode as fallback for V1 attention on ROCm #18093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 16, 2025

Conversation

kliuae
Copy link
Contributor

@kliuae kliuae commented May 13, 2025

On ROCm, vLLM’s V1 engine uses the unified attention kernel as its sole attention backend. However, at the moment this kernel fails when running models where the number of query heads over the number of key-value heads is not a power of two. This makes models like Llama-4-Scout, whose num_queries_per_kv evaluates to an odd number, fail to run and yield the following error on ROCm:

offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv)
         ^
ValueError: arange's range must be a power of 2

This PR addresses this issue by adding back the chunked_prefill_paged_decode kernel as a fallback for cases where the input tensor shapes are incompatible with the unified attention kernel.

kliuae added 2 commits May 13, 2025 16:24
Signed-off-by: kf <kuanfu.liu@embeddedllm.com>
Signed-off-by: kf <kuanfu.liu@embeddedllm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label May 13, 2025
@houseroad houseroad added the rocm Related to AMD ROCm label May 14, 2025
@tjtanaa
Copy link
Contributor

tjtanaa commented May 14, 2025

@hongxiayang

This is another quick solution to issue #18088
Falling back to chunked_prefill_paged_decode could be a quick fix as it has been extensively used in previous llama4 use cases.

@hongxiayang
Copy link
Collaborator

cc @tdoublep

@hongxiayang
Copy link
Collaborator

@hongxiayang

This is another quick solution to issue #18088 Falling back to chunked_prefill_paged_decode could be a quick fix as it has been extensively used in previous llama4 use cases.

Thanks. If using chunked_prefill_paged_decode has better performance than the unified triton attention kernel on ROCm side, we can go with this

@tjtanaa
Copy link
Contributor

tjtanaa commented May 15, 2025

@hongxiayang @tdoublep

meta-llama/Llama-4-Scout-17B-16E-Instruct -tp 4 --max-model-len 32768 --max_seq_len_to_capture 32768 --no-enable-prefix-caching --max-num-batched-tokens 32768

Image

1 2 3
pad the BLOCK_Q * num_queries_per_kv with offset mask pad the BLOCK_Q * num_queries_per_kv without offset mask fallback to previously used kernels chunked_prefill_paged_decode
https://github.com/EmbeddedLLM/vllm/tree/fix-unified-attention-triton PR #18100 PR #18093

The best solution is to fallback

The correctness of all three approaches have been validated by running lm_eval on GSM8K on both Llama4 and Mixtral model.

@tjtanaa
Copy link
Contributor

tjtanaa commented May 15, 2025

lm_eval of this branch:

[2025-05-15 08:15:06] INFO evaluation_tracker.py:272: Output path not provided, skipping saving results aggregated
vllm (pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=8,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9189 ± 0.0075
strict-match 5 exact_match 0.9014 ± 0.0082

@hongxiayang
Copy link
Collaborator

I am fine to have both PRs to address the original issue related to the relatively new unified triton attention: (1) one for the completeness of the unified triton attention to address edge cases of power of 2 issue (#18100), and (2) another to address performance regression on ROCm (this PR).

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) May 15, 2025 16:05
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label May 15, 2025
@DarkLight1337 DarkLight1337 added this to the v0.9.0 milestone May 15, 2025
@ProExpertProg
Copy link
Contributor

Can we disable auto-merge until we confirm that the performance is better on LLama4 after the unified_attention fix that landed this morning? #18161

@hongxiayang
Copy link
Collaborator

@ProExpertProg the updated benchmarking result was posted to the slack chat (the fall-back option still perform better in summary than the updated unified triton attention fix): here is the screen shot:
image

https://files.slack.com/files-pri/T07QH46AC91-F08SXGHN0TT/image.png

@tdoublep
Copy link
Member

I'm OK with merging this. Will try to figure out why the unified kernel is not performant in this case.

@vllm-bot vllm-bot merged commit ee659e3 into vllm-project:main May 16, 2025
86 of 90 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants