From ac2b3818a8b0e2c928841a2c5392340dfbb1406a Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Fri, 16 May 2025 18:13:53 +0000 Subject: [PATCH 1/3] Toggle for v1 attention Signed-off-by: Gregory Shtrasberg --- vllm/attention/ops/chunked_prefill_paged_decode.py | 4 ++-- vllm/attention/ops/prefix_prefill.py | 4 ++-- vllm/envs.py | 12 ++++++++++-- vllm/v1/attention/backends/triton_attn.py | 5 +++-- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 217db3bf965..370ffd57c25 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -264,8 +264,8 @@ def chunked_prefill_paged_decode( # Conversion of FP8 Tensor from uint8 storage to # appropriate torch.dtype for interpretation by Triton if "fp8" in kv_cache_dtype: - assert key_cache.dtype == torch.uint8 - assert value_cache.dtype == torch.uint8 + assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] if kv_cache_dtype in ("fp8", "fp8_e4m3"): target_dtype = current_platform.fp8_dtype() diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 86d256b630b..729b61b0290 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -744,8 +744,8 @@ def context_attention_fwd(q, # Conversion of FP8 Tensor from uint8 storage to # appropriate torch.dtype for interpretation by Triton if "fp8" in kv_cache_dtype: - assert (k_cache.dtype == torch.uint8) - assert (v_cache.dtype == torch.uint8) + assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] if kv_cache_dtype in ("fp8", "fp8_e4m3"): target_dtype = current_platform.fp8_dtype() diff --git a/vllm/envs.py b/vllm/envs.py index dc23c8ea531..3b683c7423d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -15,6 +15,7 @@ VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = False + VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False VLLM_FLASH_ATTN_VERSION: Optional[int] = None LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -289,6 +290,13 @@ def get_vllm_port() -> Optional[int]: lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), + # Use separate prefill and decode kernels for V1 attention instead of + # the unified triton kernel. + "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": + lambda: + (os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in + ("true", "1")), + # Force vllm to use a specific flash-attention version (2 or 3), only valid # when using the flash-attention backend. "VLLM_FLASH_ATTN_VERSION": @@ -322,8 +330,8 @@ def get_vllm_port() -> Optional[int]: # Whether to log responses from API Server for debugging "VLLM_DEBUG_LOG_API_SERVER_RESPONSE": - lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False"). - lower() == "true", + lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" + ).lower() == "true", # S3 access information, used for tensorizer to load model from S3 "S3_ACCESS_KEY_ID": diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 4000f93984d..3fd5531d4c4 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -5,6 +5,7 @@ import torch from vllm import _custom_ops as ops +from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.ops.chunked_prefill_paged_decode import ( @@ -166,8 +167,8 @@ def forward( # performance to make sure it does not introduce any overhead. num_queries_per_kv = query.shape[1] // key.shape[1] - use_prefill_decode_attn = (num_queries_per_kv & - (num_queries_per_kv - 1)) != 0 + use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or ( + (num_queries_per_kv & (num_queries_per_kv - 1)) != 0) num_actual_tokens = attn_metadata.num_actual_tokens From 9540f728b0f4d7327a92e00025fa470d9d8c7814 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Mon, 19 May 2025 15:25:17 +0000 Subject: [PATCH 2/3] Caching the env variable in the __init__ Signed-off-by: Gregory Shtrasberg --- vllm/v1/attention/backends/triton_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 3fd5531d4c4..e5fef2239af 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -127,6 +127,7 @@ def __init__( "TritonAttentionImpl") self.fp8_dtype = current_platform.fp8_dtype() + self.use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION def forward( self, @@ -167,7 +168,7 @@ def forward( # performance to make sure it does not introduce any overhead. num_queries_per_kv = query.shape[1] // key.shape[1] - use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or ( + use_prefill_decode_attn = self.use_prefill_decode_attn or ( (num_queries_per_kv & (num_queries_per_kv - 1)) != 0) num_actual_tokens = attn_metadata.num_actual_tokens From f82da97337b6bad6611c2a540e9b4c69da439251 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Thu, 22 May 2025 18:29:54 +0000 Subject: [PATCH 3/3] Better naming and logic extracted to a variable Signed-off-by: Gregory Shtrasberg --- vllm/v1/attention/backends/triton_attn.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index e5fef2239af..a97bb85004f 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -127,7 +127,8 @@ def __init__( "TritonAttentionImpl") self.fp8_dtype = current_platform.fp8_dtype() - self.use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + self.force_prefill_decode_attn = \ + envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION def forward( self, @@ -168,9 +169,9 @@ def forward( # performance to make sure it does not introduce any overhead. num_queries_per_kv = query.shape[1] // key.shape[1] - use_prefill_decode_attn = self.use_prefill_decode_attn or ( - (num_queries_per_kv & (num_queries_per_kv - 1)) != 0) - + num_q_is_pow2 = (num_queries_per_kv & (num_queries_per_kv - 1)) == 0 + use_prefill_decode_attn = (self.force_prefill_decode_attn + or not num_q_is_pow2) num_actual_tokens = attn_metadata.num_actual_tokens if use_prefill_decode_attn: