Skip to content

Commit 176b888

Browse files
authored
ensure enable_input_require_grads is called on model before getting the peft model (axolotl-ai-cloud#345)
1 parent 3392270 commit 176b888

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/axolotl/utils/models.py

+2
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ def load_adapter(model, cfg, adapter):
391391

392392
if adapter is None:
393393
return model, None
394+
if hasattr(model, "enable_input_require_grads"):
395+
model.enable_input_require_grads()
394396
if adapter in ["lora", "qlora"]:
395397
return load_lora(model, cfg)
396398
if adapter == "llama-adapter":

0 commit comments

Comments
 (0)