Skip to content

Commit 312a9fa

Browse files
committed
move flash-attn monkey patch alongside the others
1 parent 248bf90 commit 312a9fa

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed
File renamed without changes.

src/axolotl/utils/models.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def load_model(
9292

9393
if cfg.is_llama_derived_model and cfg.flash_attention:
9494
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
95-
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
95+
from axolotl.monkeypatch.llama_attn_hijack_flash import (
96+
replace_llama_attn_with_flash_attn,
97+
)
9698

9799
LOG.info("patching with flash attention")
98100
replace_llama_attn_with_flash_attn()

0 commit comments

Comments
 (0)