@@ -126,22 +126,27 @@ def get_fp8_scale_and_amax(tensor) -> Tuple[torch.Tensor, torch.Tensor]:
126
126
return _get_fp8_scale_and_amax_impl (tensor )
127
127
128
128
129
- _distribute_within_nodes_pgs = {}
129
+ _distributed_pgs = {}
130
130
131
131
132
- def create_distribute_within_nodes_pgs () :
133
- """Create process groups for distributing with nodes .
132
+ def create_distributed_pgs ( * , distributed_size : int ) -> Dict :
133
+ """Create process groups for distributing within multiple devices .
134
134
135
135
User can reuse this function to reorder communicators for SHArP.
136
+
137
+ Arguments:
138
+ distributed_size (int): the number of devices to distribute optimizer
139
+ state over.
140
+
136
141
"""
137
- global _distribute_within_nodes_pgs
142
+ global _distributed_pgs
138
143
assert torch .distributed .is_initialized ()
139
- if _distribute_within_nodes_pgs :
140
- return _distribute_within_nodes_pgs
144
+ if _distributed_pgs :
145
+ return _distributed_pgs
141
146
142
147
world_size = torch .distributed .get_world_size ()
143
148
rank = torch .distributed .get_rank ()
144
- devices = torch . cuda . device_count ()
149
+ devices = distributed_size
145
150
nodes = world_size // devices
146
151
147
152
if nodes * devices != world_size :
@@ -167,7 +172,7 @@ def create_distribute_within_nodes_pgs():
167
172
# we have to expose redundant_process_group to user.
168
173
# User has too invoke allreduce through redundant_process_group
169
174
# before all other communicators to lock SHArP tree.
170
- _distribute_within_nodes_pgs = {
175
+ _distributed_pgs = {
171
176
'world_size' : world_size ,
172
177
'rank' : rank ,
173
178
'devices' : devices ,
@@ -177,7 +182,16 @@ def create_distribute_within_nodes_pgs():
177
182
'distributed_process_group' : distributed_pgs [node_id ],
178
183
'redundant_process_group' : redundant_pgs [device_id ],
179
184
}
180
- return _distribute_within_nodes_pgs
185
+ return _distributed_pgs
186
+
187
+
188
+ def create_distribute_within_nodes_pgs ():
189
+ """Create process groups for distributing within nodes.
190
+
191
+ User can reuse this function to reorder communicators for SHArP.
192
+ This funcion is kept for backward compatibility.
193
+ """
194
+ return create_distributed_pgs (distributed_size = torch .cuda .device_count ())
181
195
182
196
183
197
class MegatronDistributedFusedAdam (DistributedFusedAdam ):
@@ -209,10 +223,17 @@ def __init__(
209
223
params : Union [Iterable [torch .nn .Parameter ], Iterable [dict ]],
210
224
disable_distributed_parameters : bool = False ,
211
225
distribute_within_nodes : bool = False ,
226
+ distributed_size : Optional [int ] = None ,
212
227
lock_timeout : Optional [float ] = None ,
213
228
** kwargs ,
214
229
):
215
230
231
+ # Update distributed_size settings
232
+ if distribute_within_nodes :
233
+ if distributed_size is not None and distributed_size != torch .cuda .device_count ():
234
+ raise ValueError ("Inconsistent distributed_size value" )
235
+ distributed_size = torch .cuda .device_count ()
236
+
216
237
# Initialize process groups
217
238
if 'process_group' not in kwargs and parallel_state .is_initialized ():
218
239
kwargs ['process_group' ] = parallel_state .get_data_parallel_group (with_context_parallel = True )
@@ -222,13 +243,13 @@ def __init__(
222
243
self_groups = [torch .distributed .new_group (ranks = [i ]) for i in range (world_size )]
223
244
kwargs ['distributed_process_group' ] = self_groups [rank ]
224
245
kwargs ['redundant_process_group' ] = kwargs ['process_group' ]
225
- elif distribute_within_nodes :
226
- dist_pg_infos = create_distribute_within_nodes_pgs ( )
246
+ elif distributed_size is not None :
247
+ dist_pg_infos = create_distributed_pgs ( distributed_size = distributed_size )
227
248
if dist_pg_infos :
228
249
kwargs ['distributed_process_group' ] = dist_pg_infos ['distributed_process_group' ]
229
250
kwargs ['redundant_process_group' ] = dist_pg_infos ['redundant_process_group' ]
230
- global _distribute_within_nodes_pgs
231
- _distribute_within_nodes_pgs = {}
251
+ global _distributed_pgs
252
+ _distributed_pgs = {}
232
253
233
254
# Make sure dtypes are in right type
234
255
for keyword in ('dtype' , 'grad_sync_dtype' , 'param_sync_dtype' ):
0 commit comments