|
18 | 18 | import platform
|
19 | 19 | from collections.abc import Mapping
|
20 | 20 | from contextlib import AbstractContextManager, ExitStack
|
| 21 | +from datetime import timedelta |
21 | 22 | from itertools import chain
|
22 | 23 | from pathlib import Path
|
23 | 24 | from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
29 | 30 | from typing_extensions import override
|
30 | 31 |
|
31 | 32 | from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
|
| 33 | +from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout |
32 | 34 | from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
|
33 | 35 | from lightning.fabric.plugins.precision import Precision
|
34 | 36 | from lightning.fabric.strategies.ddp import DDPStrategy
|
@@ -97,6 +99,7 @@ def __init__(
|
97 | 99 | load_full_weights: bool = False,
|
98 | 100 | precision: Optional[Precision] = None,
|
99 | 101 | process_group_backend: Optional[str] = None,
|
| 102 | + timeout: Optional[timedelta] = default_pg_timeout, |
100 | 103 | ) -> None:
|
101 | 104 | """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
|
102 | 105 | billion parameter models. `For more information: https://pytorch-
|
@@ -241,6 +244,7 @@ def __init__(
|
241 | 244 | process_group_backend=process_group_backend,
|
242 | 245 | )
|
243 | 246 | self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally
|
| 247 | + self._timeout: Optional[timedelta] = timeout |
244 | 248 |
|
245 | 249 | self.config = self._load_config(config)
|
246 | 250 | if self.config is None:
|
@@ -648,7 +652,9 @@ def _init_deepspeed_distributed(self) -> None:
|
648 | 652 | f"MEMBER: {self.global_rank + 1}/{self.world_size}"
|
649 | 653 | )
|
650 | 654 | self._process_group_backend = self._get_process_group_backend()
|
651 |
| - deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port) |
| 655 | + deepspeed.init_distributed( |
| 656 | + self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout |
| 657 | + ) |
652 | 658 |
|
653 | 659 | def _set_node_environment_variables(self) -> None:
|
654 | 660 | assert self.cluster_environment is not None
|
|
0 commit comments