Skip to content

Add unsharded module reference to sharded modules #2901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
@@ -1409,6 +1409,10 @@ def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:
def fused_optimizer(self) -> KeyedOptimizer:
return self._optim

@property
def unsharded_module_type(self) -> Type[EmbeddingCollection]:
return EmbeddingCollection

def create_context(self) -> EmbeddingCollectionContext:
return EmbeddingCollectionContext(sharding_contexts=[])

8 changes: 8 additions & 0 deletions torchrec/distributed/embedding_tower_sharding.py
Original file line number Diff line number Diff line change
@@ -438,6 +438,10 @@ def named_modules(
def create_context(self) -> NullShardedModuleContext:
return NullShardedModuleContext()

@property
def unsharded_module_type(self) -> Type[EmbeddingTower]:
return EmbeddingTower


class ShardedEmbeddingTowerCollection(
ShardedEmbeddingModule[
@@ -941,6 +945,10 @@ def embedding_feature_names(
kjt_features.extend(config.feature_names)
return kjt_features, wkjt_features

@property
def unsharded_module_type(self) -> Type[EmbeddingTowerCollection]:
return EmbeddingTowerCollection


class EmbeddingTowerCollectionSharder(BaseEmbeddingSharder[EmbeddingTowerCollection]):
def __init__(
23 changes: 22 additions & 1 deletion torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,18 @@
import copy
from dataclasses import dataclass
from enum import Enum, unique
from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Dict,
Generic,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)

import torch
from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation
@@ -399,6 +410,16 @@ def train(self, mode: bool = True): # pyre-ignore[3]

return self

@property
def unsharded_module_type(self) -> Type[nn.Module]:
"""
As this is the generic ShardedEmbeddingModule class, simply
return the generic nn.Module type. In the inherited classes of
ShardedEmbeddingModule, the specific unsharded module type will
be returned.
"""
return nn.Module


M = TypeVar("M", bound=nn.Module)

8 changes: 8 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
@@ -1598,6 +1598,10 @@ def create_context(self) -> EmbeddingBagCollectionContext:
def extend_shard_name(shard_name: str) -> str:
return f"embedding_bags.{shard_name}.weight"

@property
def unsharded_module_type(self) -> Type[EmbeddingBagCollection]:
return EmbeddingBagCollection


class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]):
"""
@@ -1887,6 +1891,10 @@ def fused_optimizer(self) -> KeyedOptimizer:
def create_context(self) -> NullShardedModuleContext:
return NullShardedModuleContext()

@property
def unsharded_module_type(self) -> Type[nn.EmbeddingBag]:
return nn.EmbeddingBag


class EmbeddingBagSharder(BaseEmbeddingSharder[nn.EmbeddingBag]):
"""
4 changes: 4 additions & 0 deletions torchrec/distributed/fp_embeddingbag.py
Original file line number Diff line number Diff line change
@@ -161,6 +161,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
if "_embedding_bag_collection" in fqn:
yield append_prefix(prefix, fqn)

@property
def unsharded_module_type(self) -> Type[FeatureProcessedEmbeddingBagCollection]:
return FeatureProcessedEmbeddingBagCollection


class FeatureProcessedEmbeddingBagCollectionSharder(
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]
4 changes: 4 additions & 0 deletions torchrec/distributed/fused_embeddingbag.py
Original file line number Diff line number Diff line change
@@ -85,6 +85,10 @@ def __init__(
# We need to ensure that a checkpoint from DDP and a checkpoint from a
# model parallel version are compatible.

@property
def unsharded_module_type(self) -> Type[FusedEmbeddingBagCollection]:
return FusedEmbeddingBagCollection


class FusedEmbeddingBagCollectionSharder(
BaseEmbeddingSharder[FusedEmbeddingBagCollection]
8 changes: 8 additions & 0 deletions torchrec/distributed/itep_embeddingbag.py
Original file line number Diff line number Diff line change
@@ -274,6 +274,10 @@ def _group_lookups_and_table_unpruned_size_map(

return grouped_lookups, grouped_table_unpruned_size_map

@property
def unsharded_module_type(self) -> Type[ITEPEmbeddingBagCollection]:
return ITEPEmbeddingBagCollection


class ITEPEmbeddingBagCollectionSharder(
BaseEmbeddingSharder[ITEPEmbeddingBagCollection]
@@ -523,6 +527,10 @@ def _group_lookups_and_table_unpruned_size_map(

return grouped_lookups, grouped_table_unpruned_size_map

@property
def unsharded_module_type(self) -> Type[ITEPEmbeddingCollection]:
return ITEPEmbeddingCollection


class ITEPEmbeddingCollectionSharder(BaseEmbeddingSharder[ITEPEmbeddingCollection]):
def __init__(
4 changes: 4 additions & 0 deletions torchrec/distributed/mc_embedding.py
Original file line number Diff line number Diff line change
@@ -97,6 +97,10 @@ def create_context(
) -> ManagedCollisionEmbeddingCollectionContext:
return ManagedCollisionEmbeddingCollectionContext(sharding_contexts=[])

@property
def unsharded_module_type(self) -> Type[ManagedCollisionEmbeddingCollection]:
return ManagedCollisionEmbeddingCollection


class ManagedCollisionEmbeddingCollectionSharder(
BaseManagedCollisionEmbeddingCollectionSharder[ManagedCollisionEmbeddingCollection]
4 changes: 4 additions & 0 deletions torchrec/distributed/mc_embeddingbag.py
Original file line number Diff line number Diff line change
@@ -82,6 +82,10 @@ def create_context(
) -> ManagedCollisionEmbeddingBagCollectionContext:
return ManagedCollisionEmbeddingBagCollectionContext(sharding_contexts=[])

@property
def unsharded_module_type(self) -> Type[ManagedCollisionEmbeddingBagCollection]:
return ManagedCollisionEmbeddingBagCollection


class ManagedCollisionEmbeddingBagCollectionSharder(
BaseManagedCollisionEmbeddingCollectionSharder[
4 changes: 4 additions & 0 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
@@ -1320,6 +1320,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
for fqn, _ in self.named_buffers():
yield append_prefix(prefix, fqn)

@property
def unsharded_module_type(self) -> Type[QuantManagedCollisionEmbeddingCollection]:
return QuantManagedCollisionEmbeddingCollection


class QuantManagedCollisionEmbeddingCollectionSharder(
BaseQuantEmbeddingSharder[QuantManagedCollisionEmbeddingCollection]
4 changes: 4 additions & 0 deletions torchrec/distributed/quant_embeddingbag.py
Original file line number Diff line number Diff line change
@@ -383,6 +383,10 @@ def create_context(self) -> NullShardedModuleContext:

return NullShardedModuleContext()

@property
def unsharded_module_type(self) -> Type[QuantEmbeddingBagCollection]:
return QuantEmbeddingBagCollection


class QuantEmbeddingBagCollectionSharder(
BaseQuantEmbeddingSharder[QuantEmbeddingBagCollection]
14 changes: 13 additions & 1 deletion torchrec/distributed/quant_state.py
Original file line number Diff line number Diff line change
@@ -10,12 +10,13 @@
import copy
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union

import torch
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
)
from torch import nn
from torch.distributed import _remote_device
from torch.distributed._shard.sharded_tensor import (
Shard,
@@ -367,6 +368,17 @@ def _load_from_state_dict(
missing_keys.extend(_missing_keys)
unexpected_keys.extend(_unexpected_keys)

@property
def unsharded_module_type(self) -> Type[nn.Module]:
"""
Since ShardedQuantEmbeddingModuleState is not exactly a sharded module
but rather a class to utilize generic helper functions. Returns generic
nn.Module type.
"""

# TODO: Add test in TorchRec for using ShardedQuantEmbeddingModuleState
return nn.Module


@dataclass
class WeightSpec:
8 changes: 7 additions & 1 deletion torchrec/distributed/tests/test_embedding_types.py
Original file line number Diff line number Diff line change
@@ -8,9 +8,10 @@
# pyre-strict

import unittest
from typing import Dict, List
from typing import Dict, List, Type

import torch
from torch import nn
from torchrec.distributed.embedding_types import KJTList, ShardedEmbeddingModule
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionContext
from torchrec.distributed.types import Awaitable, LazyAwaitable
@@ -55,6 +56,11 @@ def compute(self, ctx: ShrdCtx, dist_input: CompIn) -> DistOut:
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
pass

@property
def unsharded_module_type(self) -> Type[nn.Module]:
# Since this is a fake sharded embedding module, just returning default module
return nn.Module


class TestShardedEmbeddingModule(unittest.TestCase):
def test_train_mode(self) -> None:
13 changes: 13 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
@@ -1034,6 +1034,19 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
for key, _ in self.named_parameters(prefix):
yield key

@property
@abc.abstractmethod
def unsharded_module_type(self) -> Type[nn.Module]:
"""
This property is added as part of dynamic sharding implementation.
When resharding an already-sharded module wrapped in DMP, the unsharded
module type is needed to identify the proper sharder to reshard. This is
due to DistributedModelParellel (DMP) references module Sharders based
on the unsharded module type.
"""
...


def get_tensor_size_bytes(t: torch.Tensor) -> int:
b: int = t.numel() * t.element_size()