Skip to content

Add option distributed_size to MegatronDistributedFusedAdam #12728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,27 @@ def get_fp8_scale_and_amax(tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return _get_fp8_scale_and_amax_impl(tensor)


_distribute_within_nodes_pgs = {}
_distributed_pgs = {}


def create_distribute_within_nodes_pgs():
"""Create process groups for distributing with nodes.
def create_distributed_pgs(*, distributed_size: int) -> Dict:
"""Create process groups for distributing within multiple devices.

User can reuse this function to reorder communicators for SHArP.

Arguments:
distributed_size (int): the number of devices to distribute optimizer
state over.

"""
global _distribute_within_nodes_pgs
global _distributed_pgs
assert torch.distributed.is_initialized()
if _distribute_within_nodes_pgs:
return _distribute_within_nodes_pgs
if _distributed_pgs:
return _distributed_pgs

world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
devices = torch.cuda.device_count()
devices = distributed_size
nodes = world_size // devices

if nodes * devices != world_size:
Expand All @@ -167,7 +172,7 @@ def create_distribute_within_nodes_pgs():
# we have to expose redundant_process_group to user.
# User has too invoke allreduce through redundant_process_group
# before all other communicators to lock SHArP tree.
_distribute_within_nodes_pgs = {
_distributed_pgs = {
'world_size': world_size,
'rank': rank,
'devices': devices,
Expand All @@ -177,7 +182,16 @@ def create_distribute_within_nodes_pgs():
'distributed_process_group': distributed_pgs[node_id],
'redundant_process_group': redundant_pgs[device_id],
}
return _distribute_within_nodes_pgs
return _distributed_pgs


def create_distribute_within_nodes_pgs():
"""Create process groups for distributing within nodes.

User can reuse this function to reorder communicators for SHArP.
This funcion is kept for backward compatibility.
"""
return create_distributed_pgs(distributed_size=torch.cuda.device_count())


class MegatronDistributedFusedAdam(DistributedFusedAdam):
Expand All @@ -197,6 +211,8 @@ class MegatronDistributedFusedAdam(DistributedFusedAdam):
but requires larger memory than distributing within all
ranks, especially for pure data parallel models.
(default: False).
distributed_size (int, optional): the number of devices to
distribute optimizer state over.
lock_timeout (float, optional): timeout for callback mutex in
seconds.
**kwargs: keyword arguments to pass to Apex
Expand All @@ -209,10 +225,17 @@ def __init__(
params: Union[Iterable[torch.nn.Parameter], Iterable[dict]],
disable_distributed_parameters: bool = False,
distribute_within_nodes: bool = False,
distributed_size: Optional[int] = None,
lock_timeout: Optional[float] = None,
**kwargs,
):

# Update distributed_size settings
if distribute_within_nodes:
if distributed_size is not None and distributed_size != torch.cuda.device_count():
raise ValueError("Inconsistent distributed_size value")
distributed_size = torch.cuda.device_count()

# Initialize process groups
if 'process_group' not in kwargs and parallel_state.is_initialized():
kwargs['process_group'] = parallel_state.get_data_parallel_group(with_context_parallel=True)
Expand All @@ -222,13 +245,13 @@ def __init__(
self_groups = [torch.distributed.new_group(ranks=[i]) for i in range(world_size)]
kwargs['distributed_process_group'] = self_groups[rank]
kwargs['redundant_process_group'] = kwargs['process_group']
elif distribute_within_nodes:
dist_pg_infos = create_distribute_within_nodes_pgs()
elif distributed_size is not None:
dist_pg_infos = create_distributed_pgs(distributed_size=distributed_size)
if dist_pg_infos:
kwargs['distributed_process_group'] = dist_pg_infos['distributed_process_group']
kwargs['redundant_process_group'] = dist_pg_infos['redundant_process_group']
global _distribute_within_nodes_pgs
_distribute_within_nodes_pgs = {}
global _distributed_pgs
_distributed_pgs = {}

# Make sure dtypes are in right type
for keyword in ('dtype', 'grad_sync_dtype', 'param_sync_dtype'):
Expand Down