@@ -176,6 +176,45 @@ def forward(
176
176
177
177
return model
178
178
179
+ def generate_model_for_vbe_kjt (self ) -> nn .Module :
180
+ class Model (nn .Module ):
181
+ def __init__ (self , ebc ):
182
+ super ().__init__ ()
183
+ self .ebc1 = ebc
184
+
185
+ def forward (
186
+ self ,
187
+ features : KeyedJaggedTensor ,
188
+ ) -> List [torch .Tensor ]:
189
+ kt1 = self .ebc1 (features )
190
+ res : List [torch .Tensor ] = []
191
+
192
+ for kt in [kt1 ]:
193
+ res .extend (KeyedTensor .regroup ([kt ], [[key ] for key in kt .keys ()]))
194
+
195
+ return res
196
+
197
+ config1 = EmbeddingBagConfig (
198
+ name = "t1" ,
199
+ embedding_dim = 3 ,
200
+ num_embeddings = 10 ,
201
+ feature_names = ["f1" ],
202
+ )
203
+ config2 = EmbeddingBagConfig (
204
+ name = "t2" ,
205
+ embedding_dim = 4 ,
206
+ num_embeddings = 10 ,
207
+ feature_names = ["f2" ],
208
+ )
209
+ ebc = EmbeddingBagCollection (
210
+ tables = [config1 , config2 ],
211
+ is_weighted = False ,
212
+ )
213
+
214
+ model = Model (ebc )
215
+
216
+ return model
217
+
179
218
def test_serialize_deserialize_ebc (self ) -> None :
180
219
model = self .generate_model ()
181
220
id_list_features = KeyedJaggedTensor .from_offsets_sync (
@@ -253,6 +292,86 @@ def test_serialize_deserialize_ebc(self) -> None:
253
292
self .assertEqual (deserialized .shape , orginal .shape )
254
293
self .assertTrue (torch .allclose (deserialized , orginal ))
255
294
295
+ @unittest .skip ("Adding test for demonstrating VBE KJT flattening issue for now." )
296
+ def test_serialize_deserialize_ebc_with_vbe_kjt (self ) -> None :
297
+ model = self .generate_model_for_vbe_kjt ()
298
+ id_list_features = KeyedJaggedTensor (
299
+ keys = ["f1" , "f2" ],
300
+ values = torch .tensor ([5 , 6 , 7 , 1 , 2 , 3 , 0 , 1 ]),
301
+ lengths = torch .tensor ([3 , 3 , 2 ]),
302
+ stride_per_key_per_rank = [[2 ], [1 ]],
303
+ inverse_indices = (["f1" , "f2" ], torch .tensor ([[0 , 1 , 0 ], [0 , 0 , 0 ]])),
304
+ )
305
+
306
+ eager_out = model (id_list_features )
307
+
308
+ # Serialize EBC
309
+ model , sparse_fqns = encapsulate_ir_modules (model , JsonSerializer )
310
+ ep = torch .export .export (
311
+ model ,
312
+ (id_list_features ,),
313
+ {},
314
+ strict = False ,
315
+ # Allows KJT to not be unflattened and run a forward on unflattened EP
316
+ preserve_module_call_signature = (tuple (sparse_fqns )),
317
+ )
318
+
319
+ # Run forward on ExportedProgram
320
+ ep_output = ep .module ()(id_list_features )
321
+
322
+ for i , tensor in enumerate (ep_output ):
323
+ self .assertEqual (eager_out [i ].shape , tensor .shape )
324
+
325
+ # Deserialize EBC
326
+ unflatten_ep = torch .export .unflatten (ep )
327
+ deserialized_model = decapsulate_ir_modules (unflatten_ep , JsonSerializer )
328
+
329
+ # check EBC config
330
+ for i in range (5 ):
331
+ ebc_name = f"ebc{ i + 1 } "
332
+ self .assertIsInstance (
333
+ getattr (deserialized_model , ebc_name ), EmbeddingBagCollection
334
+ )
335
+
336
+ for deserialized , orginal in zip (
337
+ getattr (deserialized_model , ebc_name ).embedding_bag_configs (),
338
+ getattr (model , ebc_name ).embedding_bag_configs (),
339
+ ):
340
+ self .assertEqual (deserialized .name , orginal .name )
341
+ self .assertEqual (deserialized .embedding_dim , orginal .embedding_dim )
342
+ self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
343
+ self .assertEqual (deserialized .feature_names , orginal .feature_names )
344
+
345
+ # check FPEBC config
346
+ for i in range (2 ):
347
+ fpebc_name = f"fpebc{ i + 1 } "
348
+ assert isinstance (
349
+ getattr (deserialized_model , fpebc_name ),
350
+ FeatureProcessedEmbeddingBagCollection ,
351
+ )
352
+
353
+ for deserialized , orginal in zip (
354
+ getattr (
355
+ deserialized_model , fpebc_name
356
+ )._embedding_bag_collection .embedding_bag_configs (),
357
+ getattr (
358
+ model , fpebc_name
359
+ )._embedding_bag_collection .embedding_bag_configs (),
360
+ ):
361
+ self .assertEqual (deserialized .name , orginal .name )
362
+ self .assertEqual (deserialized .embedding_dim , orginal .embedding_dim )
363
+ self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
364
+ self .assertEqual (deserialized .feature_names , orginal .feature_names )
365
+
366
+ # Run forward on deserialized model and compare the output
367
+ deserialized_model .load_state_dict (model .state_dict ())
368
+ deserialized_out = deserialized_model (id_list_features )
369
+
370
+ self .assertEqual (len (deserialized_out ), len (eager_out ))
371
+ for deserialized , orginal in zip (deserialized_out , eager_out ):
372
+ self .assertEqual (deserialized .shape , orginal .shape )
373
+ self .assertTrue (torch .allclose (deserialized , orginal ))
374
+
256
375
def test_dynamic_shape_ebc_disabled_in_oss_compatibility (self ) -> None :
257
376
model = self .generate_model ()
258
377
feature1 = KeyedJaggedTensor .from_offsets_sync (
0 commit comments