Skip to content

Commit a032c9f

Browse files
committed
fix sdp attention to use the flash/mem-efficient context manaager
1 parent b06d3e3 commit a032c9f

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

Diff for: src/axolotl/monkeypatch/llama_attn_hijack_xformers.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,15 @@ def sdp_attention_forward(
184184

185185
# We only apply sdp attention if we don't need to output the whole attention matrix
186186
if not output_attentions:
187-
attn_output = torch.nn.functional.scaled_dot_product_attention(
188-
query_states,
189-
key_states,
190-
value_states,
191-
attn_mask=attention_mask,
192-
is_causal=False,
193-
)
194-
attn_weights = None
187+
with torch.backends.cuda.sdp_kernel():
188+
attn_output = torch.nn.functional.scaled_dot_product_attention(
189+
query_states,
190+
key_states,
191+
value_states,
192+
attn_mask=attention_mask,
193+
is_causal=False,
194+
)
195+
attn_weights = None
195196
else:
196197
attn_weights = torch.matmul(
197198
query_states, key_states.transpose(2, 3)

0 commit comments

Comments
 (0)