Skip to content

Commit 1d5f2aa

Browse files
Isotr0pyYuqi Zhang
authored andcommitted
[Misc] Use apply_rotary_emb from vllm_flash_attn for Qwen2-VL vision RoPE (vllm-project#17726)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
1 parent 9a768f8 commit 1d5f2aa

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -297,13 +297,8 @@ def forward(
297297
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
298298
for x in (q, k, v))
299299
if rotary_pos_emb is not None:
300-
use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN
301-
q = apply_rotary_pos_emb_vision(q,
302-
rotary_pos_emb,
303-
use_flash_attn=use_flash_attn)
304-
k = apply_rotary_pos_emb_vision(k,
305-
rotary_pos_emb,
306-
use_flash_attn=use_flash_attn)
300+
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
301+
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
307302

308303
if self.attn_backend == _Backend.FLASH_ATTN:
309304
# from vllm_flash_attn.flash_attn_interface import (

vllm/model_executor/models/qwen2_vl.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
BaseProcessingInfo, PromptReplacement,
6565
PromptUpdate)
6666
from vllm.multimodal.profiling import BaseDummyInputsBuilder
67-
from vllm.platforms import _Backend
67+
from vllm.platforms import _Backend, current_platform
6868
from vllm.sequence import IntermediateTensors
6969
from vllm.transformers_utils.config import uses_mrope
7070
from vllm.transformers_utils.processor import (
@@ -230,14 +230,13 @@ def apply_rotary_emb_torch(x: torch.Tensor,
230230

231231

232232
def apply_rotary_pos_emb_vision(t: torch.Tensor,
233-
freqs: torch.Tensor,
234-
use_flash_attn=False) -> torch.Tensor:
233+
freqs: torch.Tensor) -> torch.Tensor:
235234
t_ = t.float()
236235
cos = freqs.cos()
237236
sin = freqs.sin()
238237
apply_rotary_emb = apply_rotary_emb_torch
239-
if use_flash_attn:
240-
from flash_attn.layers.rotary import apply_rotary_emb
238+
if current_platform.is_cuda():
239+
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
241240
output = apply_rotary_emb(t_, cos, sin).type_as(t)
242241
return output
243242

0 commit comments

Comments
 (0)