Skip to content

Commit 103f396

Browse files
committed
Add unsharded module property to sharded modules (#2901)
Summary: Pull Request resolved: #2901 Adding a simple unsharded module reference to sharded modules. This will be used in Dynamic Sharding by `DistributedModelParallel` to reshard an already-sharded_module. As DMP is created with only one-way relationship in mind, accessing the unsharded module type will help determine which sharder to use in 'resharding'. Most of the changes here are simply to add in the property in where ShardedModule or it's wrapper ShardedEmbeddingModule is used. Differential Revision: D73407830
1 parent a28ac22 commit 103f396

14 files changed

+107
-3
lines changed

torchrec/distributed/embedding.py

+4
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,10 @@ def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:
14091409
def fused_optimizer(self) -> KeyedOptimizer:
14101410
return self._optim
14111411

1412+
@property
1413+
def unsharded_module_type(self) -> Type[EmbeddingCollection]:
1414+
return EmbeddingCollection
1415+
14121416
def create_context(self) -> EmbeddingCollectionContext:
14131417
return EmbeddingCollectionContext(sharding_contexts=[])
14141418

torchrec/distributed/embedding_tower_sharding.py

+8
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,10 @@ def named_modules(
438438
def create_context(self) -> NullShardedModuleContext:
439439
return NullShardedModuleContext()
440440

441+
@property
442+
def unsharded_module_type(self) -> Type[EmbeddingTower]:
443+
return EmbeddingTower
444+
441445

442446
class ShardedEmbeddingTowerCollection(
443447
ShardedEmbeddingModule[
@@ -941,6 +945,10 @@ def embedding_feature_names(
941945
kjt_features.extend(config.feature_names)
942946
return kjt_features, wkjt_features
943947

948+
@property
949+
def unsharded_module_type(self) -> Type[EmbeddingTowerCollection]:
950+
return EmbeddingTowerCollection
951+
944952

945953
class EmbeddingTowerCollectionSharder(BaseEmbeddingSharder[EmbeddingTowerCollection]):
946954
def __init__(

torchrec/distributed/embedding_types.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,18 @@
1111
import copy
1212
from dataclasses import dataclass
1313
from enum import Enum, unique
14-
from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
14+
from typing import (
15+
Any,
16+
Dict,
17+
Generic,
18+
Iterator,
19+
List,
20+
Optional,
21+
Tuple,
22+
Type,
23+
TypeVar,
24+
Union,
25+
)
1526

1627
import torch
1728
from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation
@@ -399,6 +410,16 @@ def train(self, mode: bool = True): # pyre-ignore[3]
399410

400411
return self
401412

413+
@property
414+
def unsharded_module_type(self) -> Type[nn.Module]:
415+
"""
416+
As this is the generic ShardedEmbeddingModule class, simply
417+
return the generic nn.Module type. In the inherited classes of
418+
ShardedEmbeddingModule, the specific unsharded module type will
419+
be returned.
420+
"""
421+
return nn.Module
422+
402423

403424
M = TypeVar("M", bound=nn.Module)
404425

torchrec/distributed/embeddingbag.py

+8
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,10 @@ def create_context(self) -> EmbeddingBagCollectionContext:
15981598
def extend_shard_name(shard_name: str) -> str:
15991599
return f"embedding_bags.{shard_name}.weight"
16001600

1601+
@property
1602+
def unsharded_module_type(self) -> Type[EmbeddingBagCollection]:
1603+
return EmbeddingBagCollection
1604+
16011605

16021606
class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]):
16031607
"""
@@ -1887,6 +1891,10 @@ def fused_optimizer(self) -> KeyedOptimizer:
18871891
def create_context(self) -> NullShardedModuleContext:
18881892
return NullShardedModuleContext()
18891893

1894+
@property
1895+
def unsharded_module_type(self) -> Type[nn.EmbeddingBag]:
1896+
return nn.EmbeddingBag
1897+
18901898

18911899
class EmbeddingBagSharder(BaseEmbeddingSharder[nn.EmbeddingBag]):
18921900
"""

torchrec/distributed/fp_embeddingbag.py

+4
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
161161
if "_embedding_bag_collection" in fqn:
162162
yield append_prefix(prefix, fqn)
163163

164+
@property
165+
def unsharded_module_type(self) -> Type[FeatureProcessedEmbeddingBagCollection]:
166+
return FeatureProcessedEmbeddingBagCollection
167+
164168

165169
class FeatureProcessedEmbeddingBagCollectionSharder(
166170
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]

torchrec/distributed/fused_embeddingbag.py

+4
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def __init__(
8585
# We need to ensure that a checkpoint from DDP and a checkpoint from a
8686
# model parallel version are compatible.
8787

88+
@property
89+
def unsharded_module_type(self) -> Type[FusedEmbeddingBagCollection]:
90+
return FusedEmbeddingBagCollection
91+
8892

8993
class FusedEmbeddingBagCollectionSharder(
9094
BaseEmbeddingSharder[FusedEmbeddingBagCollection]

torchrec/distributed/itep_embeddingbag.py

+8
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ def _group_lookups_and_table_unpruned_size_map(
274274

275275
return grouped_lookups, grouped_table_unpruned_size_map
276276

277+
@property
278+
def unsharded_module_type(self) -> Type[ITEPEmbeddingBagCollection]:
279+
return ITEPEmbeddingBagCollection
280+
277281

278282
class ITEPEmbeddingBagCollectionSharder(
279283
BaseEmbeddingSharder[ITEPEmbeddingBagCollection]
@@ -523,6 +527,10 @@ def _group_lookups_and_table_unpruned_size_map(
523527

524528
return grouped_lookups, grouped_table_unpruned_size_map
525529

530+
@property
531+
def unsharded_module_type(self) -> Type[ITEPEmbeddingCollection]:
532+
return ITEPEmbeddingCollection
533+
526534

527535
class ITEPEmbeddingCollectionSharder(BaseEmbeddingSharder[ITEPEmbeddingCollection]):
528536
def __init__(

torchrec/distributed/mc_embedding.py

+4
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def create_context(
9797
) -> ManagedCollisionEmbeddingCollectionContext:
9898
return ManagedCollisionEmbeddingCollectionContext(sharding_contexts=[])
9999

100+
@property
101+
def unsharded_module_type(self) -> Type[ManagedCollisionEmbeddingCollection]:
102+
return ManagedCollisionEmbeddingCollection
103+
100104

101105
class ManagedCollisionEmbeddingCollectionSharder(
102106
BaseManagedCollisionEmbeddingCollectionSharder[ManagedCollisionEmbeddingCollection]

torchrec/distributed/mc_embeddingbag.py

+4
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def create_context(
8282
) -> ManagedCollisionEmbeddingBagCollectionContext:
8383
return ManagedCollisionEmbeddingBagCollectionContext(sharding_contexts=[])
8484

85+
@property
86+
def unsharded_module_type(self) -> Type[ManagedCollisionEmbeddingBagCollection]:
87+
return ManagedCollisionEmbeddingBagCollection
88+
8589

8690
class ManagedCollisionEmbeddingBagCollectionSharder(
8791
BaseManagedCollisionEmbeddingCollectionSharder[

torchrec/distributed/quant_embedding.py

+4
Original file line numberDiff line numberDiff line change
@@ -1320,6 +1320,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
13201320
for fqn, _ in self.named_buffers():
13211321
yield append_prefix(prefix, fqn)
13221322

1323+
@property
1324+
def unsharded_module_type(self) -> Type[QuantManagedCollisionEmbeddingCollection]:
1325+
return QuantManagedCollisionEmbeddingCollection
1326+
13231327

13241328
class QuantManagedCollisionEmbeddingCollectionSharder(
13251329
BaseQuantEmbeddingSharder[QuantManagedCollisionEmbeddingCollection]

torchrec/distributed/quant_embeddingbag.py

+4
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,10 @@ def create_context(self) -> NullShardedModuleContext:
383383

384384
return NullShardedModuleContext()
385385

386+
@property
387+
def unsharded_module_type(self) -> Type[QuantEmbeddingBagCollection]:
388+
return QuantEmbeddingBagCollection
389+
386390

387391
class QuantEmbeddingBagCollectionSharder(
388392
BaseQuantEmbeddingSharder[QuantEmbeddingBagCollection]

torchrec/distributed/quant_state.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
import copy
1111
from dataclasses import dataclass
1212
from functools import partial
13-
from typing import Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
13+
from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union
1414

1515
import torch
1616
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
1717
IntNBitTableBatchedEmbeddingBagsCodegen,
1818
)
19+
from torch import nn
1920
from torch.distributed import _remote_device
2021
from torch.distributed._shard.sharded_tensor import (
2122
Shard,
@@ -367,6 +368,17 @@ def _load_from_state_dict(
367368
missing_keys.extend(_missing_keys)
368369
unexpected_keys.extend(_unexpected_keys)
369370

371+
@property
372+
def unsharded_module_type(self) -> Type[nn.Module]:
373+
"""
374+
Since ShardedQuantEmbeddingModuleState is not exactly a sharded module
375+
but rather a class to utilize generic helper functions. Returns generic
376+
nn.Module type.
377+
"""
378+
379+
# TODO: Add test in TorchRec for using ShardedQuantEmbeddingModuleState
380+
return nn.Module
381+
370382

371383
@dataclass
372384
class WeightSpec:

torchrec/distributed/tests/test_embedding_types.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import Dict, List
11+
from typing import Dict, List, Type
1212

1313
import torch
14+
from torch import nn
1415
from torchrec.distributed.embedding_types import KJTList, ShardedEmbeddingModule
1516
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionContext
1617
from torchrec.distributed.types import Awaitable, LazyAwaitable
@@ -55,6 +56,11 @@ def compute(self, ctx: ShrdCtx, dist_input: CompIn) -> DistOut:
5556
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
5657
pass
5758

59+
@property
60+
def unsharded_module_type(self) -> Type[nn.Module]:
61+
# Since this is a fake sharded embedding module, just returning default module
62+
return nn.Module
63+
5864

5965
class TestShardedEmbeddingModule(unittest.TestCase):
6066
def test_train_mode(self) -> None:

torchrec/distributed/types.py

+13
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,19 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
10341034
for key, _ in self.named_parameters(prefix):
10351035
yield key
10361036

1037+
@property
1038+
@abc.abstractmethod
1039+
def unsharded_module_type(self) -> Type[nn.Module]:
1040+
"""
1041+
This property is added as part of dynamic sharding implementation.
1042+
1043+
When resharding an already-sharded module wrapped in DMP, the unsharded
1044+
module type is needed to identify the proper sharder to reshard. This is
1045+
due to DistributedModelParellel (DMP) references module Sharders based
1046+
on the unsharded module type.
1047+
"""
1048+
...
1049+
10371050

10381051
def get_tensor_size_bytes(t: torch.Tensor) -> int:
10391052
b: int = t.numel() * t.element_size()

0 commit comments

Comments
 (0)