-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
[Attention][V1] Toggle for v1 attention backend #18275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
VLLM_NCCL_SO_PATH: Optional[str] = None | ||
LD_LIBRARY_PATH: Optional[str] = None | ||
VLLM_USE_TRITON_FLASH_ATTN: bool = False | ||
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False | ||
VLLM_FLASH_ATTN_VERSION: Optional[int] = None | ||
LOCAL_RANK: int = 0 | ||
CUDA_VISIBLE_DEVICES: Optional[str] = None | ||
|
@@ -289,6 +290,13 @@ def get_vllm_port() -> Optional[int]: | |
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in | ||
("true", "1")), | ||
|
||
# Use separate prefill and decode kernels for V1 attention instead of | ||
# the unified triton kernel. | ||
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": | ||
lambda: | ||
(os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in | ||
("true", "1")), | ||
|
||
# Force vllm to use a specific flash-attention version (2 or 3), only valid | ||
# when using the flash-attention backend. | ||
"VLLM_FLASH_ATTN_VERSION": | ||
|
@@ -322,8 +330,8 @@ def get_vllm_port() -> Optional[int]: | |
|
||
# Whether to log responses from API Server for debugging | ||
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE": | ||
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False"). | ||
lower() == "true", | ||
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: accidental change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ruff or yapf change, reformatting the whole file now results in this |
||
).lower() == "true", | ||
|
||
# S3 access information, used for tensorizer to load model from S3 | ||
"S3_ACCESS_KEY_ID": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
import torch | ||
|
||
from vllm import _custom_ops as ops | ||
from vllm import envs | ||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||
AttentionMetadata, AttentionType) | ||
from vllm.attention.ops.chunked_prefill_paged_decode import ( | ||
|
@@ -166,8 +167,8 @@ def forward( | |
# performance to make sure it does not introduce any overhead. | ||
|
||
num_queries_per_kv = query.shape[1] // key.shape[1] | ||
use_prefill_decode_attn = (num_queries_per_kv & | ||
(num_queries_per_kv - 1)) != 0 | ||
use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please avoid using envs.ENVIRON to in forward path that is going to be called in runtime. RFC #17067 has revealed that the overhead is very high. in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we try to automatically set this environment variable to True if it is on ROCm platform, until the performance gap has been resolved? Model: Qwen/Qwen3-235B-A22B-FP8 Use: PREFILL_DECODE_ATTENTION
Without PREFILL_DECODE_ATTENTION (with unified attention)
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is not universally worse, it shows different results on different concurrency settings, so going forward we want to keep it as the default with the toggle left for user discretion. |
||
(num_queries_per_kv & (num_queries_per_kv - 1)) != 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you extract this into a bool (like |
||
|
||
num_actual_tokens = attn_metadata.num_actual_tokens | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
VLLM_V1_TRITON_ATTN_FORCE_PREFILL_DECODE
sounds slightly more accurate to me, but feel free to use the name that works best in your eyes