Skip to content

Commit 9983f3a

Browse files
jedyang97pre-commit-ci[bot]lantiga
authored
Add timeout to DeepSpeedStrategy (#20474)
* allow user to pass kwargs to DeepSpeedStrategy * Update deepspeed.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update deepspeed.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make timeout explicit in DeepSpeedStrategy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
1 parent 1c4612e commit 9983f3a

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

src/lightning/fabric/strategies/deepspeed.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import platform
1919
from collections.abc import Mapping
2020
from contextlib import AbstractContextManager, ExitStack
21+
from datetime import timedelta
2122
from itertools import chain
2223
from pathlib import Path
2324
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
@@ -29,6 +30,7 @@
2930
from typing_extensions import override
3031

3132
from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
33+
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
3234
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
3335
from lightning.fabric.plugins.precision import Precision
3436
from lightning.fabric.strategies.ddp import DDPStrategy
@@ -97,6 +99,7 @@ def __init__(
9799
load_full_weights: bool = False,
98100
precision: Optional[Precision] = None,
99101
process_group_backend: Optional[str] = None,
102+
timeout: Optional[timedelta] = default_pg_timeout,
100103
) -> None:
101104
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
102105
billion parameter models. `For more information: https://pytorch-
@@ -241,6 +244,7 @@ def __init__(
241244
process_group_backend=process_group_backend,
242245
)
243246
self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally
247+
self._timeout: Optional[timedelta] = timeout
244248

245249
self.config = self._load_config(config)
246250
if self.config is None:
@@ -648,7 +652,9 @@ def _init_deepspeed_distributed(self) -> None:
648652
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
649653
)
650654
self._process_group_backend = self._get_process_group_backend()
651-
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
655+
deepspeed.init_distributed(
656+
self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
657+
)
652658

653659
def _set_node_environment_variables(self) -> None:
654660
assert self.cluster_environment is not None

src/lightning/pytorch/strategies/deepspeed.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from collections import OrderedDict
2020
from collections.abc import Generator, Mapping
2121
from contextlib import contextmanager
22+
from datetime import timedelta
2223
from pathlib import Path
2324
from typing import TYPE_CHECKING, Any, Optional, Union
2425

@@ -30,6 +31,7 @@
3031

3132
import lightning.pytorch as pl
3233
from lightning.fabric.plugins import ClusterEnvironment
34+
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
3335
from lightning.fabric.strategies import _StrategyRegistry
3436
from lightning.fabric.strategies.deepspeed import (
3537
_DEEPSPEED_AVAILABLE,
@@ -119,6 +121,7 @@ def __init__(
119121
load_full_weights: bool = False,
120122
precision_plugin: Optional[Precision] = None,
121123
process_group_backend: Optional[str] = None,
124+
timeout: Optional[timedelta] = default_pg_timeout,
122125
) -> None:
123126
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
124127
billion parameter models. `For more information: https://pytorch-
@@ -264,6 +267,7 @@ def __init__(
264267
precision_plugin=precision_plugin,
265268
process_group_backend=process_group_backend,
266269
)
270+
self._timeout: Optional[timedelta] = timeout
267271

268272
self.config = self._load_config(config)
269273
if self.config is None:
@@ -364,7 +368,9 @@ def _init_deepspeed_distributed(self) -> None:
364368
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
365369
)
366370
self._process_group_backend = self._get_process_group_backend()
367-
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
371+
deepspeed.init_distributed(
372+
self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
373+
)
368374

369375
def _set_node_environment_variables(self) -> None:
370376
assert self.cluster_environment is not None

0 commit comments

Comments
 (0)