-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
[Attention][V1] Toggle for v1 attention backend #18275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
VLLM_NCCL_SO_PATH: Optional[str] = None | ||
LD_LIBRARY_PATH: Optional[str] = None | ||
VLLM_USE_TRITON_FLASH_ATTN: bool = False | ||
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False | ||
VLLM_FLASH_ATTN_VERSION: Optional[int] = None | ||
LOCAL_RANK: int = 0 | ||
CUDA_VISIBLE_DEVICES: Optional[str] = None | ||
|
@@ -289,6 +290,13 @@ def get_vllm_port() -> Optional[int]: | |
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in | ||
("true", "1")), | ||
|
||
# Use separate prefill and decode kernels for V1 attention instead of | ||
# the unified triton kernel. | ||
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": | ||
lambda: | ||
(os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in | ||
("true", "1")), | ||
|
||
# Force vllm to use a specific flash-attention version (2 or 3), only valid | ||
# when using the flash-attention backend. | ||
"VLLM_FLASH_ATTN_VERSION": | ||
|
@@ -322,8 +330,8 @@ def get_vllm_port() -> Optional[int]: | |
|
||
# Whether to log responses from API Server for debugging | ||
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE": | ||
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False"). | ||
lower() == "true", | ||
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: accidental change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ruff or yapf change, reformatting the whole file now results in this |
||
).lower() == "true", | ||
|
||
# S3 access information, used for tensorizer to load model from S3 | ||
"S3_ACCESS_KEY_ID": | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -5,6 +5,7 @@ | |||||
import torch | ||||||
|
||||||
from vllm import _custom_ops as ops | ||||||
from vllm import envs | ||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||||||
AttentionMetadata, AttentionType) | ||||||
from vllm.attention.ops.chunked_prefill_paged_decode import ( | ||||||
|
@@ -126,6 +127,7 @@ def __init__( | |||||
"TritonAttentionImpl") | ||||||
|
||||||
self.fp8_dtype = current_platform.fp8_dtype() | ||||||
self.use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
def forward( | ||||||
self, | ||||||
|
@@ -166,8 +168,8 @@ def forward( | |||||
# performance to make sure it does not introduce any overhead. | ||||||
|
||||||
num_queries_per_kv = query.shape[1] // key.shape[1] | ||||||
use_prefill_decode_attn = (num_queries_per_kv & | ||||||
(num_queries_per_kv - 1)) != 0 | ||||||
use_prefill_decode_attn = self.use_prefill_decode_attn or ( | ||||||
(num_queries_per_kv & (num_queries_per_kv - 1)) != 0) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you extract this into a bool (like |
||||||
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
VLLM_V1_TRITON_ATTN_FORCE_PREFILL_DECODE
sounds slightly more accurate to me, but feel free to use the name that works best in your eyes