Skip to content

Commit 26bfe56

Browse files
Add enable_autolog_hparams argument to Trainer. (#20593)
* Make hyperparam logging optional * Modify docs * Add to CHANGELOG.md * Fix typos --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a0134d2 commit 26bfe56

File tree

4 files changed

+39
-2
lines changed

4 files changed

+39
-2
lines changed

docs/source-pytorch/common/trainer.rst

+26
Original file line numberDiff line numberDiff line change
@@ -1077,6 +1077,32 @@ With :func:`torch.inference_mode` disabled, you can enable the grad of your mode
10771077
trainer = Trainer(inference_mode=False)
10781078
trainer.validate(model)
10791079
1080+
enable_autolog_hparams
1081+
^^^^^^^^^^^^^^^^^^^^^^
1082+
1083+
Whether to log hyperparameters at the start of a run. Defaults to True.
1084+
1085+
.. testcode::
1086+
1087+
# default used by the Trainer
1088+
trainer = Trainer(enable_autolog_hparams=True)
1089+
1090+
# disable logging hyperparams
1091+
trainer = Trainer(enable_autolog_hparams=False)
1092+
1093+
With the parameter set to false, you can add custom code to log hyperparameters.
1094+
1095+
.. code-block:: python
1096+
1097+
model = LitModel()
1098+
trainer = Trainer(enable_autolog_hparams=False)
1099+
for logger in trainer.loggers:
1100+
if isinstance(logger, lightning.pytorch.loggers.CSVLogger):
1101+
logger.log_hyperparams(hparams_dict_1)
1102+
else:
1103+
logger.log_hyperparams(hparams_dict_2)
1104+
1105+
You can also use `self.logger.log_hyperparams(...)` inside `LightningModule` to log.
10801106

10811107
-----
10821108

src/lightning/pytorch/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593))
12+
13+
1114
- Allow LightningCLI to use a customized argument parser class ([#20596](https://github.com/Lightning-AI/pytorch-lightning/pull/20596))
1215

1316

src/lightning/pytorch/trainer/trainer.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def __init__(
128128
sync_batchnorm: bool = False,
129129
reload_dataloaders_every_n_epochs: int = 0,
130130
default_root_dir: Optional[_PATH] = None,
131+
enable_autolog_hparams: bool = True,
131132
) -> None:
132133
r"""Customize every aspect of training via flags.
133134
@@ -290,6 +291,9 @@ def __init__(
290291
Default: ``os.getcwd()``.
291292
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
292293
294+
enable_autolog_hparams: Whether to log hyperparameters at the start of a run.
295+
Default: ``True``.
296+
293297
Raises:
294298
TypeError:
295299
If ``gradient_clip_val`` is not an int or float.
@@ -496,6 +500,8 @@ def __init__(
496500
num_sanity_val_steps,
497501
)
498502

503+
self.enable_autolog_hparams = enable_autolog_hparams
504+
499505
def fit(
500506
self,
501507
model: "pl.LightningModule",
@@ -962,7 +968,9 @@ def _run(
962968
call._call_callback_hooks(self, "on_fit_start")
963969
call._call_lightning_module_hook(self, "on_fit_start")
964970

965-
_log_hyperparams(self)
971+
# only log hparams if enabled
972+
if self.enable_autolog_hparams:
973+
_log_hyperparams(self)
966974

967975
if self.strategy.restore_checkpoint_after_setup:
968976
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")

tests/tests_pytorch/loggers/test_csv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def test_metrics_reset_after_save(tmp_path):
163163

164164

165165
@mock.patch(
166-
# Mock the existance check, so we can simulate appending to the metrics file
166+
# Mock the existence check, so we can simulate appending to the metrics file
167167
"lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists"
168168
)
169169
def test_append_metrics_file(_, tmp_path):

0 commit comments

Comments
 (0)