Skip to content

Commit 3c5a465

Browse files
authored
Create barrier without timeout in prepare_data() (#19448)
1 parent f2f9978 commit 3c5a465

File tree

9 files changed

+110
-30
lines changed

9 files changed

+110
-30
lines changed

src/lightning/fabric/CHANGELOG.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717

1818
### Changed
1919

20-
- Rename `lightning run model` to `fabric run model` ([#19442](https://github.com/Lightning-AI/pytorch-lightning/pull/19442))
20+
- Renamed `lightning run model` to `fabric run model` ([#19442](https://github.com/Lightning-AI/pytorch-lightning/pull/19442))
2121

22-
-
22+
23+
- The `Fabric.rank_zero_first` context manager now uses a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))
2324

2425
-
2526

src/lightning/fabric/fabric.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
has_iterable_dataset,
6666
)
6767
from lightning.fabric.utilities.device_dtype_mixin import _update_properties
68-
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
68+
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper, _InfiniteBarrier
6969
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
7070
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
7171
from lightning.fabric.utilities.registry import _load_external_callbacks
@@ -636,12 +636,12 @@ def rank_zero_first(self, local: bool = False) -> Generator:
636636
637637
"""
638638
rank = self.local_rank if local else self.global_rank
639-
if rank > 0:
640-
self.barrier()
641-
yield
642-
if rank == 0:
643-
self.barrier()
644-
self.barrier()
639+
with _InfiniteBarrier() as barrier:
640+
if rank > 0:
641+
barrier()
642+
yield
643+
if rank == 0:
644+
barrier()
645645

646646
def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> ContextManager:
647647
r"""Skip gradient synchronization during backward to avoid redundant communication overhead.

src/lightning/fabric/utilities/distributed.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import time
55
from contextlib import nullcontext
6+
from datetime import timedelta
67
from pathlib import Path
78
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union
89

@@ -11,7 +12,7 @@
1112
from lightning_utilities.core.imports import package_available
1213
from torch import Tensor
1314
from torch.utils.data import Dataset, DistributedSampler, Sampler
14-
from typing_extensions import override
15+
from typing_extensions import Self, override
1516

1617
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
1718
from lightning.fabric.utilities.data import _num_cpus_available
@@ -383,3 +384,32 @@ def _distributed_is_initialized() -> bool:
383384
# https://github.com/pytorch/pytorch/blob/v2.1.0/torch/distributed/__init__.py#L25
384385
# this might happen to MacOS builds from source (default) or any build from source that sets `USE_DISTRIBUTED=0`
385386
return torch.distributed.is_available() and torch.distributed.is_initialized()
387+
388+
389+
class _InfiniteBarrier:
390+
"""A barrier with an infinite timeout.
391+
392+
Creates a new process group with the GLOO backend with a very high timeout that makes the barrier effectively wait
393+
forever. This is useful in cases where you want to execute a long-running operation on a subset of ranks that should
394+
not be subject to the regular collective timeout.
395+
396+
"""
397+
398+
def __init__(self) -> None:
399+
self.group = None
400+
self.barrier = lambda: None
401+
402+
def __call__(self) -> None:
403+
self.barrier()
404+
405+
def __enter__(self) -> Self:
406+
if _distributed_is_initialized():
407+
# Create a barrier with an 'infinite' timeout (only reliably possible over the GLOO backend)
408+
self.group = torch.distributed.new_group(backend="gloo", timeout=timedelta(days=10000))
409+
self.barrier = self.group.monitored_barrier
410+
return self
411+
412+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
413+
self.barrier()
414+
if self.group is not None:
415+
torch.distributed.destroy_process_group(self.group)

src/lightning/pytorch/CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
### Changed
1818

19-
-
19+
- The `prepare_data()` hook in `LightningModule` and `LightningDataModule` is now subject to a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))
2020

2121
-
2222

src/lightning/pytorch/trainer/connectors/data_connector.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
has_iterable_dataset,
2828
suggested_max_num_workers,
2929
)
30-
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
30+
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper, _InfiniteBarrier
3131
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper
3232
from lightning.pytorch.trainer import call
3333
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
@@ -86,17 +86,18 @@ def prepare_data(self) -> None:
8686
datamodule = trainer.datamodule
8787
lightning_module = trainer.lightning_module
8888
# handle datamodule prepare data:
89-
# check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data
90-
if datamodule is not None:
91-
dm_prepare_data_per_node = datamodule.prepare_data_per_node
92-
if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero):
93-
call._call_lightning_datamodule_hook(trainer, "prepare_data")
89+
if datamodule is not None and is_overridden("prepare_data", datamodule):
90+
prepare_data_per_node = datamodule.prepare_data_per_node
91+
with _InfiniteBarrier():
92+
if (prepare_data_per_node and local_rank_zero) or (not prepare_data_per_node and global_rank_zero):
93+
call._call_lightning_datamodule_hook(trainer, "prepare_data")
94+
9495
# handle lightning module prepare data:
95-
# check for prepare_data_per_node before calling lightning_module.prepare_data
96-
if lightning_module is not None:
97-
lm_prepare_data_per_node = lightning_module.prepare_data_per_node
98-
if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero):
99-
call._call_lightning_module_hook(trainer, "prepare_data")
96+
if lightning_module is not None and is_overridden("prepare_data", lightning_module):
97+
prepare_data_per_node = lightning_module.prepare_data_per_node
98+
with _InfiniteBarrier():
99+
if (prepare_data_per_node and local_rank_zero) or (not prepare_data_per_node and global_rank_zero):
100+
call._call_lightning_module_hook(trainer, "prepare_data")
100101

101102
def attach_data(
102103
self,

tests/tests_fabric/test_fabric.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from unittest import mock
1717
from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call
1818

19+
import lightning.fabric
1920
import pytest
2021
import torch
2122
import torch.distributed
@@ -1129,24 +1130,25 @@ def test_all_reduce():
11291130
fabric._strategy.all_reduce.assert_has_calls([call(torch.tensor(4), **defaults), call(torch.tensor(5), **defaults)])
11301131

11311132

1132-
def test_rank_zero_first():
1133+
def test_rank_zero_first(monkeypatch):
11331134
"""Test that rank 0 completes first before all other processes can execute under `.rank_zero_first()`."""
11341135

11351136
def record_calls_for_rank(rank):
11361137
call_order = []
11371138

11381139
fabric = Fabric()
11391140
fabric._strategy = Mock(global_rank=rank)
1140-
fabric.barrier = Mock(side_effect=lambda *_: call_order.append("barrier"))
1141+
barrier_mock = MagicMock(side_effect=lambda *_: call_order.append("barrier"))
1142+
monkeypatch.setattr(lightning.fabric.utilities.distributed._InfiniteBarrier, "__call__", barrier_mock)
11411143
target = Mock(run=Mock(side_effect=lambda *_: call_order.append("run")))
11421144

11431145
with fabric.rank_zero_first():
11441146
target.run()
11451147

11461148
return call_order
11471149

1148-
assert record_calls_for_rank(0) == ["run", "barrier", "barrier"]
1149-
assert record_calls_for_rank(1) == ["barrier", "run", "barrier"]
1150+
assert record_calls_for_rank(0) == ["run", "barrier"]
1151+
assert record_calls_for_rank(1) == ["barrier", "run"]
11501152

11511153

11521154
@pytest.mark.parametrize(("clip_val", "max_norm"), [(1e-3, None), (None, 1)])

tests/tests_fabric/utilities/test_distributed.py

+28
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
1414
from lightning.fabric.utilities.distributed import (
1515
_gather_all_tensors,
16+
_InfiniteBarrier,
1617
_set_num_threads_if_needed,
1718
_suggested_max_num_threads,
1819
_sync_ddp,
@@ -196,3 +197,30 @@ def test_set_num_threads_if_needed(_, set_num_threads_mock, num_processes, expec
196197
_set_num_threads_if_needed(1)
197198
set_num_threads_mock.assert_not_called()
198199
assert os.environ["OMP_NUM_THREADS"] == str(expected)
200+
201+
202+
def test_infinite_barrier():
203+
# distributed not available
204+
barrier = _InfiniteBarrier()
205+
assert barrier.group is None
206+
with mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=False):
207+
barrier.__enter__()
208+
assert barrier.group is None
209+
barrier()
210+
barrier.__exit__(None, None, None)
211+
assert barrier.group is None
212+
213+
# distributed available
214+
barrier = _InfiniteBarrier()
215+
with mock.patch(
216+
"lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True
217+
), mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock:
218+
barrier.__enter__()
219+
dist_mock.new_group.assert_called_once()
220+
assert barrier.barrier == barrier.group.monitored_barrier
221+
assert barrier.barrier.call_count == 0
222+
barrier()
223+
assert barrier.barrier.call_count == 1
224+
barrier.__exit__(None, None, None)
225+
assert barrier.barrier.call_count == 2
226+
dist_mock.destroy_process_group.assert_called_once()

tests/tests_pytorch/core/test_datamodules.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@
4040
@mock.patch("lightning.pytorch.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
4141
@mock.patch("lightning.pytorch.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
4242
def test_can_prepare_data(local_rank, node_rank):
43-
dm = Mock(spec=LightningDataModule)
43+
class MyDataModule(LightningDataModule):
44+
def prepare_data(self):
45+
pass
46+
47+
dm = MyDataModule()
48+
dm.prepare_data = Mock(wraps=dm.prepare_data)
49+
4450
dm.prepare_data_per_node = True
4551
trainer = Trainer()
4652
trainer.datamodule = dm
@@ -56,7 +62,7 @@ def test_can_prepare_data(local_rank, node_rank):
5662
dm.prepare_data.assert_called_once()
5763

5864
# local rank = 1 (False)
59-
dm.reset_mock()
65+
dm.prepare_data.reset_mock()
6066
local_rank.return_value = 1
6167
assert trainer.local_rank == 1
6268

@@ -65,7 +71,7 @@ def test_can_prepare_data(local_rank, node_rank):
6571

6672
# prepare_data_per_node = False (prepare across all nodes)
6773
# global rank = 0 (True)
68-
dm.reset_mock()
74+
dm.prepare_data.reset_mock()
6975
dm.prepare_data_per_node = False
7076
node_rank.return_value = 0
7177
local_rank.return_value = 0
@@ -74,7 +80,7 @@ def test_can_prepare_data(local_rank, node_rank):
7480
dm.prepare_data.assert_called_once()
7581

7682
# global rank = 1 (False)
77-
dm.reset_mock()
83+
dm.prepare_data.reset_mock()
7884
node_rank.return_value = 1
7985
local_rank.return_value = 0
8086

@@ -465,6 +471,10 @@ class CustomBoringDataModule(BoringDataModule):
465471
def state_dict(self):
466472
return {"temp": 1}
467473

474+
# override so that it gets called
475+
def prepare_data(self):
476+
pass
477+
468478
model = BoringModel()
469479
dm = CustomBoringDataModule()
470480
trainer = get_trainer()

tests/tests_pytorch/models/test_hooks.py

+8
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def call(hook, fn, *args, **kwargs):
4848
update_wrapper(partial_h, attr)
4949
setattr(self, h, partial_h)
5050

51+
# override so that it gets called
52+
def prepare_data(self):
53+
...
54+
5155

5256
@pytest.mark.parametrize("max_steps", [1, 2, 3])
5357
def test_on_before_zero_grad_called(tmpdir, max_steps):
@@ -407,6 +411,10 @@ def on_test_model_train(self):
407411
def on_predict_model_train(self):
408412
...
409413

414+
# override so that it gets called
415+
def prepare_data(self):
416+
...
417+
410418

411419
@pytest.mark.parametrize(
412420
"kwargs",

0 commit comments

Comments
 (0)