Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mlflow): Enabling multiple callbacks for checkpoint reporting #20585

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
15 changes: 10 additions & 5 deletions src/lightning/pytorch/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(
self.tags = tags
self._log_model = log_model
self._logged_model_time: dict[str, float] = {}
self._checkpoint_callback: Optional[ModelCheckpoint] = None
self._checkpoint_callbacks: Optional[list[ModelCheckpoint]] = []
self._prefix = prefix
self._artifact_location = artifact_location
self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
Expand Down Expand Up @@ -283,8 +283,9 @@ def finalize(self, status: str = "success") -> None:
status = "FINISHED"

# log checkpoints as artifacts
if self._checkpoint_callback:
self._scan_and_log_checkpoints(self._checkpoint_callback)
if self._checkpoint_callbacks:
for callback in self._checkpoint_callbacks:
self._scan_and_log_checkpoints(callback)

if self.experiment.get_run(self.run_id):
self.experiment.set_terminated(self.run_id, status)
Expand Down Expand Up @@ -330,8 +331,12 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
# log checkpoints as artifacts
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
self._scan_and_log_checkpoints(checkpoint_callback)
elif self._log_model is True:
self._checkpoint_callback = checkpoint_callback
elif (
self._log_model is True
and self._checkpoint_callbacks
and checkpoint_callback not in self._checkpoint_callbacks
):
self._checkpoint_callbacks.append(checkpoint_callback)

def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
# get checkpoints to be saved with associated score
Expand Down
79 changes: 79 additions & 0 deletions tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any
from unittest import mock
from unittest.mock import MagicMock, Mock

import pytest

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers.mlflow import (
_MLFLOW_AVAILABLE,
MLFlowLogger,
_get_resolve_tags,
)
from lightning.pytorch.utilities.types import STEP_OUTPUT


def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None):
Expand Down Expand Up @@ -427,3 +430,79 @@ def test_set_tracking_uri(mlflow_mock):
mlflow_mock.set_tracking_uri.assert_not_called()
_ = logger.experiment
mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri")


@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_multiple_checkpoints_top_k(mlflow_mock, tmp_path):
"""Test that multiple ModelCheckpoint callbacks with top_k parameters work correctly with MLFlowLogger.

This test verifies that when using multiple ModelCheckpoint callbacks with save_top_k, both callbacks function
correctly and save the expected number of checkpoints when using MLFlowLogger with log_model=True.

"""

class CustomBoringModel(BoringModel):
def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
loss = self.step(batch)
self.log("train_loss", loss)
return {"loss": loss}

def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
loss = self.step(batch)
self.log("val_loss", loss)
return {"loss": loss}

client = mlflow_mock.tracking.MlflowClient

model = CustomBoringModel()
logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model=True)
logger = mock_mlflow_run_creation(logger, experiment_id="test-id")

# Create two ModelCheckpoint callbacks monitoring different metrics
train_ckpt = ModelCheckpoint(
dirpath=str(tmp_path / "train_checkpoints"),
monitor="train_loss",
filename="best_train_model-{epoch:02d}-{train_loss:.2f}",
save_top_k=2,
mode="min",
)
val_ckpt = ModelCheckpoint(
dirpath=str(tmp_path / "val_checkpoints"),
monitor="val_loss",
filename="best_val_model-{epoch:02d}-{val_loss:.2f}",
save_top_k=2,
mode="min",
)

# Create trainer with both callbacks
trainer = Trainer(
default_root_dir=tmp_path,
logger=logger,
callbacks=[train_ckpt, val_ckpt],
max_epochs=5,
limit_train_batches=3,
limit_val_batches=3,
)
trainer.fit(model)

# Verify both callbacks saved their checkpoints
assert len(train_ckpt.best_k_models) > 0, "Train checkpoint callback did not save any models"
assert len(val_ckpt.best_k_models) > 0, "Validation checkpoint callback did not save any models"

# Get all artifact paths that were logged
logged_artifacts = [call_args[0][1] for call_args in client.return_value.log_artifact.call_args_list]

# Verify MLFlow logged artifacts from both callbacks
# Get all artifact paths that were logged
logged_artifacts = [call_args[0][1] for call_args in client.return_value.log_artifact.call_args_list]

# Verify MLFlow logged artifacts from both callbacks
train_artifacts = [path for path in logged_artifacts if "train_checkpoints" in path]
val_artifacts = [path for path in logged_artifacts if "val_checkpoints" in path]

assert len(train_artifacts) > 0, "MLFlow did not log any train checkpoint artifacts"
assert len(val_artifacts) > 0, "MLFlow did not log any validation checkpoint artifacts"

# Verify the number of logged artifacts matches the save_top_k for each callback
assert len(train_artifacts) == train_ckpt.save_top_k, "Number of logged train artifacts doesn't match save_top_k"
assert len(val_artifacts) == val_ckpt.save_top_k, "Number of logged val artifacts doesn't match save_top_k"
Loading