13
13
# limitations under the License.
14
14
import os
15
15
from collections import Counter
16
+ from collections .abc import Iterable
16
17
from typing import Any , Optional , Union , cast
17
18
18
19
import torch
@@ -102,7 +103,7 @@ def __init__(
102
103
devices : Union [list [int ], str , int ] = "auto" ,
103
104
num_nodes : int = 1 ,
104
105
precision : Optional [_PRECISION_INPUT ] = None ,
105
- plugins : Optional [Union [_PLUGIN_INPUT , list [_PLUGIN_INPUT ]]] = None ,
106
+ plugins : Optional [Union [_PLUGIN_INPUT , Iterable [_PLUGIN_INPUT ]]] = None ,
106
107
) -> None :
107
108
# These arguments can be set through environment variables set by the CLI
108
109
accelerator = self ._argument_from_env ("accelerator" , accelerator , default = "auto" )
@@ -165,7 +166,7 @@ def _check_config_and_set_final_flags(
165
166
strategy : Union [str , Strategy ],
166
167
accelerator : Union [str , Accelerator ],
167
168
precision : Optional [_PRECISION_INPUT ],
168
- plugins : Optional [Union [_PLUGIN_INPUT , list [_PLUGIN_INPUT ]]],
169
+ plugins : Optional [Union [_PLUGIN_INPUT , Iterable [_PLUGIN_INPUT ]]],
169
170
) -> None :
170
171
"""This method checks:
171
172
@@ -180,7 +181,7 @@ def _check_config_and_set_final_flags(
180
181
181
182
"""
182
183
if plugins is not None :
183
- plugins = [plugins ] if not isinstance (plugins , list ) else plugins
184
+ plugins = [plugins ] if not isinstance (plugins , Iterable ) else plugins
184
185
185
186
if isinstance (strategy , str ):
186
187
strategy = strategy .lower ()
0 commit comments