We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 78b9efb commit 2eda9e0Copy full SHA for 2eda9e0
src/axolotl/utils/models.py
@@ -333,7 +333,7 @@ def load_model(
333
model, use_gradient_checkpointing=cfg.gradient_checkpointing
334
)
335
336
- # LlamaRMSNorm layers are in fp32 after kit call, so we need to
+ # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
337
# convert them back to fp16/bf16 for flash-attn compatibility.
338
if cfg.flash_attention and cfg.is_llama_derived_model:
339
for name, module in model.named_modules():
0 commit comments