@@ -92,7 +92,9 @@ def load_model(
92
92
93
93
if cfg .is_llama_derived_model and cfg .flash_attention :
94
94
if cfg .device not in ["mps" , "cpu" ] and not cfg .inference :
95
- from axolotl .flash_attn import replace_llama_attn_with_flash_attn
95
+ from axolotl .monkeypatch .llama_attn_hijack_flash import (
96
+ replace_llama_attn_with_flash_attn ,
97
+ )
96
98
97
99
LOG .info ("patching with flash attention" )
98
100
replace_llama_attn_with_flash_attn ()
@@ -331,6 +333,16 @@ def load_model(
331
333
model , use_gradient_checkpointing = cfg .gradient_checkpointing
332
334
)
333
335
336
+ # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
337
+ # convert them back to fp16/bf16 for flash-attn compatibility.
338
+ if cfg .flash_attention and cfg .is_llama_derived_model :
339
+ for name , module in model .named_modules ():
340
+ if "norm" in name :
341
+ module .to (torch_dtype )
342
+ if "lm_head" in name or "embed_tokens" in name :
343
+ if hasattr (module , "weight" ):
344
+ module .to (torch_dtype )
345
+
334
346
model , lora_config = load_adapter (model , cfg , adapter )
335
347
336
348
if cfg .ddp and not load_in_8bit :
@@ -407,14 +419,6 @@ def load_llama_adapter(model, cfg):
407
419
else :
408
420
model = get_peft_model (model , peft_config )
409
421
410
- if cfg .flash_attention :
411
- for name , module in model .named_modules ():
412
- if "norm" in name :
413
- module .to (torch .float16 )
414
- if "lm_head" in name or "embed_tokens" in name :
415
- if hasattr (module , "weight" ):
416
- module .to (torch .float16 )
417
-
418
422
model .print_trainable_parameters ()
419
423
420
424
return model , peft_config
0 commit comments