Skip to content

Commit a28ac22

Browse files
jd7-trfacebook-github-bot
authored andcommittedApr 23, 2025
Add test case for exporting EBC with VBE KJT (#2907)
Summary: Pull Request resolved: #2907 # Context * Currently torchrec IR serializer can't handle variable batch use case. * `torch.export` only captures the keys, values, lengths, weights, offsets of a KJT, however, some variable-batch related parameters like `stride_per_rank` or `inverse_indices` would be ignored. * This test case (expected failure right now) covers the vb-KJT scenario for verifying that the serialize_deserialize_ebc use case works fine with KJTs with variable batch size. # Ref Reviewed By: TroyGarden Differential Revision: D73454558 fbshipit-source-id: 93268154a7bc88e07707c2e9b95de8aab286bed8
1 parent 9f0bd7e commit a28ac22

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed
 

‎torchrec/ir/tests/test_serializer.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,45 @@ def forward(
176176

177177
return model
178178

179+
def generate_model_for_vbe_kjt(self) -> nn.Module:
180+
class Model(nn.Module):
181+
def __init__(self, ebc):
182+
super().__init__()
183+
self.ebc1 = ebc
184+
185+
def forward(
186+
self,
187+
features: KeyedJaggedTensor,
188+
) -> List[torch.Tensor]:
189+
kt1 = self.ebc1(features)
190+
res: List[torch.Tensor] = []
191+
192+
for kt in [kt1]:
193+
res.extend(KeyedTensor.regroup([kt], [[key] for key in kt.keys()]))
194+
195+
return res
196+
197+
config1 = EmbeddingBagConfig(
198+
name="t1",
199+
embedding_dim=3,
200+
num_embeddings=10,
201+
feature_names=["f1"],
202+
)
203+
config2 = EmbeddingBagConfig(
204+
name="t2",
205+
embedding_dim=4,
206+
num_embeddings=10,
207+
feature_names=["f2"],
208+
)
209+
ebc = EmbeddingBagCollection(
210+
tables=[config1, config2],
211+
is_weighted=False,
212+
)
213+
214+
model = Model(ebc)
215+
216+
return model
217+
179218
def test_serialize_deserialize_ebc(self) -> None:
180219
model = self.generate_model()
181220
id_list_features = KeyedJaggedTensor.from_offsets_sync(
@@ -253,6 +292,86 @@ def test_serialize_deserialize_ebc(self) -> None:
253292
self.assertEqual(deserialized.shape, orginal.shape)
254293
self.assertTrue(torch.allclose(deserialized, orginal))
255294

295+
@unittest.skip("Adding test for demonstrating VBE KJT flattening issue for now.")
296+
def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
297+
model = self.generate_model_for_vbe_kjt()
298+
id_list_features = KeyedJaggedTensor(
299+
keys=["f1", "f2"],
300+
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
301+
lengths=torch.tensor([3, 3, 2]),
302+
stride_per_key_per_rank=[[2], [1]],
303+
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
304+
)
305+
306+
eager_out = model(id_list_features)
307+
308+
# Serialize EBC
309+
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
310+
ep = torch.export.export(
311+
model,
312+
(id_list_features,),
313+
{},
314+
strict=False,
315+
# Allows KJT to not be unflattened and run a forward on unflattened EP
316+
preserve_module_call_signature=(tuple(sparse_fqns)),
317+
)
318+
319+
# Run forward on ExportedProgram
320+
ep_output = ep.module()(id_list_features)
321+
322+
for i, tensor in enumerate(ep_output):
323+
self.assertEqual(eager_out[i].shape, tensor.shape)
324+
325+
# Deserialize EBC
326+
unflatten_ep = torch.export.unflatten(ep)
327+
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
328+
329+
# check EBC config
330+
for i in range(5):
331+
ebc_name = f"ebc{i + 1}"
332+
self.assertIsInstance(
333+
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
334+
)
335+
336+
for deserialized, orginal in zip(
337+
getattr(deserialized_model, ebc_name).embedding_bag_configs(),
338+
getattr(model, ebc_name).embedding_bag_configs(),
339+
):
340+
self.assertEqual(deserialized.name, orginal.name)
341+
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
342+
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
343+
self.assertEqual(deserialized.feature_names, orginal.feature_names)
344+
345+
# check FPEBC config
346+
for i in range(2):
347+
fpebc_name = f"fpebc{i + 1}"
348+
assert isinstance(
349+
getattr(deserialized_model, fpebc_name),
350+
FeatureProcessedEmbeddingBagCollection,
351+
)
352+
353+
for deserialized, orginal in zip(
354+
getattr(
355+
deserialized_model, fpebc_name
356+
)._embedding_bag_collection.embedding_bag_configs(),
357+
getattr(
358+
model, fpebc_name
359+
)._embedding_bag_collection.embedding_bag_configs(),
360+
):
361+
self.assertEqual(deserialized.name, orginal.name)
362+
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
363+
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
364+
self.assertEqual(deserialized.feature_names, orginal.feature_names)
365+
366+
# Run forward on deserialized model and compare the output
367+
deserialized_model.load_state_dict(model.state_dict())
368+
deserialized_out = deserialized_model(id_list_features)
369+
370+
self.assertEqual(len(deserialized_out), len(eager_out))
371+
for deserialized, orginal in zip(deserialized_out, eager_out):
372+
self.assertEqual(deserialized.shape, orginal.shape)
373+
self.assertTrue(torch.allclose(deserialized, orginal))
374+
256375
def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
257376
model = self.generate_model()
258377
feature1 = KeyedJaggedTensor.from_offsets_sync(

0 commit comments

Comments
 (0)