-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
Copy pathtrainer.py
1689 lines (1378 loc) · 70.9 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# THIS FILE MUST READ EASILY, FOR UNDERSTANDING AND DEBUGGING PURPOSES.
# DO NOT OBSCURE THE TRAINING LOOP
# THIS IS A HARD REQUIREMENT TO CONTRIBUTING TO LIGHTNING
# WE FAVOR READABILITY OVER ENGINEERING-CONSTRUCTS BY DESIGN
# DO NOT REMOVE THIS NOTICE
# - WILLIAM FALCON
"""Trainer to automate the training."""
import logging
import math
import os
from collections.abc import Generator, Iterable
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, Optional, Union
from weakref import proxy
import torch
from torch.optim import Optimizer
import lightning.pytorch as pl
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBar
from lightning.pytorch.core.datamodule import LightningDataModule
from lightning.pytorch.loggers import Logger
from lightning.pytorch.loggers.csv_logs import CSVLogger
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.loggers.utilities import _log_hyperparams
from lightning.pytorch.loops import _PredictionLoop, _TrainingEpochLoop
from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop
from lightning.pytorch.loops.fit_loop import _FitLoop
from lightning.pytorch.loops.utilities import _parse_loop_limits, _reset_progress
from lightning.pytorch.plugins import _PLUGIN_INPUT, Precision
from lightning.pytorch.profilers import Profiler
from lightning.pytorch.strategies import ParallelStrategy, Strategy
from lightning.pytorch.trainer import call, setup
from lightning.pytorch.trainer.configuration_validator import _verify_loop_configurations
from lightning.pytorch.trainer.connectors.accelerator_connector import (
_LITERAL_WARN,
_PRECISION_INPUT,
_PRECISION_INPUT_STR,
_AcceleratorConnector,
)
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector
from lightning.pytorch.trainer.connectors.data_connector import _DataConnector
from lightning.pytorch.trainer.connectors.logger_connector import _LoggerConnector
from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _PBAR_DICT, _ResultCollection
from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector
from lightning.pytorch.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
from lightning.pytorch.utilities import GradClipAlgorithmType, parsing
from lightning.pytorch.utilities.argparse import _defaults_from_env_vars
from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized, _verify_strategy_supports_compile
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.seed import isolate_rng
from lightning.pytorch.utilities.types import (
_EVALUATE_OUTPUT,
_PREDICT_OUTPUT,
EVAL_DATALOADERS,
TRAIN_DATALOADERS,
LRSchedulerConfig,
)
from lightning.pytorch.utilities.warnings import PossibleUserWarning
log = logging.getLogger(__name__)
class Trainer:
@_defaults_from_env_vars
def __init__(
self,
*,
accelerator: Union[str, Accelerator] = "auto",
strategy: Union[str, Strategy] = "auto",
devices: Union[list[int], str, int] = "auto",
num_nodes: int = 1,
precision: Optional[_PRECISION_INPUT] = None,
logger: Optional[Union[Logger, Iterable[Logger], bool]] = None,
callbacks: Optional[Union[list[Callback], Callback]] = None,
fast_dev_run: Union[int, bool] = False,
max_epochs: Optional[int] = None,
min_epochs: Optional[int] = None,
max_steps: int = -1,
min_steps: Optional[int] = None,
max_time: Optional[Union[str, timedelta, dict[str, int]]] = None,
limit_train_batches: Optional[Union[int, float]] = None,
limit_val_batches: Optional[Union[int, float]] = None,
limit_test_batches: Optional[Union[int, float]] = None,
limit_predict_batches: Optional[Union[int, float]] = None,
overfit_batches: Union[int, float] = 0.0,
val_check_interval: Optional[Union[int, float]] = None,
check_val_every_n_epoch: Optional[int] = 1,
num_sanity_val_steps: Optional[int] = None,
log_every_n_steps: Optional[int] = None,
enable_checkpointing: Optional[bool] = None,
enable_progress_bar: Optional[bool] = None,
enable_model_summary: Optional[bool] = None,
accumulate_grad_batches: int = 1,
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[str] = None,
deterministic: Optional[Union[bool, _LITERAL_WARN]] = None,
benchmark: Optional[bool] = None,
inference_mode: bool = True,
use_distributed_sampler: bool = True,
profiler: Optional[Union[Profiler, str]] = None,
detect_anomaly: bool = False,
barebones: bool = False,
plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None,
sync_batchnorm: bool = False,
reload_dataloaders_every_n_epochs: int = 0,
default_root_dir: Optional[_PATH] = None,
) -> None:
r"""Customize every aspect of training via flags.
Args:
accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "hpu", "mps", "auto")
as well as custom accelerator instances.
strategy: Supports different training strategies with aliases as well custom strategies.
Default: ``"auto"``.
devices: The devices to use. Can be set to a positive number (int or str), a sequence of device indices
(list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for
automatic selection based on the chosen accelerator. Default: ``"auto"``.
num_nodes: Number of GPU nodes for distributed training.
Default: ``1``.
precision: Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'),
16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed').
Can be used on CPU, GPU, TPUs, or HPUs.
Default: ``'32-true'``.
logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses
the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``.
``False`` will disable logging. If multiple loggers are provided, local files
(checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger.
Default: ``True``.
callbacks: Add a callback or list of callbacks.
Default: ``None``.
fast_dev_run: Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
of train, val and test to find any bugs (ie: a sort of unit test).
Default: ``False``.
max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
To enable infinite training, set ``max_epochs = -1``.
min_epochs: Force training for at least these many epochs. Disabled by default (None).
max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
``max_epochs`` to ``-1``.
min_steps: Force training for at least these number of steps. Disabled by default (``None``).
max_time: Stop training after this amount of time has passed. Disabled by default (``None``).
The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
:class:`datetime.timedelta`, or a dictionary with keys that will be passed to
:class:`datetime.timedelta`.
limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches).
Default: ``1.0``.
limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches).
Default: ``1.0``.
limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches).
Default: ``1.0``.
limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches).
Default: ``1.0``.
overfit_batches: Overfit a fraction of training/validation data (float) or a set number of batches (int).
Default: ``0.0``.
val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
batches. An ``int`` value can only be higher than the number of training batches when
``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
across epochs or during iteration-based training.
Default: ``1.0``.
check_val_every_n_epoch: Perform a validation loop after every `N` training epochs. If ``None``,
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
to be an integer value.
Default: ``1``.
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
Set it to `-1` to run all batches in all validation dataloaders.
Default: ``2``.
log_every_n_steps: How often to log within steps.
Default: ``50``.
enable_checkpointing: If ``True``, enable checkpointing.
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`.
Default: ``True``.
enable_progress_bar: Whether to enable to progress bar by default.
Default: ``True``.
enable_model_summary: Whether to enable model summarization by default.
Default: ``True``.
accumulate_grad_batches: Accumulates gradients over k batches before stepping the optimizer.
Default: 1.
gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before.
Default: ``None``.
gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will
be set to ``"norm"``.
deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms.
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
benchmark: The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
(``False`` if not manually set). If :paramref:`~lightning.pytorch.trainer.trainer.Trainer.deterministic`
is set to ``True``, this will default to ``False``. Override to manually set a different value.
Default: ``None``.
inference_mode: Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
evaluation (``validate``/``test``/``predict``).
use_distributed_sampler: Whether to wrap the DataLoader's sampler with
:class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
we don't do this automatically.
profiler: To profile individual steps during training and assist in identifying bottlenecks.
Default: ``None``.
detect_anomaly: Enable anomaly detection for the autograd engine.
Default: ``False``.
barebones: Whether to run in "barebones mode", where all features that may impact raw speed are
disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training
runs. The following features are deactivated:
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_checkpointing`,
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.logger`,
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_progress_bar`,
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.log_every_n_steps`,
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_model_summary`,
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.num_sanity_val_steps`,
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.fast_dev_run`,
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.detect_anomaly`,
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.profiler`,
:meth:`~lightning.pytorch.core.LightningModule.log`,
:meth:`~lightning.pytorch.core.LightningModule.log_dict`.
plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
Default: ``None``.
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.
Default: ``False``.
reload_dataloaders_every_n_epochs: Set to a positive integer to reload dataloaders every n epochs.
Default: ``0``.
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
Default: ``os.getcwd()``.
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
Raises:
TypeError:
If ``gradient_clip_val`` is not an int or float.
MisconfigurationException:
If ``gradient_clip_algorithm`` is invalid.
"""
super().__init__()
log.debug(f"{self.__class__.__name__}: Initializing trainer with parameters: {locals()}")
if default_root_dir is not None:
default_root_dir = os.fspath(default_root_dir)
self.barebones = barebones
if barebones:
# opt-outs
if enable_checkpointing:
raise ValueError(
f"`Trainer(barebones=True, enable_checkpointing={enable_checkpointing!r})` was passed."
" Checkpointing can impact raw speed so it is disabled in barebones mode."
)
enable_checkpointing = False
if logger is not None and logger is not False:
raise ValueError(
f"`Trainer(barebones=True, logger={logger!r})` was passed."
" Logging can impact raw speed so it is disabled in barebones mode."
)
logger = False
if enable_progress_bar:
raise ValueError(
f"`Trainer(barebones=True, enable_progress_bar={enable_progress_bar!r})` was passed."
" The progress bar can impact raw speed so it is disabled in barebones mode."
)
enable_progress_bar = False
if log_every_n_steps is not None and log_every_n_steps != 0:
raise ValueError(
f"`Trainer(barebones=True, log_every_n_steps={log_every_n_steps!r})` was passed."
" Logging can impact raw speed so it is disabled in barebones mode."
)
log_every_n_steps = 0
if enable_model_summary:
raise ValueError(
f"`Trainer(barebones=True, enable_model_summary={enable_model_summary!r})` was passed."
" Model summary can impact raw speed so it is disabled in barebones mode."
)
enable_model_summary = False
if num_sanity_val_steps is not None and num_sanity_val_steps != 0:
raise ValueError(
f"`Trainer(barebones=True, num_sanity_val_steps={num_sanity_val_steps!r})` was passed."
" Sanity checking can impact raw speed so it is disabled in barebones mode."
)
num_sanity_val_steps = 0
# opt-ins
if fast_dev_run is not False and fast_dev_run != 0:
raise ValueError(
f"`Trainer(barebones=True, fast_dev_run={fast_dev_run!r})` was passed."
" Development run is not meant for raw speed evaluation so it is disabled in barebones mode."
)
if detect_anomaly:
raise ValueError(
f"`Trainer(barebones=True, detect_anomaly={detect_anomaly!r})` was passed."
" Anomaly detection can impact raw speed so it is disabled in barebones mode."
)
if profiler is not None:
raise ValueError(
f"`Trainer(barebones=True, profiler={profiler!r})` was passed."
" Profiling can impact raw speed so it is disabled in barebones mode."
)
deactivated = (
" - Checkpointing: `Trainer(enable_checkpointing=True)`",
" - Progress bar: `Trainer(enable_progress_bar=True)`",
" - Model summary: `Trainer(enable_model_summary=True)`",
" - Logging: `Trainer(logger=True)`, `Trainer(log_every_n_steps>0)`,"
" `LightningModule.log(...)`, `LightningModule.log_dict(...)`",
" - Sanity checking: `Trainer(num_sanity_val_steps>0)`",
" - Development run: `Trainer(fast_dev_run=True)`",
" - Anomaly detection: `Trainer(detect_anomaly=True)`",
" - Profiling: `Trainer(profiler=...)`",
)
rank_zero_info(
"You are running in `Trainer(barebones=True)` mode. All features that may impact raw speed have been"
" disabled to facilitate analyzing the Trainer overhead. Specifically, the following features are"
f" deactivated:{os.linesep}{os.linesep.join(deactivated)}"
)
else:
# set the opt-out defaults
if enable_checkpointing is None:
enable_checkpointing = True
if logger is None:
logger = True
if enable_progress_bar is None:
enable_progress_bar = True
if log_every_n_steps is None:
log_every_n_steps = 50
if enable_model_summary is None:
enable_model_summary = True
if num_sanity_val_steps is None:
num_sanity_val_steps = 2
# init connectors
self._data_connector = _DataConnector(self)
self._accelerator_connector = _AcceleratorConnector(
devices=devices,
accelerator=accelerator,
strategy=strategy,
num_nodes=num_nodes,
sync_batchnorm=sync_batchnorm,
benchmark=benchmark,
use_distributed_sampler=use_distributed_sampler,
deterministic=deterministic,
precision=precision,
plugins=plugins,
)
self._logger_connector = _LoggerConnector(self)
self._callback_connector = _CallbackConnector(self)
self._checkpoint_connector = _CheckpointConnector(self)
self._signal_connector = _SignalConnector(self)
# init loops
self.fit_loop = _FitLoop(self, min_epochs=min_epochs, max_epochs=max_epochs)
self.fit_loop.epoch_loop = _TrainingEpochLoop(self, min_steps=min_steps, max_steps=max_steps)
self.validate_loop = _EvaluationLoop(
self, TrainerFn.VALIDATING, RunningStage.VALIDATING, inference_mode=inference_mode
)
self.test_loop = _EvaluationLoop(self, TrainerFn.TESTING, RunningStage.TESTING, inference_mode=inference_mode)
self.predict_loop = _PredictionLoop(self, inference_mode=inference_mode)
self.accumulate_grad_batches = accumulate_grad_batches
# init callbacks
# Declare attributes to be set in _callback_connector on_trainer_init
self._callback_connector.on_trainer_init(
callbacks,
enable_checkpointing,
enable_progress_bar,
default_root_dir,
enable_model_summary,
max_time,
)
# init data flags
self.check_val_every_n_epoch: Optional[int]
self._data_connector.on_trainer_init(
val_check_interval,
reload_dataloaders_every_n_epochs,
check_val_every_n_epoch,
)
# gradient clipping
if gradient_clip_val is not None and not isinstance(gradient_clip_val, (int, float)):
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")
if gradient_clip_algorithm is not None and not GradClipAlgorithmType.supported_type(
gradient_clip_algorithm.lower()
):
raise MisconfigurationException(
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. "
f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
)
self.gradient_clip_val: Optional[Union[int, float]] = gradient_clip_val
self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = (
GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None else None
)
if detect_anomaly:
rank_zero_info(
"You have turned on `Trainer(detect_anomaly=True)`. This will significantly slow down compute speed and"
" is recommended only for model debugging."
)
self._detect_anomaly: bool = detect_anomaly
setup._log_device_info(self)
self.should_stop = False
self.state = TrainerState()
# configure profiler
setup._init_profiler(self, profiler)
# init logger flags
self._loggers: list[Logger]
self._logger_connector.on_trainer_init(logger, log_every_n_steps)
# init debugging flags
self.val_check_batch: Union[int, float]
self.val_check_interval: Union[int, float]
self.num_sanity_val_steps: Union[int, float]
self.limit_train_batches: Union[int, float]
self.limit_val_batches: Union[int, float]
self.limit_test_batches: Union[int, float]
self.limit_predict_batches: Union[int, float]
setup._init_debugging_flags(
self,
limit_train_batches,
limit_val_batches,
limit_test_batches,
limit_predict_batches,
fast_dev_run,
overfit_batches,
val_check_interval,
num_sanity_val_steps,
)
def fit(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[_PATH] = None,
) -> None:
r"""Runs the full optimization routine.
Args:
model: Model to fit.
train_dataloaders: An iterable or collection of iterables specifying training samples.
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook.
val_dataloaders: An iterable or collection of iterables specifying validation samples.
datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook.
ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special
keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised.
Raises:
TypeError:
If ``model`` is not :class:`~lightning.pytorch.core.LightningModule` for torch version less than
2.0.0 and if ``model`` is not :class:`~lightning.pytorch.core.LightningModule` or
:class:`torch._dynamo.OptimizedModule` for torch versions greater than or equal to 2.0.0 .
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
"""
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(model, self.strategy)
self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True
call._call_and_handle_interrupt(
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)
def _fit_impl(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[_PATH] = None,
) -> None:
log.debug(f"{self.__class__.__name__}: trainer fit stage")
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloaders, LightningDataModule):
datamodule = train_dataloaders
train_dataloaders = None
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
"You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`"
)
# links data to the trainer
self._data_connector.attach_data(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)
assert self.state.fn is not None
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn,
ckpt_path,
model_provided=True,
model_connected=self.lightning_module is not None,
)
self._run(model, ckpt_path=ckpt_path)
assert self.state.stopped
self.training = False
return
def validate(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[_PATH] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
) -> _EVALUATE_OUTPUT:
r"""Perform one evaluation epoch over the validation set.
Args:
model: The model to validate.
dataloaders: An iterable or collection of iterables specifying validation samples.
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate.
If ``None`` and the model instance was passed, use the current weights.
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
if a checkpoint callback is configured.
verbose: If True, prints the validation results.
datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
Returns:
List of dictionaries with metrics logged during the validation phase, e.g., in model- or callback hooks
like :meth:`~lightning.pytorch.LightningModule.validation_step` etc.
The length of the list corresponds to the number of validation dataloaders used.
Raises:
TypeError:
If no ``model`` is passed and there was no ``LightningModule`` passed in the previous run.
If ``model`` passed is not `LightningModule` or `torch._dynamo.OptimizedModule`.
MisconfigurationException:
If both ``dataloaders`` and ``datamodule`` are passed. Pass only one of these.
RuntimeError:
If a compiled ``model`` is passed and the strategy is not supported.
"""
if model is None:
# do we still have a reference from a previous call?
if self.lightning_module is None:
raise TypeError(
"`Trainer.validate()` requires a `LightningModule` when it hasn't been passed in a previous run"
)
else:
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(self.lightning_module, self.strategy)
self.state.fn = TrainerFn.VALIDATING
self.state.status = TrainerStatus.RUNNING
self.validating = True
return call._call_and_handle_interrupt(
self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule
)
def _validate_impl(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[_PATH] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
# --------------------
# SETUP HOOK
# --------------------
log.debug(f"{self.__class__.__name__}: trainer validate stage")
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
datamodule = dataloaders
dataloaders = None
# If you supply a datamodule you can't supply val_dataloaders
if dataloaders is not None and datamodule:
raise MisconfigurationException("You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`")
if model is None:
model = self.lightning_module
model_provided = False
else:
model_provided = True
self.validate_loop.verbose = verbose
# links data to the trainer
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
assert self.state.fn is not None
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)
results = self._run(model, ckpt_path=ckpt_path)
# remove the tensors from the validation results
results = convert_tensors_to_scalars(results)
assert self.state.stopped
self.validating = False
return results
def test(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[_PATH] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
) -> _EVALUATE_OUTPUT:
r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your
test set until you want to.
Args:
model: The model to test.
dataloaders: An iterable or collection of iterables specifying test samples.
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test.
If ``None`` and the model instance was passed, use the current weights.
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
if a checkpoint callback is configured.
verbose: If True, prints the test results.
datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
Returns:
List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks
like :meth:`~lightning.pytorch.LightningModule.test_step` etc.
The length of the list corresponds to the number of test dataloaders used.
Raises:
TypeError:
If no ``model`` is passed and there was no ``LightningModule`` passed in the previous run.
If ``model`` passed is not `LightningModule` or `torch._dynamo.OptimizedModule`.
MisconfigurationException:
If both ``dataloaders`` and ``datamodule`` are passed. Pass only one of these.
RuntimeError:
If a compiled ``model`` is passed and the strategy is not supported.
"""
if model is None:
# do we still have a reference from a previous call?
if self.lightning_module is None:
raise TypeError(
"`Trainer.test()` requires a `LightningModule` when it hasn't been passed in a previous run"
)
else:
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(self.lightning_module, self.strategy)
self.state.fn = TrainerFn.TESTING
self.state.status = TrainerStatus.RUNNING
self.testing = True
return call._call_and_handle_interrupt(
self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
)
def _test_impl(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[_PATH] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
# --------------------
# SETUP HOOK
# --------------------
log.debug(f"{self.__class__.__name__}: trainer test stage")
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
datamodule = dataloaders
dataloaders = None
# If you supply a datamodule you can't supply test_dataloaders
if dataloaders is not None and datamodule:
raise MisconfigurationException("You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`")
if model is None:
model = self.lightning_module
model_provided = False
else:
model_provided = True
self.test_loop.verbose = verbose
# links data to the trainer
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
assert self.state.fn is not None
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)
results = self._run(model, ckpt_path=ckpt_path)
# remove the tensors from the test results
results = convert_tensors_to_scalars(results)
assert self.state.stopped
self.testing = False
return results
def predict(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
datamodule: Optional[LightningDataModule] = None,
return_predictions: Optional[bool] = None,
ckpt_path: Optional[_PATH] = None,
) -> Optional[_PREDICT_OUTPUT]:
r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to
perform distributed and batched predictions. Logging is disabled in the predict hooks.
Args:
model: The model to predict with.
dataloaders: An iterable or collection of iterables specifying predict samples.
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
the :class:`~lightning.pytorch.core.hooks.DataHooks.predict_dataloader` hook.
datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
the :class:`~lightning.pytorch.core.hooks.DataHooks.predict_dataloader` hook.
return_predictions: Whether to return predictions.
``True`` by default except when an accelerator that spawns processes is used (not supported).
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict.
If ``None`` and the model instance was passed, use the current weights.
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
if a checkpoint callback is configured.
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
Returns:
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
Raises:
TypeError:
If no ``model`` is passed and there was no ``LightningModule`` passed in the previous run.
If ``model`` passed is not `LightningModule` or `torch._dynamo.OptimizedModule`.
MisconfigurationException:
If both ``dataloaders`` and ``datamodule`` are passed. Pass only one of these.
RuntimeError:
If a compiled ``model`` is passed and the strategy is not supported.
See :ref:`Lightning inference section<deploy/production_basic:Predict step with your LightningModule>` for more.
"""
if model is None:
# do we still have a reference from a previous call?
if self.lightning_module is None:
raise TypeError(
"`Trainer.predict()` requires a `LightningModule` when it hasn't been passed in a previous run"
)
else:
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(self.lightning_module, self.strategy)
self.state.fn = TrainerFn.PREDICTING
self.state.status = TrainerStatus.RUNNING
self.predicting = True
return call._call_and_handle_interrupt(
self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
)
def _predict_impl(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
datamodule: Optional[LightningDataModule] = None,
return_predictions: Optional[bool] = None,
ckpt_path: Optional[_PATH] = None,
) -> Optional[_PREDICT_OUTPUT]:
# --------------------
# SETUP HOOK
# --------------------
log.debug(f"{self.__class__.__name__}: trainer predict stage")
self.predict_loop.return_predictions = return_predictions # type: ignore[assignment]
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
datamodule = dataloaders
dataloaders = None
if dataloaders is not None and datamodule:
raise MisconfigurationException("You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`")
if model is None:
model = self.lightning_module
model_provided = False
else:
model_provided = True
# links data to the trainer
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
assert self.state.fn is not None
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)
results = self._run(model, ckpt_path=ckpt_path)
assert self.state.stopped
self.predicting = False
return results
def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
if self.state.fn == TrainerFn.FITTING:
min_epochs, max_epochs = _parse_loop_limits(
self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self
)
self.fit_loop.min_epochs = min_epochs
self.fit_loop.max_epochs = max_epochs
if self.barebones:
# no progress bar in barebones can make it look like the Trainer hung
rank_zero_info(
"`Trainer(barebones=True)` started running. The progress bar is disabled so you might want to"
" manually print the progress in your model."
)
# clean hparams
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)
# attach model to the strategy
self.strategy.connect(model)
self._callback_connector._attach_model_callbacks()
self._callback_connector._attach_model_logging_functions()
_verify_loop_configurations(self)
# ----------------------------
# SET UP THE TRAINER
# ----------------------------
log.debug(f"{self.__class__.__name__}: setting up strategy environment")
self.strategy.setup_environment()
self.__setup_profiler()
log.debug(f"{self.__class__.__name__}: preparing data")
self._data_connector.prepare_data()
log.debug(f"{self.__class__.__name__}: configuring model")
call._call_configure_model(self)
call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment
# check if we should delay restoring checkpoint till later
if not self.strategy.restore_checkpoint_after_setup:
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
# reset logger connector
self._logger_connector.reset_results()
self._logger_connector.reset_metrics()
# strategy will configure model and move it to the device
self.strategy.setup(self)
# hook
if self.state.fn == TrainerFn.FITTING:
call._call_callback_hooks(self, "on_fit_start")
call._call_lightning_module_hook(self, "on_fit_start")
_log_hyperparams(self)
if self.strategy.restore_checkpoint_after_setup:
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
# restore optimizers, etc.
log.debug(f"{self.__class__.__name__}: restoring training state")
self._checkpoint_connector.restore_training_state()
self._checkpoint_connector.resume_end()
self._signal_connector.register_signal_handlers()
# ----------------------------
# RUN THE TRAINER
# ----------------------------
results = self._run_stage()
# ----------------------------
# POST-Training CLEAN UP
# ----------------------------
log.debug(f"{self.__class__.__name__}: trainer tearing down")
self._teardown()
if self.state.fn == TrainerFn.FITTING:
call._call_callback_hooks(self, "on_fit_end")
call._call_lightning_module_hook(self, "on_fit_end")
log.debug(f"{self.__class__.__name__}: calling teardown hooks")
call._call_teardown_hook(self)
self.state.status = TrainerStatus.FINISHED
self.state.stage = None
return results