Skip to content

Commit 9923f33

Browse files
committed
remove seq_len in llama rotary_emb
1 parent ff939d8 commit 9923f33

File tree

2 files changed

+3
-14
lines changed

2 files changed

+3
-14
lines changed

src/axolotl/monkeypatch/llama_attn_hijack_flash.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,8 @@ def flashattn_forward_with_s2attn(
284284
# [bsz, nh, q_len, hd]
285285
# pylint: disable=duplicate-code
286286

287-
kv_seq_len = key_states.shape[-2]
288-
if past_key_value is not None:
289-
kv_seq_len += past_key_value[0].shape[-2]
290287
cos, sin = self.rotary_emb(
291-
value_states, seq_len=kv_seq_len, position_ids=position_ids
288+
value_states, position_ids=position_ids
292289
)
293290
query_states, key_states = apply_rotary_pos_emb(
294291
query_states, key_states, cos, sin, position_ids
@@ -435,12 +432,8 @@ def flashattn_forward(
435432
# [bsz, q_len, nh, hd]
436433
# [bsz, nh, q_len, hd]
437434

438-
kv_seq_len = key_states.shape[-2]
439-
if past_key_value is not None:
440-
kv_seq_len += past_key_value[0].shape[-2]
441-
442435
cos, sin = self.rotary_emb(
443-
value_states, seq_len=kv_seq_len, position_ids=position_ids
436+
value_states, position_ids=position_ids
444437
)
445438
query_states, key_states = apply_rotary_pos_emb(
446439
query_states, key_states, cos, sin, position_ids

src/axolotl/monkeypatch/llama_attn_hijack_xformers.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,7 @@ def xformers_forward(
8080
# [bsz, q_len, nh, hd]
8181
# [bsz, nh, q_len, hd]
8282

83-
kv_seq_len = key_states.shape[-2]
84-
if past_key_value is not None:
85-
kv_seq_len += past_key_value[0].shape[-2]
86-
87-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
83+
cos, sin = self.rotary_emb(value_states)
8884
query_states, key_states = apply_rotary_pos_emb(
8985
query_states, key_states, cos, sin, position_ids
9086
)

0 commit comments

Comments
 (0)