|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import logging
|
| 15 | +from dataclasses import dataclass |
15 | 16 | from typing import Any, Dict, List, Optional, Union
|
16 | 17 |
|
17 | 18 | import torch
|
|
45 | 46 | log = logging.getLogger(__name__)
|
46 | 47 |
|
47 | 48 |
|
| 49 | +@dataclass |
| 50 | +class RestartStage: |
| 51 | + NONE = "none" |
| 52 | + RESTARTED_ON_EPOCH_START = "restarted_on_epoch_start" |
| 53 | + RESTARTED_MID_EPOCH = "restarted_mid_epoch" |
| 54 | + RESTARTED_ON_EPOCH_END = "restarted_on_epoch_end" |
| 55 | + RESUMED_ON_EPOCH_END = "resumed_on_epoch_end" |
| 56 | + |
| 57 | + |
48 | 58 | class _FitLoop(_Loop):
|
49 | 59 | """This loop is the top-level loop where training starts.
|
50 | 60 |
|
@@ -97,6 +107,7 @@ def __init__(
|
97 | 107 | self._combined_loader_states_to_load: List[Dict[str, Any]] = []
|
98 | 108 | self._data_fetcher: Optional[_DataFetcher] = None
|
99 | 109 | self._last_train_dl_reload_epoch = float("-inf")
|
| 110 | + self._restart_stage = RestartStage.NONE |
100 | 111 |
|
101 | 112 | @property
|
102 | 113 | def total_batch_idx(self) -> int:
|
@@ -204,9 +215,10 @@ def run(self) -> None:
|
204 | 215 | self.on_advance_start()
|
205 | 216 | self.advance()
|
206 | 217 | self.on_advance_end()
|
207 |
| - self._restarting = False |
208 | 218 | except StopIteration:
|
209 | 219 | break
|
| 220 | + finally: |
| 221 | + self.on_iteration_done() |
210 | 222 | self._restarting = False
|
211 | 223 | self.on_run_end()
|
212 | 224 |
|
@@ -302,14 +314,92 @@ def setup_data(self) -> None:
|
302 | 314 | category=PossibleUserWarning,
|
303 | 315 | )
|
304 | 316 |
|
| 317 | + @property |
| 318 | + def restarted_on_epoch_start(self) -> bool: |
| 319 | + return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_START |
| 320 | + |
| 321 | + @property |
| 322 | + def restarted_mid_epoch(self) -> bool: |
| 323 | + return self._restart_stage == RestartStage.RESTARTED_MID_EPOCH |
| 324 | + |
| 325 | + @property |
| 326 | + def restarted_on_epoch_end(self) -> bool: |
| 327 | + return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_END |
| 328 | + |
| 329 | + @property |
| 330 | + def resumed_on_epoch_end(self) -> bool: |
| 331 | + # This case happens when restarting from last without validation at |
| 332 | + # the end of epoch. In this case self.restarting is False. |
| 333 | + return self._restart_stage == RestartStage.RESUMED_ON_EPOCH_END |
| 334 | + |
| 335 | + def update_restart_stage(self) -> None: |
| 336 | + if ( |
| 337 | + self.restarting |
| 338 | + and self.epoch_progress.total.started == self.epoch_progress.total.ready - 1 |
| 339 | + and self.epoch_progress.total.processed == self.epoch_progress.total.started |
| 340 | + and self.epoch_progress.total.completed == self.epoch_progress.total.processed |
| 341 | + ): |
| 342 | + self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_START |
| 343 | + elif ( |
| 344 | + self.restarting |
| 345 | + and self.epoch_progress.total.started == self.epoch_progress.total.ready |
| 346 | + and self.epoch_progress.total.processed == self.epoch_progress.total.started - 1 |
| 347 | + and self.epoch_progress.total.completed == self.epoch_progress.total.processed |
| 348 | + ): |
| 349 | + self._restart_stage = RestartStage.RESTARTED_MID_EPOCH |
| 350 | + elif ( |
| 351 | + self.restarting |
| 352 | + and self.epoch_progress.total.started == self.epoch_progress.total.ready |
| 353 | + and self.epoch_progress.total.processed == self.epoch_progress.total.started |
| 354 | + and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 |
| 355 | + ): |
| 356 | + self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_END |
| 357 | + elif ( |
| 358 | + self._loaded_from_state_dict |
| 359 | + and self.epoch_progress.total.started == self.epoch_progress.total.ready |
| 360 | + and self.epoch_progress.total.processed == self.epoch_progress.total.started |
| 361 | + and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 |
| 362 | + ): |
| 363 | + self._restart_stage = RestartStage.RESUMED_ON_EPOCH_END |
| 364 | + else: |
| 365 | + self._restart_stage = RestartStage.NONE |
| 366 | + |
| 367 | + self.epoch_loop.update_restart_stage() |
| 368 | + |
| 369 | + def reset_restart_stage(self) -> None: |
| 370 | + self._restart_stage = RestartStage.NONE |
| 371 | + |
305 | 372 | def reset(self) -> None:
|
306 | 373 | """Resets the internal state of this loop."""
|
307 | 374 | assert self.trainer.model is not None
|
308 | 375 | torch.set_grad_enabled(True)
|
309 | 376 |
|
310 |
| - if self.restarting: |
| 377 | + self.update_restart_stage() |
| 378 | + |
| 379 | + if self.restarted_on_epoch_start: |
311 | 380 | self.epoch_progress.reset_on_restart()
|
312 | 381 |
|
| 382 | + if self.resumed_on_epoch_end: |
| 383 | + # when restarting from last without validation at end of epoch, |
| 384 | + # self.restarting is False but it's still resuming |
| 385 | + self.epoch_progress.increment_completed() |
| 386 | + |
| 387 | + if ( |
| 388 | + self.epoch_loop.restarted_on_train_batch_end |
| 389 | + and self.restarted_mid_epoch |
| 390 | + and self.epoch_loop.batch_progress.is_last_batch |
| 391 | + ): |
| 392 | + self.epoch_progress.increment_processed() |
| 393 | + self.epoch_progress.increment_completed() |
| 394 | + |
| 395 | + if ( |
| 396 | + self.epoch_loop.restarted_on_train_batch_end |
| 397 | + and self.epoch_loop.batch_progress.is_last_batch |
| 398 | + and not self.restarted_mid_epoch |
| 399 | + and not self.epoch_loop.val_loop.batch_progress.is_last_batch |
| 400 | + ): |
| 401 | + self.epoch_progress.increment_completed() |
| 402 | + |
313 | 403 | def on_run_start(self) -> None:
|
314 | 404 | """Calls the ``on_train_start`` hook."""
|
315 | 405 | # update the current_epoch in-case of checkpoint reload
|
@@ -340,12 +430,14 @@ def on_advance_start(self) -> None:
|
340 | 430 | for i, dl in enumerate(self._combined_loader.flattened):
|
341 | 431 | _set_sampler_epoch(dl, self.epoch_progress.current.processed)
|
342 | 432 |
|
343 |
| - self.epoch_progress.increment_ready() |
| 433 | + if not self.restarted_mid_epoch and not self.restarted_on_epoch_end: |
| 434 | + if not self.restarted_on_epoch_start: |
| 435 | + self.epoch_progress.increment_ready() |
344 | 436 |
|
345 |
| - call._call_callback_hooks(trainer, "on_train_epoch_start") |
346 |
| - call._call_lightning_module_hook(trainer, "on_train_epoch_start") |
| 437 | + call._call_callback_hooks(trainer, "on_train_epoch_start") |
| 438 | + call._call_lightning_module_hook(trainer, "on_train_epoch_start") |
347 | 439 |
|
348 |
| - self.epoch_progress.increment_started() |
| 440 | + self.epoch_progress.increment_started() |
349 | 441 |
|
350 | 442 | def advance(self) -> None:
|
351 | 443 | """Runs one whole epoch."""
|
@@ -379,8 +471,7 @@ def on_advance_end(self) -> None:
|
379 | 471 |
|
380 | 472 | trainer._logger_connector.on_epoch_end()
|
381 | 473 |
|
382 |
| - if self.epoch_loop._num_ready_batches_reached(): |
383 |
| - # if we are restarting and the above condition holds, it's because we are reloading an epoch-end checkpoint. |
| 474 | + if not self.restarting and self.epoch_loop._num_ready_batches_reached(): |
384 | 475 | # since metric-based schedulers require access to metrics and those are not currently saved in the
|
385 | 476 | # checkpoint, the plateau schedulers shouldn't be updated
|
386 | 477 | self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting)
|
|
0 commit comments