File tree 1 file changed +8
-6
lines changed
1 file changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -333,13 +333,15 @@ def load_model(
333
333
model , use_gradient_checkpointing = cfg .gradient_checkpointing
334
334
)
335
335
336
- if cfg .flash_attention :
337
- for name , module in model .named_modules ():
338
- if "norm" in name :
339
- module .to (torch_dtype )
340
- if "lm_head" in name or "embed_tokens" in name :
341
- if hasattr (module , "weight" ):
336
+ # LlamaRMSNorm layers are in fp32 after kit call, so we need to
337
+ # convert them back to fp16/bf16 for flash-attn compatibility.
338
+ if cfg .flash_attention and cfg .is_llama_derived_model :
339
+ for name , module in model .named_modules ():
340
+ if "norm" in name :
342
341
module .to (torch_dtype )
342
+ if "lm_head" in name or "embed_tokens" in name :
343
+ if hasattr (module , "weight" ):
344
+ module .to (torch_dtype )
343
345
344
346
model , lora_config = load_adapter (model , cfg , adapter )
345
347
You can’t perform that action at this time.
0 commit comments