Expand file tree Collapse file tree 9 files changed +36
-0
lines changed Original file line number Diff line number Diff line change @@ -1409,6 +1409,10 @@ def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:
1409
1409
def fused_optimizer (self ) -> KeyedOptimizer :
1410
1410
return self ._optim
1411
1411
1412
+ @property
1413
+ def unsharded_module_type (self ) -> Type [EmbeddingCollection ]:
1414
+ return EmbeddingCollection
1415
+
1412
1416
def create_context (self ) -> EmbeddingCollectionContext :
1413
1417
return EmbeddingCollectionContext (sharding_contexts = [])
1414
1418
Original file line number Diff line number Diff line change @@ -1598,6 +1598,10 @@ def create_context(self) -> EmbeddingBagCollectionContext:
1598
1598
def extend_shard_name (shard_name : str ) -> str :
1599
1599
return f"embedding_bags.{ shard_name } .weight"
1600
1600
1601
+ @property
1602
+ def unsharded_module_type (self ) -> Type [EmbeddingBagCollection ]:
1603
+ return EmbeddingBagCollection
1604
+
1601
1605
1602
1606
class EmbeddingBagCollectionSharder (BaseEmbeddingSharder [EmbeddingBagCollection ]):
1603
1607
"""
Original file line number Diff line number Diff line change @@ -161,6 +161,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
161
161
if "_embedding_bag_collection" in fqn :
162
162
yield append_prefix (prefix , fqn )
163
163
164
+ @property
165
+ def unsharded_module_type (self ) -> Type [FeatureProcessedEmbeddingBagCollection ]:
166
+ return FeatureProcessedEmbeddingBagCollection
167
+
164
168
165
169
class FeatureProcessedEmbeddingBagCollectionSharder (
166
170
BaseEmbeddingSharder [FeatureProcessedEmbeddingBagCollection ]
Original file line number Diff line number Diff line change @@ -85,6 +85,10 @@ def __init__(
85
85
# We need to ensure that a checkpoint from DDP and a checkpoint from a
86
86
# model parallel version are compatible.
87
87
88
+ @property
89
+ def unsharded_module_type (self ) -> Type [FusedEmbeddingBagCollection ]:
90
+ return FusedEmbeddingBagCollection
91
+
88
92
89
93
class FusedEmbeddingBagCollectionSharder (
90
94
BaseEmbeddingSharder [FusedEmbeddingBagCollection ]
Original file line number Diff line number Diff line change @@ -97,6 +97,10 @@ def create_context(
97
97
) -> ManagedCollisionEmbeddingCollectionContext :
98
98
return ManagedCollisionEmbeddingCollectionContext (sharding_contexts = [])
99
99
100
+ @property
101
+ def unsharded_module_type (self ) -> Type [ManagedCollisionEmbeddingCollection ]:
102
+ return ManagedCollisionEmbeddingCollection
103
+
100
104
101
105
class ManagedCollisionEmbeddingCollectionSharder (
102
106
BaseManagedCollisionEmbeddingCollectionSharder [ManagedCollisionEmbeddingCollection ]
Original file line number Diff line number Diff line change @@ -82,6 +82,10 @@ def create_context(
82
82
) -> ManagedCollisionEmbeddingBagCollectionContext :
83
83
return ManagedCollisionEmbeddingBagCollectionContext (sharding_contexts = [])
84
84
85
+ @property
86
+ def unsharded_module_type (self ) -> Type [ManagedCollisionEmbeddingBagCollection ]:
87
+ return ManagedCollisionEmbeddingBagCollection
88
+
85
89
86
90
class ManagedCollisionEmbeddingBagCollectionSharder (
87
91
BaseManagedCollisionEmbeddingCollectionSharder [
Original file line number Diff line number Diff line change @@ -1320,6 +1320,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
1320
1320
for fqn , _ in self .named_buffers ():
1321
1321
yield append_prefix (prefix , fqn )
1322
1322
1323
+ @property
1324
+ def unsharded_module_type (self ) -> Type [QuantManagedCollisionEmbeddingCollection ]:
1325
+ return QuantManagedCollisionEmbeddingCollection
1326
+
1323
1327
1324
1328
class QuantManagedCollisionEmbeddingCollectionSharder (
1325
1329
BaseQuantEmbeddingSharder [QuantManagedCollisionEmbeddingCollection ]
Original file line number Diff line number Diff line change @@ -383,6 +383,10 @@ def create_context(self) -> NullShardedModuleContext:
383
383
384
384
return NullShardedModuleContext ()
385
385
386
+ @property
387
+ def unsharded_module_type (self ) -> Type [QuantEmbeddingBagCollection ]:
388
+ return QuantEmbeddingBagCollection
389
+
386
390
387
391
class QuantEmbeddingBagCollectionSharder (
388
392
BaseQuantEmbeddingSharder [QuantEmbeddingBagCollection ]
Original file line number Diff line number Diff line change @@ -1034,6 +1034,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
1034
1034
for key , _ in self .named_parameters (prefix ):
1035
1035
yield key
1036
1036
1037
+ @property
1038
+ @abc .abstractmethod
1039
+ def unsharded_module_type (self ) -> Type [nn .Module ]: ...
1040
+
1037
1041
1038
1042
def get_tensor_size_bytes (t : torch .Tensor ) -> int :
1039
1043
b : int = t .numel () * t .element_size ()
0 commit comments