Skip to content

Commit b6769b0

Browse files
authored
Merge branch 'master' into delete_data_stuffs
2 parents 3442da9 + 63188f9 commit b6769b0

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

src/lightning/pytorch/CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212

1313
- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105))
1414

15-
-
15+
- Added robust timer duration parsing with an informative error message when parsing fails ([#19513](https://github.com/Lightning-AI/pytorch-lightning/pull/19513))
1616

1717
-
1818

src/lightning/pytorch/callbacks/timer.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""
1818

1919
import logging
20+
import re
2021
import time
2122
from datetime import timedelta
2223
from typing import Any, Dict, Optional, Union
@@ -50,6 +51,8 @@ class Timer(Callback):
5051
verbose: Set this to ``False`` to suppress logging messages.
5152
5253
Raises:
54+
MisconfigurationException:
55+
If ``duration`` is not in the expected format.
5356
MisconfigurationException:
5457
If ``interval`` is not one of the supported choices.
5558
@@ -86,10 +89,19 @@ def __init__(
8689
) -> None:
8790
super().__init__()
8891
if isinstance(duration, str):
89-
dhms = duration.strip().split(":")
90-
dhms = [int(i) for i in dhms]
91-
duration = timedelta(days=dhms[0], hours=dhms[1], minutes=dhms[2], seconds=dhms[3])
92-
if isinstance(duration, dict):
92+
duration_match = re.fullmatch(r"(\d+):(\d\d):(\d\d):(\d\d)", duration.strip())
93+
if not duration_match:
94+
raise MisconfigurationException(
95+
f"`Timer(duration={duration!r})` is not a valid duration. "
96+
"Expected a string in the format DD:HH:MM:SS."
97+
)
98+
duration = timedelta(
99+
days=int(duration_match.group(1)),
100+
hours=int(duration_match.group(2)),
101+
minutes=int(duration_match.group(3)),
102+
seconds=int(duration_match.group(4)),
103+
)
104+
elif isinstance(duration, dict):
93105
duration = timedelta(**duration)
94106
if interval not in set(Interval):
95107
raise MisconfigurationException(

tests/tests_pytorch/callbacks/test_timer.py

+6
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ def test_timer_parse_duration(duration, expected):
6767
assert (timer.time_remaining() == expected is None) or (timer.time_remaining() == expected.total_seconds())
6868

6969

70+
@pytest.mark.parametrize("duration", ["6:00:00", "60 minutes"])
71+
def test_timer_parse_duration_misconfiguration(duration):
72+
with pytest.raises(MisconfigurationException, match="format DD:HH:MM:SS"):
73+
Timer(duration=duration)
74+
75+
7076
def test_timer_interval_choice():
7177
Timer(duration=timedelta(), interval="step")
7278
Timer(duration=timedelta(), interval="epoch")

0 commit comments

Comments
 (0)