diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 0509f28acb07a..572e02829eeec 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -941,9 +941,9 @@ def _run( log.debug(f"{self.__class__.__name__}: preparing data") self._data_connector.prepare_data() - call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment log.debug(f"{self.__class__.__name__}: configuring model") call._call_configure_model(self) + call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment # check if we should delay restoring checkpoint till later if not self.strategy.restore_checkpoint_after_setup: diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py index 0c09ae5d5042a..dde5924fda394 100644 --- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py +++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py @@ -431,3 +431,51 @@ def test_unsupported_strategies(tmp_path): trainer = Trainer(accelerator="cpu", strategy="deepspeed", callbacks=[callback]) with pytest.raises(NotImplementedError, match="does not support running with the DeepSpeed strategy"): callback.setup(trainer, model, stage=None) + + +def test_finetuning_with_configure_model(tmp_path): + """Test that BaseFinetuning works correctly with configure_model by ensuring freeze_before_training is called after + configure_model but before training starts.""" + + class TrackingFinetuningCallback(BaseFinetuning): + def __init__(self): + super().__init__() + + def freeze_before_training(self, pl_module): + assert hasattr(pl_module, "backbone"), "backbone should be configured before freezing" + self.freeze(pl_module.backbone) + + def finetune_function(self, pl_module, epoch, optimizer): + pass + + class TestModel(LightningModule): + def __init__(self): + super().__init__() + self.configure_model_called_count = 0 + + def configure_model(self): + self.backbone = nn.Linear(32, 32) + self.classifier = nn.Linear(32, 2) + self.configure_model_called_count += 1 + + def forward(self, x): + x = self.backbone(x) + return self.classifier(x) + + def training_step(self, batch, batch_idx): + return self.forward(batch).sum() + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.1) + + model = TestModel() + callback = TrackingFinetuningCallback() + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[callback], + max_epochs=1, + limit_train_batches=1, + ) + + trainer.fit(model, torch.randn(10, 32)) + assert model.configure_model_called_count == 1 diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 1a8aeb4b297a9..9ed5e69492d00 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -472,11 +472,11 @@ def training_step(self, batch, batch_idx): expected = [ {"name": "configure_callbacks"}, {"name": "prepare_data"}, + {"name": "configure_model"}, {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}}, {"name": "setup", "kwargs": {"stage": "fit"}}, # DeepSpeed needs the batch size to figure out throughput logging *([{"name": "train_dataloader"}] if using_deepspeed else []), - {"name": "configure_model"}, {"name": "configure_optimizers"}, {"name": "Callback.on_fit_start", "args": (trainer, model)}, {"name": "on_fit_start"}, @@ -571,9 +571,9 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path): expected = [ {"name": "configure_callbacks"}, {"name": "prepare_data"}, + {"name": "configure_model"}, {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}}, {"name": "setup", "kwargs": {"stage": "fit"}}, - {"name": "configure_model"}, {"name": "on_load_checkpoint", "args": (loaded_ckpt,)}, {"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)}, {"name": "Callback.load_state_dict", "args": ({"foo": True},)}, @@ -651,9 +651,9 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path): expected = [ {"name": "configure_callbacks"}, {"name": "prepare_data"}, + {"name": "configure_model"}, {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}}, {"name": "setup", "kwargs": {"stage": "fit"}}, - {"name": "configure_model"}, {"name": "on_load_checkpoint", "args": (loaded_ckpt,)}, {"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)}, {"name": "Callback.load_state_dict", "args": ({"foo": True},)}, @@ -719,9 +719,9 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat expected = [ {"name": "configure_callbacks"}, {"name": "prepare_data"}, + {"name": "configure_model"}, {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": verb}}, {"name": "setup", "kwargs": {"stage": verb}}, - {"name": "configure_model"}, {"name": "zero_grad"}, *(hooks if batches else []), {"name": "Callback.teardown", "args": (trainer, model), "kwargs": {"stage": verb}}, @@ -746,9 +746,9 @@ def test_trainer_model_hook_system_predict(tmp_path): expected = [ {"name": "configure_callbacks"}, {"name": "prepare_data"}, + {"name": "configure_model"}, {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "predict"}}, {"name": "setup", "kwargs": {"stage": "predict"}}, - {"name": "configure_model"}, {"name": "zero_grad"}, {"name": "predict_dataloader"}, {"name": "train", "args": (False,)},