Skip to content

Commit 63c9349

Browse files
authored
Merge branch 'master' into fix/MLFlowLogger
2 parents e1baacb + 5756c81 commit 63c9349

File tree

16 files changed

+680
-73
lines changed

16 files changed

+680
-73
lines changed

.github/checkgroup.yml

+12-12
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ subprojects:
1919
- "!*.md"
2020
- "!**/*.md"
2121
checks:
22-
- "pl-cpu (macOS-13, lightning, 3.9, 2.1, oldest)"
22+
- "pl-cpu (macOS-14, lightning, 3.9, 2.1, oldest)"
2323
- "pl-cpu (macOS-14, lightning, 3.10, 2.1)"
2424
- "pl-cpu (macOS-14, lightning, 3.11, 2.2.2)"
2525
- "pl-cpu (macOS-14, lightning, 3.11, 2.3)"
@@ -40,7 +40,7 @@ subprojects:
4040
- "pl-cpu (macOS-14, pytorch, 3.9, 2.1)"
4141
- "pl-cpu (ubuntu-20.04, pytorch, 3.9, 2.1)"
4242
- "pl-cpu (windows-2022, pytorch, 3.9, 2.1)"
43-
- "pl-cpu (macOS-13, pytorch, 3.10, 2.1)"
43+
- "pl-cpu (macOS-14, pytorch, 3.10, 2.1)"
4444
- "pl-cpu (ubuntu-22.04, pytorch, 3.10, 2.1)"
4545
- "pl-cpu (windows-2022, pytorch, 3.10, 2.1)"
4646

@@ -171,7 +171,7 @@ subprojects:
171171
- "!*.md"
172172
- "!**/*.md"
173173
checks:
174-
- "fabric-cpu (macOS-13, lightning, 3.9, 2.1, oldest)"
174+
- "fabric-cpu (macOS-14, lightning, 3.9, 2.1, oldest)"
175175
- "fabric-cpu (macOS-14, lightning, 3.10, 2.1)"
176176
- "fabric-cpu (macOS-14, lightning, 3.11, 2.2.2)"
177177
- "fabric-cpu (macOS-14, lightning, 3.11, 2.3)"
@@ -192,7 +192,7 @@ subprojects:
192192
- "fabric-cpu (macOS-14, fabric, 3.9, 2.1)"
193193
- "fabric-cpu (ubuntu-20.04, fabric, 3.9, 2.1)"
194194
- "fabric-cpu (windows-2022, fabric, 3.9, 2.1)"
195-
- "fabric-cpu (macOS-13, fabric, 3.10, 2.1)"
195+
- "fabric-cpu (macOS-14, fabric, 3.10, 2.1)"
196196
- "fabric-cpu (ubuntu-22.04, fabric, 3.10, 2.1)"
197197
- "fabric-cpu (windows-2022, fabric, 3.10, 2.1)"
198198

@@ -266,14 +266,14 @@ subprojects:
266266
- "install-pkg (ubuntu-22.04, lightning, 3.11)"
267267
- "install-pkg (ubuntu-22.04, notset, 3.9)"
268268
- "install-pkg (ubuntu-22.04, notset, 3.11)"
269-
- "install-pkg (macOS-13, fabric, 3.9)"
270-
- "install-pkg (macOS-13, fabric, 3.11)"
271-
- "install-pkg (macOS-13, pytorch, 3.9)"
272-
- "install-pkg (macOS-13, pytorch, 3.11)"
273-
- "install-pkg (macOS-13, lightning, 3.9)"
274-
- "install-pkg (macOS-13, lightning, 3.11)"
275-
- "install-pkg (macOS-13, notset, 3.9)"
276-
- "install-pkg (macOS-13, notset, 3.11)"
269+
- "install-pkg (macOS-14, fabric, 3.9)"
270+
- "install-pkg (macOS-14, fabric, 3.11)"
271+
- "install-pkg (macOS-14, pytorch, 3.9)"
272+
- "install-pkg (macOS-14, pytorch, 3.11)"
273+
- "install-pkg (macOS-14, lightning, 3.9)"
274+
- "install-pkg (macOS-14, lightning, 3.11)"
275+
- "install-pkg (macOS-14, notset, 3.9)"
276+
- "install-pkg (macOS-14, notset, 3.11)"
277277
- "install-pkg (windows-2022, fabric, 3.9)"
278278
- "install-pkg (windows-2022, fabric, 3.11)"
279279
- "install-pkg (windows-2022, pytorch, 3.9)"

.github/workflows/ci-pkg-install.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242
strategy:
4343
fail-fast: false
4444
matrix:
45-
os: ["ubuntu-22.04", "macOS-13", "windows-2022"]
45+
os: ["ubuntu-22.04", "macOS-14", "windows-2022"]
4646
pkg-name: ["fabric", "pytorch", "lightning", "notset"]
4747
python-version: ["3.9", "3.11"]
4848
steps:

.github/workflows/ci-tests-fabric.yml

+6-3
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ jobs:
5656
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
5757
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
5858
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
59-
- { os: "macOS-13", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }
59+
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }
6060
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }
6161
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }
6262
# "oldest" versions tests, only on minimum Python
63-
- { os: "macOS-13", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
63+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
6464
- {
6565
os: "ubuntu-20.04",
6666
pkg-name: "lightning",
@@ -101,7 +101,10 @@ jobs:
101101

102102
- name: Set min. dependencies
103103
if: ${{ matrix.requires == 'oldest' }}
104-
run: python .actions/assistant.py replace_oldest_ver
104+
run: |
105+
python .actions/assistant.py replace_oldest_ver
106+
pip install "cython<3.0" wheel
107+
pip install "pyyaml==5.4" --no-build-isolation
105108
106109
- name: Adjust PyTorch versions in requirements files
107110
if: ${{ matrix.requires != 'oldest' }}

.github/workflows/ci-tests-pytorch.yml

+6-3
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ jobs:
6060
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
6161
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
6262
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
63-
- { os: "macOS-13", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }
63+
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }
6464
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }
6565
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }
6666
# "oldest" versions tests, only on minimum Python
67-
- { os: "macOS-13", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
67+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" }
6868
- {
6969
os: "ubuntu-20.04",
7070
pkg-name: "lightning",
@@ -106,7 +106,10 @@ jobs:
106106

107107
- name: Set min. dependencies
108108
if: ${{ matrix.requires == 'oldest' }}
109-
run: python .actions/assistant.py replace_oldest_ver
109+
run: |
110+
python .actions/assistant.py replace_oldest_ver
111+
pip install "cython<3.0" wheel
112+
pip install "pyyaml==5.4" --no-build-isolation
110113
111114
- name: Adjust PyTorch versions in requirements files
112115
if: ${{ matrix.requires != 'oldest' }}

src/lightning/fabric/utilities/throughput.py

+8
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,14 @@ def measure_flops(
347347
torch.int8: 389.9e12,
348348
"int4": 779.8e12,
349349
},
350+
"rtx 4080 super": {
351+
torch.float32: 52.2e12,
352+
"tfloat32": 52.2e12,
353+
torch.bfloat16: 52.2e12,
354+
torch.float16: 52.2e12,
355+
torch.int8: 417.6e12,
356+
"int4": 835.2e12,
357+
},
350358
"l4": {
351359
torch.float32: 30.3e12,
352360
"tfloat32": 60e12,

src/lightning/pytorch/demos/transformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def forward(self, x: Tensor) -> Tensor:
8888
# TODO: Could make this a `nn.Parameter` with `requires_grad=False`
8989
self.pe = self._init_pos_encoding(device=x.device)
9090

91-
x = x + self.pe[: x.size(0), :]
91+
x = x + self.pe[:, x.size(1)]
9292
return self.dropout(x)
9393

9494
def _init_pos_encoding(self, device: torch.device) -> Tensor:
@@ -97,7 +97,7 @@ def _init_pos_encoding(self, device: torch.device) -> Tensor:
9797
div_term = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * (-math.log(10000.0) / self.dim))
9898
pe[:, 0::2] = torch.sin(position * div_term)
9999
pe[:, 1::2] = torch.cos(position * div_term)
100-
pe = pe.unsqueeze(0).transpose(0, 1)
100+
pe = pe.unsqueeze(0)
101101
return pe
102102

103103

src/lightning/pytorch/loops/evaluation_loop.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import shutil
1616
import sys
1717
from collections import ChainMap, OrderedDict, defaultdict
18+
from dataclasses import dataclass
1819
from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union
1920

2021
from lightning_utilities.core.apply_func import apply_to_collection
@@ -45,6 +46,12 @@
4546
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
4647

4748

49+
@dataclass
50+
class RestartStage:
51+
NONE = "none"
52+
RESTARTED_MID_EVALUATION = "restarted_mid_evaluation"
53+
54+
4855
class _EvaluationLoop(_Loop):
4956
"""Top-level loop where validation/testing starts."""
5057

@@ -73,6 +80,7 @@ def __init__(
7380
self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int)
7481
self._last_val_dl_reload_epoch = float("-inf")
7582
self._module_mode = _ModuleMode()
83+
self._restart_stage = RestartStage.NONE
7684

7785
@property
7886
def num_dataloaders(self) -> int:
@@ -137,7 +145,7 @@ def run(self) -> List[_OUT_DICT]:
137145
# this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
138146
break
139147
finally:
140-
self._restarting = False
148+
self.on_iteration_done()
141149
self._store_dataloader_outputs()
142150
return self.on_run_end()
143151

@@ -197,6 +205,24 @@ def setup_data(self) -> None:
197205
# this depends on the data used, so reset it too
198206
self._seen_batches_per_dataloader = defaultdict(int)
199207

208+
@property
209+
def restarted_mid_evaluation(self) -> bool:
210+
return self._restart_stage == RestartStage.RESTARTED_MID_EVALUATION
211+
212+
def update_restart_stage(self) -> None:
213+
if (
214+
self.restarting
215+
and self.batch_progress.total.started == self.batch_progress.total.ready
216+
and self.batch_progress.total.processed == self.batch_progress.total.started - 1
217+
and self.batch_progress.total.completed == self.batch_progress.total.processed
218+
):
219+
self._restart_stage = RestartStage.RESTARTED_MID_EVALUATION
220+
else:
221+
self._restart_stage = RestartStage.NONE
222+
223+
def reset_restart_stage(self) -> None:
224+
self._restart_stage = RestartStage.NONE
225+
200226
def reset(self) -> None:
201227
"""Resets the internal state of the loop."""
202228
trainer = self.trainer
@@ -236,6 +262,16 @@ def reset(self) -> None:
236262
data_fetcher._stop_profiler = self._on_after_fetch
237263
self._data_fetcher = data_fetcher
238264

265+
def increment_progress_to_evaluation_end(self) -> None:
266+
self.setup_data()
267+
if self.skip:
268+
return
269+
self.reset()
270+
max_batch = int(max(self.max_batches))
271+
if max_batch == -1:
272+
return
273+
self.batch_progress.increment_by(max_batch, True)
274+
239275
def on_run_start(self) -> None:
240276
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
241277
hooks."""

src/lightning/pytorch/loops/fit_loop.py

+99-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
from dataclasses import dataclass
1516
from typing import Any, Dict, List, Optional, Union
1617

1718
import torch
@@ -45,6 +46,15 @@
4546
log = logging.getLogger(__name__)
4647

4748

49+
@dataclass
50+
class RestartStage:
51+
NONE = "none"
52+
RESTARTED_ON_EPOCH_START = "restarted_on_epoch_start"
53+
RESTARTED_MID_EPOCH = "restarted_mid_epoch"
54+
RESTARTED_ON_EPOCH_END = "restarted_on_epoch_end"
55+
RESUMED_ON_EPOCH_END = "resumed_on_epoch_end"
56+
57+
4858
class _FitLoop(_Loop):
4959
"""This loop is the top-level loop where training starts.
5060
@@ -97,6 +107,7 @@ def __init__(
97107
self._combined_loader_states_to_load: List[Dict[str, Any]] = []
98108
self._data_fetcher: Optional[_DataFetcher] = None
99109
self._last_train_dl_reload_epoch = float("-inf")
110+
self._restart_stage = RestartStage.NONE
100111

101112
@property
102113
def total_batch_idx(self) -> int:
@@ -204,9 +215,10 @@ def run(self) -> None:
204215
self.on_advance_start()
205216
self.advance()
206217
self.on_advance_end()
207-
self._restarting = False
208218
except StopIteration:
209219
break
220+
finally:
221+
self.on_iteration_done()
210222
self._restarting = False
211223
self.on_run_end()
212224

@@ -302,14 +314,92 @@ def setup_data(self) -> None:
302314
category=PossibleUserWarning,
303315
)
304316

317+
@property
318+
def restarted_on_epoch_start(self) -> bool:
319+
return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_START
320+
321+
@property
322+
def restarted_mid_epoch(self) -> bool:
323+
return self._restart_stage == RestartStage.RESTARTED_MID_EPOCH
324+
325+
@property
326+
def restarted_on_epoch_end(self) -> bool:
327+
return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_END
328+
329+
@property
330+
def resumed_on_epoch_end(self) -> bool:
331+
# This case happens when restarting from last without validation at
332+
# the end of epoch. In this case self.restarting is False.
333+
return self._restart_stage == RestartStage.RESUMED_ON_EPOCH_END
334+
335+
def update_restart_stage(self) -> None:
336+
if (
337+
self.restarting
338+
and self.epoch_progress.total.started == self.epoch_progress.total.ready - 1
339+
and self.epoch_progress.total.processed == self.epoch_progress.total.started
340+
and self.epoch_progress.total.completed == self.epoch_progress.total.processed
341+
):
342+
self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_START
343+
elif (
344+
self.restarting
345+
and self.epoch_progress.total.started == self.epoch_progress.total.ready
346+
and self.epoch_progress.total.processed == self.epoch_progress.total.started - 1
347+
and self.epoch_progress.total.completed == self.epoch_progress.total.processed
348+
):
349+
self._restart_stage = RestartStage.RESTARTED_MID_EPOCH
350+
elif (
351+
self.restarting
352+
and self.epoch_progress.total.started == self.epoch_progress.total.ready
353+
and self.epoch_progress.total.processed == self.epoch_progress.total.started
354+
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
355+
):
356+
self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_END
357+
elif (
358+
self._loaded_from_state_dict
359+
and self.epoch_progress.total.started == self.epoch_progress.total.ready
360+
and self.epoch_progress.total.processed == self.epoch_progress.total.started
361+
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
362+
):
363+
self._restart_stage = RestartStage.RESUMED_ON_EPOCH_END
364+
else:
365+
self._restart_stage = RestartStage.NONE
366+
367+
self.epoch_loop.update_restart_stage()
368+
369+
def reset_restart_stage(self) -> None:
370+
self._restart_stage = RestartStage.NONE
371+
305372
def reset(self) -> None:
306373
"""Resets the internal state of this loop."""
307374
assert self.trainer.model is not None
308375
torch.set_grad_enabled(True)
309376

310-
if self.restarting:
377+
self.update_restart_stage()
378+
379+
if self.restarted_on_epoch_start:
311380
self.epoch_progress.reset_on_restart()
312381

382+
if self.resumed_on_epoch_end:
383+
# when restarting from last without validation at end of epoch,
384+
# self.restarting is False but it's still resuming
385+
self.epoch_progress.increment_completed()
386+
387+
if (
388+
self.epoch_loop.restarted_on_train_batch_end
389+
and self.restarted_mid_epoch
390+
and self.epoch_loop.batch_progress.is_last_batch
391+
):
392+
self.epoch_progress.increment_processed()
393+
self.epoch_progress.increment_completed()
394+
395+
if (
396+
self.epoch_loop.restarted_on_train_batch_end
397+
and self.epoch_loop.batch_progress.is_last_batch
398+
and not self.restarted_mid_epoch
399+
and not self.epoch_loop.val_loop.batch_progress.is_last_batch
400+
):
401+
self.epoch_progress.increment_completed()
402+
313403
def on_run_start(self) -> None:
314404
"""Calls the ``on_train_start`` hook."""
315405
# update the current_epoch in-case of checkpoint reload
@@ -340,12 +430,14 @@ def on_advance_start(self) -> None:
340430
for i, dl in enumerate(self._combined_loader.flattened):
341431
_set_sampler_epoch(dl, self.epoch_progress.current.processed)
342432

343-
self.epoch_progress.increment_ready()
433+
if not self.restarted_mid_epoch and not self.restarted_on_epoch_end:
434+
if not self.restarted_on_epoch_start:
435+
self.epoch_progress.increment_ready()
344436

345-
call._call_callback_hooks(trainer, "on_train_epoch_start")
346-
call._call_lightning_module_hook(trainer, "on_train_epoch_start")
437+
call._call_callback_hooks(trainer, "on_train_epoch_start")
438+
call._call_lightning_module_hook(trainer, "on_train_epoch_start")
347439

348-
self.epoch_progress.increment_started()
440+
self.epoch_progress.increment_started()
349441

350442
def advance(self) -> None:
351443
"""Runs one whole epoch."""
@@ -379,8 +471,7 @@ def on_advance_end(self) -> None:
379471

380472
trainer._logger_connector.on_epoch_end()
381473

382-
if self.epoch_loop._num_ready_batches_reached():
383-
# if we are restarting and the above condition holds, it's because we are reloading an epoch-end checkpoint.
474+
if not self.restarting and self.epoch_loop._num_ready_batches_reached():
384475
# since metric-based schedulers require access to metrics and those are not currently saved in the
385476
# checkpoint, the plateau schedulers shouldn't be updated
386477
self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting)

0 commit comments

Comments
 (0)