Skip to content

Commit 78b9efb

Browse files
committed
scope flash-attn+qlora fix correctly, scope to llama, add comment
1 parent 312a9fa commit 78b9efb

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/axolotl/utils/models.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -333,13 +333,15 @@ def load_model(
333333
model, use_gradient_checkpointing=cfg.gradient_checkpointing
334334
)
335335

336-
if cfg.flash_attention:
337-
for name, module in model.named_modules():
338-
if "norm" in name:
339-
module.to(torch_dtype)
340-
if "lm_head" in name or "embed_tokens" in name:
341-
if hasattr(module, "weight"):
336+
# LlamaRMSNorm layers are in fp32 after kit call, so we need to
337+
# convert them back to fp16/bf16 for flash-attn compatibility.
338+
if cfg.flash_attention and cfg.is_llama_derived_model:
339+
for name, module in model.named_modules():
340+
if "norm" in name:
342341
module.to(torch_dtype)
342+
if "lm_head" in name or "embed_tokens" in name:
343+
if hasattr(module, "weight"):
344+
module.to(torch_dtype)
343345

344346
model, lora_config = load_adapter(model, cfg, adapter)
345347

0 commit comments

Comments
 (0)