Skip to content

Commit 0d2e34f

Browse files
authored
Merge pull request axolotl-ai-cloud#336 from tmm1/flash-attn
Fix flash-attn + qlora not working with llama models
2 parents b56a6c0 + 2eda9e0 commit 0d2e34f

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed
File renamed without changes.

src/axolotl/utils/models.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def load_model(
9292

9393
if cfg.is_llama_derived_model and cfg.flash_attention:
9494
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
95-
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
95+
from axolotl.monkeypatch.llama_attn_hijack_flash import (
96+
replace_llama_attn_with_flash_attn,
97+
)
9698

9799
LOG.info("patching with flash attention")
98100
replace_llama_attn_with_flash_attn()
@@ -331,6 +333,16 @@ def load_model(
331333
model, use_gradient_checkpointing=cfg.gradient_checkpointing
332334
)
333335

336+
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
337+
# convert them back to fp16/bf16 for flash-attn compatibility.
338+
if cfg.flash_attention and cfg.is_llama_derived_model:
339+
for name, module in model.named_modules():
340+
if "norm" in name:
341+
module.to(torch_dtype)
342+
if "lm_head" in name or "embed_tokens" in name:
343+
if hasattr(module, "weight"):
344+
module.to(torch_dtype)
345+
334346
model, lora_config = load_adapter(model, cfg, adapter)
335347

336348
if cfg.ddp and not load_in_8bit:
@@ -407,14 +419,6 @@ def load_llama_adapter(model, cfg):
407419
else:
408420
model = get_peft_model(model, peft_config)
409421

410-
if cfg.flash_attention:
411-
for name, module in model.named_modules():
412-
if "norm" in name:
413-
module.to(torch.float16)
414-
if "lm_head" in name or "embed_tokens" in name:
415-
if hasattr(module, "weight"):
416-
module.to(torch.float16)
417-
418422
model.print_trainable_parameters()
419423

420424
return model, peft_config

0 commit comments

Comments
 (0)