Skip to content

Commit 09fcea8

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Add unsharded module reference to sharded modules (#2901)
Summary: Adding a simple unsharded module reference to sharded modules. This will be used in Dynamic Sharding by `DistributedModelParallel` to reshard an already-sharded_module. As DMP is created with only one-way relationship in mind, accessing the unsharded module type will help determine which sharder to use in 'resharding'. Differential Revision: D73407830
1 parent 9eaec09 commit 09fcea8

10 files changed

+43
-1
lines changed

torchrec/distributed/embedding.py

+4
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,10 @@ def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:
14091409
def fused_optimizer(self) -> KeyedOptimizer:
14101410
return self._optim
14111411

1412+
@property
1413+
def unsharded_module_type(self) -> Type[EmbeddingCollection]:
1414+
return EmbeddingCollection
1415+
14121416
def create_context(self) -> EmbeddingCollectionContext:
14131417
return EmbeddingCollectionContext(sharding_contexts=[])
14141418

torchrec/distributed/embeddingbag.py

+4
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,10 @@ def create_context(self) -> EmbeddingBagCollectionContext:
15981598
def extend_shard_name(shard_name: str) -> str:
15991599
return f"embedding_bags.{shard_name}.weight"
16001600

1601+
@property
1602+
def unsharded_module_type(self) -> Type[EmbeddingBagCollection]:
1603+
return EmbeddingBagCollection
1604+
16011605

16021606
class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]):
16031607
"""

torchrec/distributed/fp_embeddingbag.py

+4
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
161161
if "_embedding_bag_collection" in fqn:
162162
yield append_prefix(prefix, fqn)
163163

164+
@property
165+
def unsharded_module_type(self) -> Type[FeatureProcessedEmbeddingBagCollection]:
166+
return FeatureProcessedEmbeddingBagCollection
167+
164168

165169
class FeatureProcessedEmbeddingBagCollectionSharder(
166170
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]

torchrec/distributed/fused_embeddingbag.py

+4
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def __init__(
8585
# We need to ensure that a checkpoint from DDP and a checkpoint from a
8686
# model parallel version are compatible.
8787

88+
@property
89+
def unsharded_module_type(self) -> Type[FusedEmbeddingBagCollection]:
90+
return FusedEmbeddingBagCollection
91+
8892

8993
class FusedEmbeddingBagCollectionSharder(
9094
BaseEmbeddingSharder[FusedEmbeddingBagCollection]

torchrec/distributed/mc_embedding.py

+4
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def create_context(
9797
) -> ManagedCollisionEmbeddingCollectionContext:
9898
return ManagedCollisionEmbeddingCollectionContext(sharding_contexts=[])
9999

100+
@property
101+
def unsharded_module_type(self) -> Type[ManagedCollisionEmbeddingCollection]:
102+
return ManagedCollisionEmbeddingCollection
103+
100104

101105
class ManagedCollisionEmbeddingCollectionSharder(
102106
BaseManagedCollisionEmbeddingCollectionSharder[ManagedCollisionEmbeddingCollection]

torchrec/distributed/mc_embeddingbag.py

+4
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def create_context(
8282
) -> ManagedCollisionEmbeddingBagCollectionContext:
8383
return ManagedCollisionEmbeddingBagCollectionContext(sharding_contexts=[])
8484

85+
@property
86+
def unsharded_module_type(self) -> Type[ManagedCollisionEmbeddingBagCollection]:
87+
return ManagedCollisionEmbeddingBagCollection
88+
8589

8690
class ManagedCollisionEmbeddingBagCollectionSharder(
8791
BaseManagedCollisionEmbeddingCollectionSharder[

torchrec/distributed/quant_embedding.py

+4
Original file line numberDiff line numberDiff line change
@@ -1320,6 +1320,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
13201320
for fqn, _ in self.named_buffers():
13211321
yield append_prefix(prefix, fqn)
13221322

1323+
@property
1324+
def unsharded_module_type(self) -> Type[QuantManagedCollisionEmbeddingCollection]:
1325+
return QuantManagedCollisionEmbeddingCollection
1326+
13231327

13241328
class QuantManagedCollisionEmbeddingCollectionSharder(
13251329
BaseQuantEmbeddingSharder[QuantManagedCollisionEmbeddingCollection]

torchrec/distributed/quant_embeddingbag.py

+4
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,10 @@ def create_context(self) -> NullShardedModuleContext:
383383

384384
return NullShardedModuleContext()
385385

386+
@property
387+
def unsharded_module_type(self) -> Type[QuantEmbeddingBagCollection]:
388+
return QuantEmbeddingBagCollection
389+
386390

387391
class QuantEmbeddingBagCollectionSharder(
388392
BaseQuantEmbeddingSharder[QuantEmbeddingBagCollection]

torchrec/distributed/tests/test_embedding_types.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import Dict, List
11+
from typing import Dict, List, Type
1212

1313
import torch
14+
from torch import nn
1415
from torchrec.distributed.embedding_types import KJTList, ShardedEmbeddingModule
1516
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionContext
1617
from torchrec.distributed.types import Awaitable, LazyAwaitable
@@ -55,6 +56,11 @@ def compute(self, ctx: ShrdCtx, dist_input: CompIn) -> DistOut:
5556
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
5657
pass
5758

59+
@property
60+
def unsharded_module_type(self) -> Type[nn.Module]:
61+
# Since this is a fake sharded embedding module, just returning default module
62+
return nn.Module
63+
5864

5965
class TestShardedEmbeddingModule(unittest.TestCase):
6066
def test_train_mode(self) -> None:

torchrec/distributed/types.py

+4
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
10341034
for key, _ in self.named_parameters(prefix):
10351035
yield key
10361036

1037+
@property
1038+
@abc.abstractmethod
1039+
def unsharded_module_type(self) -> Type[nn.Module]: ...
1040+
10371041

10381042
def get_tensor_size_bytes(t: torch.Tensor) -> int:
10391043
b: int = t.numel() * t.element_size()

0 commit comments

Comments
 (0)