Skip to content

Commit 06c0922

Browse files
authored
[FP8][ROCm][Attention] Enable FP8 KV cache on ROCm for V1 (#17870)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
1 parent cd3edfc commit 06c0922

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010

1111
from vllm import _custom_ops as ops
12+
from vllm.platforms import current_platform
1213
from vllm.platforms.rocm import use_rocm_custom_paged_attention
1314
from vllm.triton_utils import tl, triton
1415

@@ -267,7 +268,7 @@ def chunked_prefill_paged_decode(
267268
assert value_cache.dtype == torch.uint8
268269

269270
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
270-
target_dtype = torch.float8_e4m3fn
271+
target_dtype = current_platform.fp8_dtype()
271272
elif kv_cache_dtype == "fp8_e5m2":
272273
target_dtype = torch.float8_e5m2
273274
else:

vllm/engine/arg_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,9 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
12051205
and not envs.is_set("VLLM_ATTENTION_BACKEND")
12061206
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
12071207
supported = False
1208-
if fp8_attention and will_use_fa:
1208+
if current_platform.is_rocm():
1209+
supported = True
1210+
elif fp8_attention and will_use_fa:
12091211
from vllm.attention.utils.fa_utils import (
12101212
flash_attn_supports_fp8)
12111213
supported = flash_attn_supports_fp8()

vllm/v1/attention/backends/triton_attn.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AttentionMetadata, AttentionType)
1010
from vllm.attention.ops.triton_unified_attention import unified_attention
1111
from vllm.logger import init_logger
12+
from vllm.platforms import current_platform
1213
from vllm.v1.attention.backends.flash_attn import (
1314
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
1415

@@ -108,6 +109,8 @@ def __init__(
108109
"are not implemented for "
109110
"TritonAttentionImpl")
110111

112+
self.fp8_dtype = current_platform.fp8_dtype()
113+
111114
def forward(
112115
self,
113116
layer: torch.nn.Module,
@@ -161,15 +164,18 @@ def forward(
161164
)
162165

163166
if self.kv_cache_dtype.startswith("fp8"):
164-
key_cache = key_cache.view(torch.float8_e4m3fn)
165-
value_cache = value_cache.view(torch.float8_e4m3fn)
167+
key_cache = key_cache.view(self.fp8_dtype)
168+
value_cache = value_cache.view(self.fp8_dtype)
166169
num_tokens, num_heads, head_size = query.shape
167170
assert layer._q_scale == 1.0, \
168171
"A non 1.0 q_scale is not currently supported."
169-
query, _ = ops.scaled_fp8_quant(
170-
query.reshape(
171-
(num_tokens, num_heads * head_size)).contiguous(),
172-
layer._q_scale)
172+
if not current_platform.is_rocm():
173+
# Skip Q quantization on ROCm, since dequantizing back to
174+
# f32 in the attention kernel is not supported.
175+
query, _ = ops.scaled_fp8_quant(
176+
query.reshape(
177+
(num_tokens, num_heads * head_size)).contiguous(),
178+
layer._q_scale)
173179
query = query.reshape((num_tokens, num_heads, head_size))
174180

175181
use_local_attn = \

0 commit comments

Comments
 (0)