diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index 1d158f41b52bc..389558accb34f 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -143,7 +143,7 @@ def __init__( self.tags = tags self._log_model = log_model self._logged_model_time: dict[str, float] = {} - self._checkpoint_callback: Optional[ModelCheckpoint] = None + self._checkpoint_callbacks: Optional[list[ModelCheckpoint]] = [] self._prefix = prefix self._artifact_location = artifact_location self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous} @@ -285,8 +285,9 @@ def finalize(self, status: str = "success") -> None: status = "FINISHED" # log checkpoints as artifacts - if self._checkpoint_callback: - self._scan_and_log_checkpoints(self._checkpoint_callback) + if self._checkpoint_callbacks: + for callback in self._checkpoint_callbacks: + self._scan_and_log_checkpoints(callback) if self.experiment.get_run(self.run_id): self.experiment.set_terminated(self.run_id, status) @@ -332,8 +333,12 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: # log checkpoints as artifacts if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) - elif self._log_model is True: - self._checkpoint_callback = checkpoint_callback + elif ( + self._log_model is True + and self._checkpoint_callbacks + and checkpoint_callback not in self._checkpoint_callbacks + ): + self._checkpoint_callbacks.append(checkpoint_callback) def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: # get checkpoints to be saved with associated score diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 8118349ea6721..751c110e15502 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -12,18 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from typing import Any from unittest import mock from unittest.mock import MagicMock, Mock import pytest from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers.mlflow import ( _MLFLOW_AVAILABLE, MLFlowLogger, _get_resolve_tags, ) +from lightning.pytorch.utilities.types import STEP_OUTPUT def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None): @@ -457,3 +460,79 @@ def test_mlflow_log_model_with_checkpoint_path_prefix(mlflow_mock, tmp_path): for call in client.return_value.log_artifact.call_args_list: args, _ = call assert str(args[2]).startswith("my_prefix") + + +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +def test_mlflow_multiple_checkpoints_top_k(mlflow_mock, tmp_path): + """Test that multiple ModelCheckpoint callbacks with top_k parameters work correctly with MLFlowLogger. + + This test verifies that when using multiple ModelCheckpoint callbacks with save_top_k, both callbacks function + correctly and save the expected number of checkpoints when using MLFlowLogger with log_model=True. + + """ + + class CustomBoringModel(BoringModel): + def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: + loss = self.step(batch) + self.log("train_loss", loss) + return {"loss": loss} + + def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: + loss = self.step(batch) + self.log("val_loss", loss) + return {"loss": loss} + + client = mlflow_mock.tracking.MlflowClient + + model = CustomBoringModel() + logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model=True) + logger = mock_mlflow_run_creation(logger, experiment_id="test-id") + + # Create two ModelCheckpoint callbacks monitoring different metrics + train_ckpt = ModelCheckpoint( + dirpath=str(tmp_path / "train_checkpoints"), + monitor="train_loss", + filename="best_train_model-{epoch:02d}-{train_loss:.2f}", + save_top_k=2, + mode="min", + ) + val_ckpt = ModelCheckpoint( + dirpath=str(tmp_path / "val_checkpoints"), + monitor="val_loss", + filename="best_val_model-{epoch:02d}-{val_loss:.2f}", + save_top_k=2, + mode="min", + ) + + # Create trainer with both callbacks + trainer = Trainer( + default_root_dir=tmp_path, + logger=logger, + callbacks=[train_ckpt, val_ckpt], + max_epochs=5, + limit_train_batches=3, + limit_val_batches=3, + ) + trainer.fit(model) + + # Verify both callbacks saved their checkpoints + assert len(train_ckpt.best_k_models) > 0, "Train checkpoint callback did not save any models" + assert len(val_ckpt.best_k_models) > 0, "Validation checkpoint callback did not save any models" + + # Get all artifact paths that were logged + logged_artifacts = [call_args[0][1] for call_args in client.return_value.log_artifact.call_args_list] + + # Verify MLFlow logged artifacts from both callbacks + # Get all artifact paths that were logged + logged_artifacts = [call_args[0][1] for call_args in client.return_value.log_artifact.call_args_list] + + # Verify MLFlow logged artifacts from both callbacks + train_artifacts = [path for path in logged_artifacts if "train_checkpoints" in path] + val_artifacts = [path for path in logged_artifacts if "val_checkpoints" in path] + + assert len(train_artifacts) > 0, "MLFlow did not log any train checkpoint artifacts" + assert len(val_artifacts) > 0, "MLFlow did not log any validation checkpoint artifacts" + + # Verify the number of logged artifacts matches the save_top_k for each callback + assert len(train_artifacts) == train_ckpt.save_top_k, "Number of logged train artifacts doesn't match save_top_k" + assert len(val_artifacts) == val_ckpt.save_top_k, "Number of logged val artifacts doesn't match save_top_k"