Skip to content

[Attention][V1] Toggle for v1 attention backend #18275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/attention/ops/chunked_prefill_paged_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def chunked_prefill_paged_decode(
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if "fp8" in kv_cache_dtype:
assert key_cache.dtype == torch.uint8
assert value_cache.dtype == torch.uint8
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]

if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = current_platform.fp8_dtype()
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,8 @@ def context_attention_fwd(q,
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if "fp8" in kv_cache_dtype:
assert (k_cache.dtype == torch.uint8)
assert (v_cache.dtype == torch.uint8)
assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]

if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = current_platform.fp8_dtype()
Expand Down
12 changes: 10 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
Expand Down Expand Up @@ -289,6 +290,13 @@ def get_vllm_port() -> Optional[int]:
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")),

# Use separate prefill and decode kernels for V1 attention instead of
# the unified triton kernel.
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: VLLM_V1_TRITON_ATTN_FORCE_PREFILL_DECODE sounds slightly more accurate to me, but feel free to use the name that works best in your eyes

lambda:
(os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in
("true", "1")),

# Force vllm to use a specific flash-attention version (2 or 3), only valid
# when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION":
Expand Down Expand Up @@ -322,8 +330,8 @@ def get_vllm_port() -> Optional[int]:

# Whether to log responses from API Server for debugging
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE":
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False").
lower() == "true",
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: accidental change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ruff or yapf change, reformatting the whole file now results in this

).lower() == "true",

# S3 access information, used for tensorizer to load model from S3
"S3_ACCESS_KEY_ID":
Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.chunked_prefill_paged_decode import (
Expand Down Expand Up @@ -166,8 +167,8 @@ def forward(
# performance to make sure it does not introduce any overhead.

num_queries_per_kv = query.shape[1] // key.shape[1]
use_prefill_decode_attn = (num_queries_per_kv &
(num_queries_per_kv - 1)) != 0
use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or (
Copy link
Contributor

@tjtanaa tjtanaa May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid using envs.ENVIRON to in forward path that is going to be called in runtime. RFC #17067 has revealed that the overhead is very high.

in the def __init__():, we can pre-evaluate the env.ENVIRON self.use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION

Copy link
Contributor

@tjtanaa tjtanaa May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we try to automatically set this environment variable to True if it is on ROCm platform, until the performance gap has been resolved?

Model: Qwen/Qwen3-235B-A22B-FP8
Input: Output= 1000:1000

Use: PREFILL_DECODE_ATTENTION

============ Serving Benchmark Result ============
Successful requests:                     500       
Benchmark duration (s):                  129.30    
Total input tokens:                      494223    
Total generated tokens:                  364067    
Request throughput (req/s):              3.87      
Request goodput (req/s):                 0.03      
Output token throughput (tok/s):         2815.74   
Total Token throughput (tok/s):          6638.12   
---------------Time to First Token----------------
Mean TTFT (ms):                          17056.34  
Median TTFT (ms):                        16655.24  
P99 TTFT (ms):                           31088.47  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          224.22    
Median TPOT (ms):                        83.60     
P99 TPOT (ms):                           2039.01   
---------------Inter-token Latency----------------
Mean ITL (ms):                           76.99     
Median ITL (ms):                         61.88     
P99 ITL (ms):                            270.91    
==================================================

Without PREFILL_DECODE_ATTENTION (with unified attention)

============ Serving Benchmark Result ============
Successful requests:                     500       
Benchmark duration (s):                  153.32    
Total input tokens:                      494223    
Total generated tokens:                  354124    
Request throughput (req/s):              3.26      
Request goodput (req/s):                 0.00      
Output token throughput (tok/s):         2309.75   
Total Token throughput (tok/s):          5533.28   
---------------Time to First Token----------------
Mean TTFT (ms):                          18589.72  
Median TTFT (ms):                        19132.91  
P99 TTFT (ms):                           35497.60  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          238.41    
Median TPOT (ms):                        98.48     
P99 TPOT (ms):                           2331.83   
---------------Inter-token Latency----------------
Mean ITL (ms):                           90.84     
Median ITL (ms):                         69.57     
P99 ITL (ms):                            470.61    
==================================================

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not universally worse, it shows different results on different concurrency settings, so going forward we want to keep it as the default with the toggle left for user discretion.

(num_queries_per_kv & (num_queries_per_kv - 1)) != 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you extract this into a bool (like is_num_q_pow2)


num_actual_tokens = attn_metadata.num_actual_tokens

Expand Down