Skip to content

Commit 1f5add3

Browse files
Fix CSVLogger hyperparameter is logged at every write which increase latency significantly. (#20594)
* Move save_hparams_to_yaml to log_hparams instead of auto save with metric * Fix params to be optional * Adjust test * Fix test_csv, test_no_name --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9ef355b commit 1f5add3

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

src/lightning/pytorch/CHANGELOG.md

+13
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,19 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
## [unreleased] - YYYY-MM-DD
8+
9+
### Added
10+
11+
### Changed
12+
13+
### Removed
14+
15+
### Fixed
16+
17+
- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))
18+
19+
720
## [2.5.0] - 2024-12-19
821

922
### Added

src/lightning/pytorch/loggers/csv_logs.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,10 @@ def __init__(self, log_dir: str) -> None:
5555
self.hparams: dict[str, Any] = {}
5656

5757
def log_hparams(self, params: dict[str, Any]) -> None:
58-
"""Record hparams."""
58+
"""Record hparams and save into files."""
5959
self.hparams.update(params)
60-
61-
@override
62-
def save(self) -> None:
63-
"""Save recorded hparams and metrics into files."""
6460
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
6561
save_hparams_to_yaml(hparams_file, self.hparams)
66-
return super().save()
6762

6863

6964
class CSVLogger(Logger, FabricCSVLogger):
@@ -144,7 +139,7 @@ def save_dir(self) -> str:
144139

145140
@override
146141
@rank_zero_only
147-
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
142+
def log_hyperparams(self, params: Optional[Union[dict[str, Any], Namespace]] = None) -> None:
148143
params = _convert_params(params)
149144
self.experiment.log_hparams(params)
150145

tests/tests_pytorch/loggers/test_csv.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def test_named_version(tmp_path):
7575

7676
logger = CSVLogger(save_dir=tmp_path, name=exp_name, version=expected_version)
7777
logger.log_hyperparams({"a": 1, "b": 2})
78-
logger.save()
7978
assert logger.version == expected_version
8079
assert os.listdir(tmp_path / exp_name) == [expected_version]
8180
assert os.listdir(tmp_path / exp_name / expected_version)
@@ -85,7 +84,7 @@ def test_named_version(tmp_path):
8584
def test_no_name(tmp_path, name):
8685
"""Verify that None or empty name works."""
8786
logger = CSVLogger(save_dir=tmp_path, name=name)
88-
logger.save()
87+
logger.log_hyperparams()
8988
assert os.path.normpath(logger.root_dir) == str(tmp_path) # use os.path.normpath to handle trailing /
9089
assert os.listdir(tmp_path / "version_0")
9190

@@ -116,7 +115,6 @@ def test_log_hyperparams(tmp_path):
116115
"layer": torch.nn.BatchNorm1d,
117116
}
118117
logger.log_hyperparams(hparams)
119-
logger.save()
120118

121119
path_yaml = os.path.join(logger.log_dir, ExperimentWriter.NAME_HPARAMS_FILE)
122120
params = load_hparams_from_yaml(path_yaml)

0 commit comments

Comments
 (0)