Skip to content

Commit 3fbc29b

Browse files
authored
Fix CSVLogger trying to append to file from previous run in same version folder (#19446)
1 parent aa6e085 commit 3fbc29b

File tree

7 files changed

+67
-15
lines changed

7 files changed

+67
-15
lines changed

.github/workflows/ci-tests-fabric.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ jobs:
4646
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" }
4747
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" }
4848
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" }
49-
- { os: "macOS-12", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
50-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
49+
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
50+
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
5151
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
5252
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
5353
- { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" }

.github/workflows/ci-tests-pytorch.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ jobs:
5050
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
5151
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
5252
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
53-
- { os: "macOS-12", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" }
54-
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" }
53+
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" }
54+
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" }
5555
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" }
5656
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
5757
- { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" }

src/lightning/fabric/CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4242

4343
### Fixed
4444

45-
-
45+
- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually ([#19446](https://github.com/Lightning-AI/lightning/pull/19446))
4646

4747
-
4848

src/lightning/fabric/loggers/csv_logs.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class CSVLogger(Logger):
4040
name: Experiment name. Defaults to ``'lightning_logs'``.
4141
version: Experiment version. If version is not specified the logger inspects the save
4242
directory for existing versions, then automatically assigns the next available version.
43+
If the version is specified, and the directory already contains a metrics file for that version, it will be
44+
overwritten.
4345
prefix: A string to put at the beginning of metric keys.
4446
flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps).
4547
@@ -203,15 +205,11 @@ def __init__(self, log_dir: str) -> None:
203205

204206
self._fs = get_filesystem(log_dir)
205207
self.log_dir = log_dir
206-
if self._fs.exists(self.log_dir) and self._fs.listdir(self.log_dir):
207-
rank_zero_warn(
208-
f"Experiment logs directory {self.log_dir} exists and is not empty."
209-
" Previous log files in this directory will be deleted when the new ones are saved!"
210-
)
211-
self._fs.makedirs(self.log_dir, exist_ok=True)
212-
213208
self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE)
214209

210+
self._check_log_dir_exists()
211+
self._fs.makedirs(self.log_dir, exist_ok=True)
212+
215213
def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None:
216214
"""Record metrics."""
217215

@@ -264,3 +262,12 @@ def _rewrite_with_new_header(self, fieldnames: List[str]) -> None:
264262
writer = csv.DictWriter(file, fieldnames=fieldnames)
265263
writer.writeheader()
266264
writer.writerows(metrics)
265+
266+
def _check_log_dir_exists(self) -> None:
267+
if self._fs.exists(self.log_dir) and self._fs.listdir(self.log_dir):
268+
rank_zero_warn(
269+
f"Experiment logs directory {self.log_dir} exists and is not empty."
270+
" Previous log files in this directory will be deleted when the new ones are saved!"
271+
)
272+
if self._fs.isfile(self.metrics_file_path):
273+
self._fs.rm_file(self.metrics_file_path)

src/lightning/pytorch/CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4040

4141
### Fixed
4242

43-
-
43+
- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually ([#19446](https://github.com/Lightning-AI/lightning/pull/19446))
44+
4445

4546
-
4647

tests/tests_fabric/loggers/test_csv.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from unittest import mock
1516
from unittest.mock import MagicMock
1617

1718
import pytest
@@ -50,6 +51,21 @@ def test_manual_versioning(tmp_path):
5051
assert logger.version == 1
5152

5253

54+
def test_manual_versioning_file_exists(tmp_path):
55+
"""Test that a warning is emitted and existing files get overwritten."""
56+
57+
# Simulate an existing 'version_0' vrom a previous run
58+
(tmp_path / "exp" / "version_0").mkdir(parents=True)
59+
previous_metrics_file = tmp_path / "exp" / "version_0" / "metrics.csv"
60+
previous_metrics_file.touch()
61+
62+
logger = CSVLogger(root_dir=tmp_path, name="exp", version=0)
63+
assert previous_metrics_file.exists()
64+
with pytest.warns(UserWarning, match="Experiment logs directory .* exists and is not empty"):
65+
_ = logger.experiment
66+
assert not previous_metrics_file.exists()
67+
68+
5369
def test_named_version(tmp_path):
5470
"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'."""
5571
exp_name = "exp"
@@ -130,7 +146,11 @@ def test_automatic_step_tracking(tmp_path):
130146
assert logger.experiment.metrics[2]["step"] == 2
131147

132148

133-
def test_append_metrics_file(tmp_path):
149+
@mock.patch(
150+
# Mock the existance check, so we can simulate appending to the metrics file
151+
"lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists"
152+
)
153+
def test_append_metrics_file(_, tmp_path):
134154
"""Test that the logger appends to the file instead of rewriting it on every save."""
135155
logger = CSVLogger(tmp_path, name="test", version=0, flush_logs_every_n_steps=1)
136156

@@ -167,7 +187,11 @@ def test_append_columns(tmp_path):
167187
assert set(header.split(",")) == {"step", "a", "b", "c"}
168188

169189

170-
def test_rewrite_with_new_header(tmp_path):
190+
@mock.patch(
191+
# Mock the existance check, so we can simulate appending to the metrics file
192+
"lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists"
193+
)
194+
def test_rewrite_with_new_header(_, tmp_path):
171195
# write a csv file manually
172196
with open(tmp_path / "metrics.csv", "w") as file:
173197
file.write("step,metric1,metric2\n")

tests/tests_pytorch/loggers/test_csv.py

+20
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from unittest import mock
1516
from unittest.mock import MagicMock
1617

1718
import fsspec
@@ -51,6 +52,21 @@ def test_manual_versioning(tmp_path):
5152
assert logger.version == 1
5253

5354

55+
def test_manual_versioning_file_exists(tmp_path):
56+
"""Test that a warning is emitted and existing files get overwritten."""
57+
58+
# Simulate an existing 'version_0' vrom a previous run
59+
(tmp_path / "exp" / "version_0").mkdir(parents=True)
60+
previous_metrics_file = tmp_path / "exp" / "version_0" / "metrics.csv"
61+
previous_metrics_file.touch()
62+
63+
logger = CSVLogger(save_dir=tmp_path, name="exp", version=0)
64+
assert previous_metrics_file.exists()
65+
with pytest.warns(UserWarning, match="Experiment logs directory .* exists and is not empty"):
66+
_ = logger.experiment
67+
assert not previous_metrics_file.exists()
68+
69+
5470
def test_named_version(tmp_path):
5571
"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'."""
5672
exp_name = "exp"
@@ -148,6 +164,10 @@ def test_metrics_reset_after_save(tmp_path):
148164
assert not logger.experiment.metrics
149165

150166

167+
@mock.patch(
168+
# Mock the existance check, so we can simulate appending to the metrics file
169+
"lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists"
170+
)
151171
def test_append_metrics_file(tmp_path):
152172
"""Test that the logger appends to the file instead of rewriting it on every save."""
153173
logger = CSVLogger(tmp_path, name="test", version=0, flush_logs_every_n_steps=1)

0 commit comments

Comments
 (0)