Skip to content

Commit a9125c2

Browse files
mauvilsaBordalantiga
authored
Fix LightningCLI failing when both module and data module save hyperparameters (#20221)
* Fix LightningCLI failing when both module and data module save hyperparameters due to conflicting internal parameter * Update changelog pull link * Only skip logging internal LightningCLI params * Only skip logging internal LightningCLI params * Only skip _class_path --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
1 parent 60289d7 commit a9125c2

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

src/lightning/pytorch/CHANGELOG.md

+15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,21 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [unreleased] - YYYY-MM-DD
9+
10+
### Added
11+
12+
### Changed
13+
14+
- Merging of hparams when logging now ignores parameter names that begin with underscore `_` ([#20221](https://github.com/Lightning-AI/pytorch-lightning/pull/20221))
15+
16+
### Removed
17+
18+
### Fixed
19+
20+
- Fix LightningCLI failing when both module and data module save hyperparameters due to conflicting internal `_class_path` parameter ([#20221](https://github.com/Lightning-AI/pytorch-lightning/pull/20221))
21+
22+
823
## [2.4.0] - 2024-08-06
924

1025
### Added

src/lightning/pytorch/loggers/utilities.py

+7
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None:
6969
lightning_hparams = pl_module.hparams_initial
7070
inconsistent_keys = []
7171
for key in lightning_hparams.keys() & datamodule_hparams.keys():
72+
if key == "_class_path":
73+
# Skip LightningCLI's internal hparam
74+
continue
7275
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
7376
if (
7477
type(lm_val) != type(dm_val)
@@ -88,6 +91,10 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None:
8891
elif datamodule_log_hyperparams:
8992
hparams_initial = trainer.datamodule.hparams_initial
9093

94+
# Don't log LightningCLI's internal hparam
95+
if hparams_initial is not None:
96+
hparams_initial = {k: v for k, v in hparams_initial.items() if k != "_class_path"}
97+
9198
for logger in trainer.loggers:
9299
if hparams_initial is not None:
93100
logger.log_hyperparams(hparams_initial)

tests/tests_pytorch/test_cli.py

+23
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,29 @@ def test_lightning_cli_save_hyperparameters_untyped_module(cleandir):
973973
assert model.kwargs == {"x": 1}
974974

975975

976+
class TestDataSaveHparams(BoringDataModule):
977+
def __init__(self, batch_size: int = 32, num_workers: int = 4):
978+
super().__init__()
979+
self.save_hyperparameters()
980+
self.batch_size = batch_size
981+
self.num_workers = num_workers
982+
983+
984+
def test_lightning_cli_save_hyperparameters_merge(cleandir):
985+
config = {
986+
"model": {
987+
"class_path": f"{__name__}.TestModelSaveHparams",
988+
},
989+
"data": {
990+
"class_path": f"{__name__}.TestDataSaveHparams",
991+
},
992+
}
993+
with mock.patch("sys.argv", ["any.py", "fit", f"--config={json.dumps(config)}", "--trainer.max_epochs=1"]):
994+
cli = LightningCLI(auto_configure_optimizers=False)
995+
assert set(cli.model.hparams) == {"optimizer", "scheduler", "activation", "_instantiator", "_class_path"}
996+
assert set(cli.datamodule.hparams) == {"batch_size", "num_workers", "_instantiator", "_class_path"}
997+
998+
976999
@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
9771000
def test_lightning_cli_trainer_fn(fn):
9781001
class TestCLI(LightningCLI):

0 commit comments

Comments
 (0)