Skip to content

Commit 14bef51

Browse files
committed
Add option distributed_size to MegatronDistributedFusedAdam
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
1 parent 8eab68c commit 14bef51

File tree

1 file changed

+34
-13
lines changed

1 file changed

+34
-13
lines changed

nemo/core/optim/distributed_adam.py

+34-13
Original file line numberDiff line numberDiff line change
@@ -126,22 +126,27 @@ def get_fp8_scale_and_amax(tensor) -> Tuple[torch.Tensor, torch.Tensor]:
126126
return _get_fp8_scale_and_amax_impl(tensor)
127127

128128

129-
_distribute_within_nodes_pgs = {}
129+
_distributed_pgs = {}
130130

131131

132-
def create_distribute_within_nodes_pgs():
133-
"""Create process groups for distributing with nodes.
132+
def create_distributed_pgs(*, distributed_size: int) -> Dict:
133+
"""Create process groups for distributing within multiple devices.
134134
135135
User can reuse this function to reorder communicators for SHArP.
136+
137+
Arguments:
138+
distributed_size (int): the number of devices to distribute optimizer
139+
state over.
140+
136141
"""
137-
global _distribute_within_nodes_pgs
142+
global _distributed_pgs
138143
assert torch.distributed.is_initialized()
139-
if _distribute_within_nodes_pgs:
140-
return _distribute_within_nodes_pgs
144+
if _distributed_pgs:
145+
return _distributed_pgs
141146

142147
world_size = torch.distributed.get_world_size()
143148
rank = torch.distributed.get_rank()
144-
devices = torch.cuda.device_count()
149+
devices = distributed_size
145150
nodes = world_size // devices
146151

147152
if nodes * devices != world_size:
@@ -167,7 +172,7 @@ def create_distribute_within_nodes_pgs():
167172
# we have to expose redundant_process_group to user.
168173
# User has too invoke allreduce through redundant_process_group
169174
# before all other communicators to lock SHArP tree.
170-
_distribute_within_nodes_pgs = {
175+
_distributed_pgs = {
171176
'world_size': world_size,
172177
'rank': rank,
173178
'devices': devices,
@@ -177,7 +182,16 @@ def create_distribute_within_nodes_pgs():
177182
'distributed_process_group': distributed_pgs[node_id],
178183
'redundant_process_group': redundant_pgs[device_id],
179184
}
180-
return _distribute_within_nodes_pgs
185+
return _distributed_pgs
186+
187+
188+
def create_distribute_within_nodes_pgs():
189+
"""Create process groups for distributing within nodes.
190+
191+
User can reuse this function to reorder communicators for SHArP.
192+
This funcion is kept for backward compatibility.
193+
"""
194+
return create_distributed_pgs(distributed_size=torch.cuda.device_count())
181195

182196

183197
class MegatronDistributedFusedAdam(DistributedFusedAdam):
@@ -209,10 +223,17 @@ def __init__(
209223
params: Union[Iterable[torch.nn.Parameter], Iterable[dict]],
210224
disable_distributed_parameters: bool = False,
211225
distribute_within_nodes: bool = False,
226+
distributed_size: Optional[int] = None,
212227
lock_timeout: Optional[float] = None,
213228
**kwargs,
214229
):
215230

231+
# Update distributed_size settings
232+
if distribute_within_nodes:
233+
if distributed_size is not None and distributed_size != torch.cuda.device_count():
234+
raise ValueError("Inconsistent distributed_size value")
235+
distributed_size = torch.cuda.device_count()
236+
216237
# Initialize process groups
217238
if 'process_group' not in kwargs and parallel_state.is_initialized():
218239
kwargs['process_group'] = parallel_state.get_data_parallel_group(with_context_parallel=True)
@@ -222,13 +243,13 @@ def __init__(
222243
self_groups = [torch.distributed.new_group(ranks=[i]) for i in range(world_size)]
223244
kwargs['distributed_process_group'] = self_groups[rank]
224245
kwargs['redundant_process_group'] = kwargs['process_group']
225-
elif distribute_within_nodes:
226-
dist_pg_infos = create_distribute_within_nodes_pgs()
246+
elif distributed_size is not None:
247+
dist_pg_infos = create_distributed_pgs(distributed_size=distributed_size)
227248
if dist_pg_infos:
228249
kwargs['distributed_process_group'] = dist_pg_infos['distributed_process_group']
229250
kwargs['redundant_process_group'] = dist_pg_infos['redundant_process_group']
230-
global _distribute_within_nodes_pgs
231-
_distribute_within_nodes_pgs = {}
251+
global _distributed_pgs
252+
_distributed_pgs = {}
232253

233254
# Make sure dtypes are in right type
234255
for keyword in ('dtype', 'grad_sync_dtype', 'param_sync_dtype'):

0 commit comments

Comments
 (0)