File tree Expand file tree Collapse file tree 2 files changed +6
-12
lines changed
vllm/model_executor/models Expand file tree Collapse file tree 2 files changed +6
-12
lines changed Original file line number Diff line number Diff line change @@ -297,13 +297,8 @@ def forward(
297
297
q , k , v = (rearrange (x , "s b ... -> b s ..." ).contiguous ()
298
298
for x in (q , k , v ))
299
299
if rotary_pos_emb is not None :
300
- use_flash_attn = self .attn_backend == _Backend .FLASH_ATTN
301
- q = apply_rotary_pos_emb_vision (q ,
302
- rotary_pos_emb ,
303
- use_flash_attn = use_flash_attn )
304
- k = apply_rotary_pos_emb_vision (k ,
305
- rotary_pos_emb ,
306
- use_flash_attn = use_flash_attn )
300
+ q = apply_rotary_pos_emb_vision (q , rotary_pos_emb )
301
+ k = apply_rotary_pos_emb_vision (k , rotary_pos_emb )
307
302
308
303
if self .attn_backend == _Backend .FLASH_ATTN :
309
304
# from vllm_flash_attn.flash_attn_interface import (
Original file line number Diff line number Diff line change 64
64
BaseProcessingInfo , PromptReplacement ,
65
65
PromptUpdate )
66
66
from vllm .multimodal .profiling import BaseDummyInputsBuilder
67
- from vllm .platforms import _Backend
67
+ from vllm .platforms import _Backend , current_platform
68
68
from vllm .sequence import IntermediateTensors
69
69
from vllm .transformers_utils .config import uses_mrope
70
70
from vllm .transformers_utils .processor import (
@@ -230,14 +230,13 @@ def apply_rotary_emb_torch(x: torch.Tensor,
230
230
231
231
232
232
def apply_rotary_pos_emb_vision (t : torch .Tensor ,
233
- freqs : torch .Tensor ,
234
- use_flash_attn = False ) -> torch .Tensor :
233
+ freqs : torch .Tensor ) -> torch .Tensor :
235
234
t_ = t .float ()
236
235
cos = freqs .cos ()
237
236
sin = freqs .sin ()
238
237
apply_rotary_emb = apply_rotary_emb_torch
239
- if use_flash_attn :
240
- from flash_attn .layers .rotary import apply_rotary_emb
238
+ if current_platform . is_cuda () :
239
+ from vllm . vllm_flash_attn .layers .rotary import apply_rotary_emb
241
240
output = apply_rotary_emb (t_ , cos , sin ).type_as (t )
242
241
return output
243
242
You can’t perform that action at this time.
0 commit comments