Skip to content

Commit bbe931e

Browse files
tjtanaaLeiWang1999
authored andcommitted
[Bugfix] Fix try-catch conditions to import correct Flash Attention Backend in Draft Model (vllm-project#9101)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 7ab7621 commit bbe931e

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

vllm/spec_decode/draft_model_runner.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@
66
from vllm.model_executor.layers.sampler import SamplerOutput
77

88
try:
9-
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
10-
except ModuleNotFoundError:
11-
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
12-
from vllm.attention.backends.rocm_flash_attn import (
13-
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
9+
try:
10+
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
11+
except (ModuleNotFoundError, ImportError):
12+
# vllm_flash_attn is not installed, try the ROCm FA metadata
13+
from vllm.attention.backends.rocm_flash_attn import (
14+
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
15+
except (ModuleNotFoundError, ImportError) as err:
16+
raise RuntimeError(
17+
"Draft model speculative decoding currently only supports"
18+
"CUDA and ROCm flash attention backend.") from err
1419

1520
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
1621
ModelConfig, ObservabilityConfig, ParallelConfig,

0 commit comments

Comments
 (0)