Skip to content

Commit 26b5e80

Browse files
committed
Change signature and make the equivalent changes to Fabric connector
1 parent a771423 commit 26b5e80

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

src/lightning/fabric/connector.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
from collections import Counter
16+
from collections.abc import Iterable
1617
from typing import Any, Optional, Union, cast
1718

1819
import torch
@@ -102,7 +103,7 @@ def __init__(
102103
devices: Union[list[int], str, int] = "auto",
103104
num_nodes: int = 1,
104105
precision: Optional[_PRECISION_INPUT] = None,
105-
plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None,
106+
plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]] = None,
106107
) -> None:
107108
# These arguments can be set through environment variables set by the CLI
108109
accelerator = self._argument_from_env("accelerator", accelerator, default="auto")
@@ -165,7 +166,7 @@ def _check_config_and_set_final_flags(
165166
strategy: Union[str, Strategy],
166167
accelerator: Union[str, Accelerator],
167168
precision: Optional[_PRECISION_INPUT],
168-
plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]],
169+
plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]],
169170
) -> None:
170171
"""This method checks:
171172
@@ -180,7 +181,7 @@ def _check_config_and_set_final_flags(
180181
181182
"""
182183
if plugins is not None:
183-
plugins = [plugins] if not isinstance(plugins, list) else plugins
184+
plugins = [plugins] if not isinstance(plugins, Iterable) else plugins
184185

185186
if isinstance(strategy, str):
186187
strategy = strategy.lower()

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(
7979
num_nodes: int = 1,
8080
accelerator: Union[str, Accelerator] = "auto",
8181
strategy: Union[str, Strategy] = "auto",
82-
plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None,
82+
plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]] = None,
8383
precision: Optional[_PRECISION_INPUT] = None,
8484
sync_batchnorm: bool = False,
8585
benchmark: Optional[bool] = None,
@@ -167,7 +167,7 @@ def _check_config_and_set_final_flags(
167167
strategy: Union[str, Strategy],
168168
accelerator: Union[str, Accelerator],
169169
precision: Optional[_PRECISION_INPUT],
170-
plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]],
170+
plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]],
171171
sync_batchnorm: bool,
172172
) -> None:
173173
"""This method checks:

0 commit comments

Comments
 (0)