Skip to content

Commit 3da042f

Browse files
authored
Merge branch 'master' into feature/option-disable-loghparams
2 parents 72f4706 + e8d70bc commit 3da042f

19 files changed

+245
-47
lines changed

.azure/gpu-benchmarks.yml

+1
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,6 @@ jobs:
108108
condition: and(succeeded(), eq(variables['PACKAGE_NAME'], 'fabric'))
109109
env:
110110
PL_RUN_CUDA_TESTS: "1"
111+
PL_RUN_STANDALONE_TESTS: "1"
111112
displayName: "Testing: fabric standalone tasks"
112113
timeoutInMinutes: "10"

.azure/gpu-tests-fabric.yml

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ jobs:
144144
workingDirectory: tests/
145145
env:
146146
PL_STANDALONE_TESTS_SOURCE: $(COVERAGE_SOURCE)
147+
PL_RUN_STANDALONE_TESTS: "1"
147148
displayName: "Testing: fabric standalone"
148149
timeoutInMinutes: "10"
149150

.azure/gpu-tests-pytorch.yml

+1
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ jobs:
166166
env:
167167
PL_USE_MOCKED_MNIST: "1"
168168
PL_STANDALONE_TESTS_SOURCE: $(COVERAGE_SOURCE)
169+
PL_RUN_STANDALONE_TESTS: "1"
169170
displayName: "Testing: PyTorch standalone tests"
170171
timeoutInMinutes: "35"
171172

.github/workflows/call-clear-cache.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@ on:
2323
jobs:
2424
cron-clear:
2525
if: github.event_name == 'schedule' || github.event_name == 'pull_request'
26-
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.12.0
26+
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.14.0
2727
with:
28-
scripts-ref: v0.11.8
28+
scripts-ref: v0.14.0
2929
dry-run: ${{ github.event_name == 'pull_request' }}
3030
pattern: "latest|docs"
3131
age-days: 7
3232

3333
direct-clear:
3434
if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request'
35-
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.12.0
35+
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.14.0
3636
with:
37-
scripts-ref: v0.11.8
37+
scripts-ref: v0.14.0
3838
dry-run: ${{ github.event_name == 'pull_request' }}
3939
pattern: ${{ inputs.pattern || 'pypi_wheels' }} # setting str in case of PR / debugging
4040
age-days: ${{ fromJSON(inputs.age-days) || 0 }} # setting 0 in case of PR / debugging

.github/workflows/ci-check-md-links.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ on:
1414

1515
jobs:
1616
check-md-links:
17-
uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.12.0
17+
uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.14.0
1818
with:
1919
config-file: ".github/markdown-links-config.json"
2020
base-branch: "master"

.github/workflows/ci-schema.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
jobs:
1010
check:
11-
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.12.0
11+
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.14.0
1212
with:
1313
# skip azure due to the wrong schema file by MSFT
1414
# https://github.com/Lightning-AI/lightning-flash/pull/1455#issuecomment-1244793607

docs/source-pytorch/visualize/loggers.rst

+34
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,37 @@ Track and Visualize Experiments
5454

5555
</div>
5656
</div>
57+
58+
.. _mlflow_logger:
59+
60+
MLflow Logger
61+
-------------
62+
63+
The MLflow logger in PyTorch Lightning now includes a `checkpoint_path_prefix` parameter. This parameter allows you to prefix the checkpoint artifact's path when logging checkpoints as artifacts.
64+
65+
Example usage:
66+
67+
.. code-block:: python
68+
69+
import lightning as L
70+
from lightning.pytorch.loggers import MLFlowLogger
71+
72+
mlf_logger = MLFlowLogger(
73+
experiment_name="lightning_logs",
74+
tracking_uri="file:./ml-runs",
75+
checkpoint_path_prefix="my_prefix"
76+
)
77+
trainer = L.Trainer(logger=mlf_logger)
78+
79+
# Your LightningModule definition
80+
class LitModel(L.LightningModule):
81+
def training_step(self, batch, batch_idx):
82+
# example
83+
self.logger.experiment.whatever_ml_flow_supports(...)
84+
85+
def any_lightning_module_function_or_hook(self):
86+
self.logger.experiment.whatever_ml_flow_supports(...)
87+
88+
# Train your model
89+
model = LitModel()
90+
trainer.fit(model)

src/lightning/pytorch/CHANGELOG.md

+17
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,32 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
- Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593))
1212

13+
14+
- Allow LightningCLI to use a customized argument parser class ([#20596](https://github.com/Lightning-AI/pytorch-lightning/pull/20596))
15+
16+
1317
### Changed
1418

19+
- Change `wandb` default x-axis to `tensorboard`'s `global_step` when `sync_tensorboard=True` ([#20611](https://github.com/Lightning-AI/pytorch-lightning/pull/20611))
20+
21+
22+
- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored ([#20538](https://github.com/Lightning-AI/pytorch-lightning/pull/20538))
23+
24+
25+
1526
### Removed
1627

28+
-
29+
30+
1731
### Fixed
1832

1933
- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))
2034

2135

36+
- Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610))
37+
38+
2239
## [2.5.0] - 2024-12-19
2340

2441
### Added

src/lightning/pytorch/cli.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def __init__(
314314
trainer_defaults: Optional[dict[str, Any]] = None,
315315
seed_everything_default: Union[bool, int] = True,
316316
parser_kwargs: Optional[Union[dict[str, Any], dict[str, dict[str, Any]]]] = None,
317+
parser_class: type[LightningArgumentParser] = LightningArgumentParser,
317318
subclass_mode_model: bool = False,
318319
subclass_mode_data: bool = False,
319320
args: ArgsType = None,
@@ -367,6 +368,7 @@ def __init__(
367368
self.trainer_defaults = trainer_defaults or {}
368369
self.seed_everything_default = seed_everything_default
369370
self.parser_kwargs = parser_kwargs or {}
371+
self.parser_class = parser_class
370372
self.auto_configure_optimizers = auto_configure_optimizers
371373

372374
self.model_class = model_class
@@ -404,7 +406,7 @@ def _setup_parser_kwargs(self, parser_kwargs: dict[str, Any]) -> tuple[dict[str,
404406
def init_parser(self, **kwargs: Any) -> LightningArgumentParser:
405407
"""Method that instantiates the argument parser."""
406408
kwargs.setdefault("dump_header", [f"lightning.pytorch=={pl.__version__}"])
407-
parser = LightningArgumentParser(**kwargs)
409+
parser = self.parser_class(**kwargs)
408410
parser.add_argument(
409411
"-c", "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
410412
)

src/lightning/pytorch/loggers/mlflow.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def any_lightning_module_function_or_hook(self):
9797
:paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1``
9898
which also logs every checkpoint during training.
9999
* if ``log_model == False`` (default), no checkpoint is logged.
100-
100+
checkpoint_path_prefix: A string to prefix the checkpoint artifact's path.
101101
prefix: A string to put at the beginning of metric keys.
102102
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
103103
default.
@@ -121,6 +121,7 @@ def __init__(
121121
tags: Optional[dict[str, Any]] = None,
122122
save_dir: Optional[str] = "./mlruns",
123123
log_model: Literal[True, False, "all"] = False,
124+
checkpoint_path_prefix: str = "",
124125
prefix: str = "",
125126
artifact_location: Optional[str] = None,
126127
run_id: Optional[str] = None,
@@ -147,6 +148,7 @@ def __init__(
147148
self._artifact_location = artifact_location
148149
self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
149150
self._initialized = False
151+
self._checkpoint_path_prefix = checkpoint_path_prefix
150152

151153
from mlflow.tracking import MlflowClient
152154

@@ -361,7 +363,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
361363
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
362364

363365
# Artifact path on mlflow
364-
artifact_path = Path(p).stem
366+
artifact_path = Path(self._checkpoint_path_prefix) / Path(p).stem
365367

366368
# Log the checkpoint
367369
self.experiment.log_artifact(self._run_id, p, artifact_path)

src/lightning/pytorch/loggers/wandb.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,11 @@ def experiment(self) -> Union["Run", "RunDisabled"]:
410410
if isinstance(self._experiment, (Run, RunDisabled)) and getattr(
411411
self._experiment, "define_metric", None
412412
):
413-
self._experiment.define_metric("trainer/global_step")
414-
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
413+
if self._wandb_init.get("sync_tensorboard"):
414+
self._experiment.define_metric("*", step_metric="global_step")
415+
else:
416+
self._experiment.define_metric("trainer/global_step")
417+
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
415418

416419
return self._experiment
417420

@@ -434,7 +437,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
434437
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
435438

436439
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
437-
if step is not None:
440+
if step is not None and not self._wandb_init.get("sync_tensorboard"):
438441
self.experiment.log(dict(metrics, **{"trainer/global_step": step}))
439442
else:
440443
self.experiment.log(metrics)

src/lightning/pytorch/trainer/call.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import lightning.pytorch as pl
2222
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
2323
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
24+
from lightning.pytorch.loggers import WandbLogger
2425
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
2526
from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal
2627
from lightning.pytorch.trainer.states import TrainerStatus
@@ -91,8 +92,12 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None:
9192
if isinstance(module, _DeviceDtypeModuleMixin):
9293
module._device = trainer.strategy.root_device
9394

95+
# wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb:
96+
# https://github.com/wandb/wandb/issues/1782#issuecomment-779161203
97+
loggers = sorted(trainer.loggers, key=lambda logger: not isinstance(logger, WandbLogger))
98+
9499
# Trigger lazy creation of experiment in loggers so loggers have their metadata available
95-
for logger in trainer.loggers:
100+
for logger in loggers:
96101
if hasattr(logger, "experiment"):
97102
_ = logger.experiment
98103

0 commit comments

Comments
 (0)