|
27 | 27 | has_iterable_dataset,
|
28 | 28 | suggested_max_num_workers,
|
29 | 29 | )
|
30 |
| -from lightning.fabric.utilities.distributed import DistributedSamplerWrapper |
| 30 | +from lightning.fabric.utilities.distributed import DistributedSamplerWrapper, _InfiniteBarrier |
31 | 31 | from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper
|
32 | 32 | from lightning.pytorch.trainer import call
|
33 | 33 | from lightning.pytorch.trainer.states import RunningStage, TrainerFn
|
@@ -86,17 +86,18 @@ def prepare_data(self) -> None:
|
86 | 86 | datamodule = trainer.datamodule
|
87 | 87 | lightning_module = trainer.lightning_module
|
88 | 88 | # handle datamodule prepare data:
|
89 |
| - # check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data |
90 |
| - if datamodule is not None: |
91 |
| - dm_prepare_data_per_node = datamodule.prepare_data_per_node |
92 |
| - if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero): |
93 |
| - call._call_lightning_datamodule_hook(trainer, "prepare_data") |
| 89 | + if datamodule is not None and is_overridden("prepare_data", datamodule): |
| 90 | + prepare_data_per_node = datamodule.prepare_data_per_node |
| 91 | + with _InfiniteBarrier(): |
| 92 | + if (prepare_data_per_node and local_rank_zero) or (not prepare_data_per_node and global_rank_zero): |
| 93 | + call._call_lightning_datamodule_hook(trainer, "prepare_data") |
| 94 | + |
94 | 95 | # handle lightning module prepare data:
|
95 |
| - # check for prepare_data_per_node before calling lightning_module.prepare_data |
96 |
| - if lightning_module is not None: |
97 |
| - lm_prepare_data_per_node = lightning_module.prepare_data_per_node |
98 |
| - if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero): |
99 |
| - call._call_lightning_module_hook(trainer, "prepare_data") |
| 96 | + if lightning_module is not None and is_overridden("prepare_data", lightning_module): |
| 97 | + prepare_data_per_node = lightning_module.prepare_data_per_node |
| 98 | + with _InfiniteBarrier(): |
| 99 | + if (prepare_data_per_node and local_rank_zero) or (not prepare_data_per_node and global_rank_zero): |
| 100 | + call._call_lightning_module_hook(trainer, "prepare_data") |
100 | 101 |
|
101 | 102 | def attach_data(
|
102 | 103 | self,
|
|
0 commit comments