Skip to content

Commit 77085ea

Browse files
authored
qlora w flash attention fixes (axolotl-ai-cloud#333)
1 parent db2a358 commit 77085ea

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/axolotl/utils/models.py

+8
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,14 @@ def load_llama_adapter(model, cfg):
407407
else:
408408
model = get_peft_model(model, peft_config)
409409

410+
if cfg.flash_attention:
411+
for name, module in model.named_modules():
412+
if "norm" in name:
413+
module.to(torch.float16)
414+
if "lm_head" in name or "embed_tokens" in name:
415+
if hasattr(module, "weight"):
416+
module.to(torch.float16)
417+
410418
model.print_trainable_parameters()
411419

412420
return model, peft_config

0 commit comments

Comments
 (0)