Skip to content

[Attention][V1] Toggle for v1 attention backend #18275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

gshtras
Copy link
Collaborator

@gshtras gshtras commented May 16, 2025

Expanding on #18093
Adding a toggle to force fallback to the 2 stage attention kernel in V1

Including a small fix for the FP8 kv cache on ROCm in the 2 stage kernel approach

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label May 16, 2025
@@ -166,8 +167,8 @@ def forward(
# performance to make sure it does not introduce any overhead.

num_queries_per_kv = query.shape[1] // key.shape[1]
use_prefill_decode_attn = (num_queries_per_kv &
(num_queries_per_kv - 1)) != 0
use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or (
Copy link
Contributor

@tjtanaa tjtanaa May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid using envs.ENVIRON to in forward path that is going to be called in runtime. RFC #17067 has revealed that the overhead is very high.

in the def __init__():, we can pre-evaluate the env.ENVIRON self.use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION

Copy link
Contributor

@tjtanaa tjtanaa May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we try to automatically set this environment variable to True if it is on ROCm platform, until the performance gap has been resolved?

Model: Qwen/Qwen3-235B-A22B-FP8
Input: Output= 1000:1000

Use: PREFILL_DECODE_ATTENTION

============ Serving Benchmark Result ============
Successful requests:                     500       
Benchmark duration (s):                  129.30    
Total input tokens:                      494223    
Total generated tokens:                  364067    
Request throughput (req/s):              3.87      
Request goodput (req/s):                 0.03      
Output token throughput (tok/s):         2815.74   
Total Token throughput (tok/s):          6638.12   
---------------Time to First Token----------------
Mean TTFT (ms):                          17056.34  
Median TTFT (ms):                        16655.24  
P99 TTFT (ms):                           31088.47  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          224.22    
Median TPOT (ms):                        83.60     
P99 TPOT (ms):                           2039.01   
---------------Inter-token Latency----------------
Mean ITL (ms):                           76.99     
Median ITL (ms):                         61.88     
P99 ITL (ms):                            270.91    
==================================================

Without PREFILL_DECODE_ATTENTION (with unified attention)

============ Serving Benchmark Result ============
Successful requests:                     500       
Benchmark duration (s):                  153.32    
Total input tokens:                      494223    
Total generated tokens:                  354124    
Request throughput (req/s):              3.26      
Request goodput (req/s):                 0.00      
Output token throughput (tok/s):         2309.75   
Total Token throughput (tok/s):          5533.28   
---------------Time to First Token----------------
Mean TTFT (ms):                          18589.72  
Median TTFT (ms):                        19132.91  
P99 TTFT (ms):                           35497.60  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          238.41    
Median TPOT (ms):                        98.48     
P99 TPOT (ms):                           2331.83   
---------------Inter-token Latency----------------
Mean ITL (ms):                           90.84     
Median ITL (ms):                         69.57     
P99 ITL (ms):                            470.61    
==================================================

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not universally worse, it shows different results on different concurrency settings, so going forward we want to keep it as the default with the toggle left for user discretion.

Copy link
Contributor

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor comments, looks good otherwise

use_prefill_decode_attn = (num_queries_per_kv &
(num_queries_per_kv - 1)) != 0
use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or (
(num_queries_per_kv & (num_queries_per_kv - 1)) != 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you extract this into a bool (like is_num_q_pow2)

@@ -289,6 +290,13 @@ def get_vllm_port() -> Optional[int]:
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")),

# Use separate prefill and decode kernels for V1 attention instead of
# the unified triton kernel.
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: VLLM_V1_TRITON_ATTN_FORCE_PREFILL_DECODE sounds slightly more accurate to me, but feel free to use the name that works best in your eyes

@@ -322,8 +330,8 @@ def get_vllm_port() -> Optional[int]:

# Whether to log responses from API Server for debugging
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE":
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False").
lower() == "true",
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: accidental change?

@@ -126,6 +127,7 @@ def __init__(
"TritonAttentionImpl")

self.fp8_dtype = current_platform.fp8_dtype()
self.use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
self.force_prefill_decode_attn = envs. VLLM_V1_TRITON_ATTN_FORCE_PREFILL_DECODE

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants