@@ -284,11 +284,8 @@ def flashattn_forward_with_s2attn(
284
284
# [bsz, nh, q_len, hd]
285
285
# pylint: disable=duplicate-code
286
286
287
- kv_seq_len = key_states .shape [- 2 ]
288
- if past_key_value is not None :
289
- kv_seq_len += past_key_value [0 ].shape [- 2 ]
290
287
cos , sin = self .rotary_emb (
291
- value_states , seq_len = kv_seq_len , position_ids = position_ids
288
+ value_states , position_ids = position_ids
292
289
)
293
290
query_states , key_states = apply_rotary_pos_emb (
294
291
query_states , key_states , cos , sin , position_ids
@@ -435,12 +432,8 @@ def flashattn_forward(
435
432
# [bsz, q_len, nh, hd]
436
433
# [bsz, nh, q_len, hd]
437
434
438
- kv_seq_len = key_states .shape [- 2 ]
439
- if past_key_value is not None :
440
- kv_seq_len += past_key_value [0 ].shape [- 2 ]
441
-
442
435
cos , sin = self .rotary_emb (
443
- value_states , seq_len = kv_seq_len , position_ids = position_ids
436
+ value_states , position_ids = position_ids
444
437
)
445
438
query_states , key_states = apply_rotary_pos_emb (
446
439
query_states , key_states , cos , sin , position_ids
0 commit comments