Skip to content

Commit fe9fee4

Browse files
authored
Allow DPO reference model to be loaded from LoadCheckpoint callback (#80)
This is not the cleanest solution of all time, but does unblock this niche use case without significant rearchitecting of the code. The issue is that we sometimes use a callback to load the checkpoint in Composer (https://github.com/mosaicml/composer/blob/main/composer/callbacks/load_checkpoint.py). This is useful when the base model is saved in a composer checkpoint, and you want to only save lora checkpoints during training for autoresume. The callback will load a checkpoint on `BEFORE_LOAD` event, so that any autoresume checkpoint would overwrite it. None of that really applies to the reference model loading here, and we just want to grab the base checkpoint from the callback and load it as an additional step. Testing: Before (fails with nan loss because weights are not properly loaded): `daniel-matt-failure-1-tRAWIE` After with load checkpoint callback (init device meta, pretrained false): `daniel-matt-callback-1-VevoUT` After without load checkpoint callback (init device mixed, pretrained true): `daniel-matt-no-callback-1-G6P1po` <img width="1648" alt="Screenshot 2025-06-02 at 3 59 28 PM" src="https://github.com/user-attachments/assets/898089ed-71cd-4874-8ae4-7d36c19addc2" />
1 parent 8ecb82d commit fe9fee4

File tree

1 file changed

+41
-1
lines changed

1 file changed

+41
-1
lines changed

compose_rl/dpo/callback.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88

99
import torch
1010
from composer import Trainer
11+
from composer.callbacks import LoadCheckpoint
1112
from composer.core import State, get_precision_context
1213
from composer.loggers import Logger
14+
from composer.models.huggingface import HuggingFaceModel
15+
from composer.utils.checkpoint import load_checkpoint
1316
from llmfoundry.interfaces import CallbackWithConfig
1417
from llmfoundry.utils import build_composer_model
1518
# pyright does not recognize process_init_device though it is a declared export
@@ -47,9 +50,10 @@ def after_load(self, state: State, logger: Logger) -> None:
4750
)
4851

4952
original_load_path = self.train_config.get('load_path', None)
53+
5054
# For HF checkpoint, load_path is unset and should be handled in llmfoundry code.
5155
# Create a Trainer object to load model into FSDP
52-
_ = Trainer(
56+
fake_trainer = Trainer(
5357
model=self.reference_model,
5458
parallelism_config={'fsdp': state.fsdp_config},
5559
precision=state.precision,
@@ -58,6 +62,42 @@ def after_load(self, state: State, logger: Logger) -> None:
5862
load_path=original_load_path,
5963
)
6064

65+
# The base model checkpoint may have been supplied by a LoadCheckpoint callback,
66+
# so we need to check and apply that checkpoint to the reference model.
67+
load_checkpoint_callbacks = [
68+
callback for callback in state.callbacks
69+
if isinstance(callback, LoadCheckpoint)
70+
]
71+
72+
if original_load_path is not None and len(
73+
load_checkpoint_callbacks,
74+
) > 0:
75+
raise ValueError(
76+
'Cannot use `load_path` in the train config when using `LoadCheckpoint` callback. '
77+
+ 'Please remove `load_path` from the train config.',
78+
)
79+
80+
# For any LoadCheckpoint callbacks we found, we will load the checkpoint into the reference model.
81+
# If none are found, this for loop is a no-op.
82+
for load_checkpoint_callback in load_checkpoint_callbacks:
83+
assert isinstance(self.reference_model, HuggingFaceModel)
84+
85+
# If using PEFT, we need to _not_ filter the state dict to only include the PEFT weights.
86+
# This is so the checkpoint can load the base model weights. Since the reference model is
87+
# not being update, we don't need to respect the `should_save_peft_only` flag from the original model
88+
# and can just hardcode it to False.
89+
self.reference_model.should_save_peft_only = False
90+
load_checkpoint(
91+
path=load_checkpoint_callback.parsed_path,
92+
state=fake_trainer.state,
93+
logger=logger,
94+
object_store=load_checkpoint_callback.load_object_store,
95+
strict_model_weights=load_checkpoint_callback.
96+
strict_model_weights,
97+
ignore_keys=load_checkpoint_callback.ignore_keys,
98+
load_weights_only=load_checkpoint_callback.load_weights_only,
99+
)
100+
61101
def before_forward(self, state: State, logger: Logger) -> Optional[int]:
62102
# Before every batch we need to do a forwards pass over the reference model
63103
with get_precision_context(state.precision):

0 commit comments

Comments
 (0)