Skip to content

Commit ce9124e

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Test include stride_per_key to KJT's flatten and unflatten
Differential Revision: D73051959
1 parent 9eaec09 commit ce9124e

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

torchrec/sparse/jagged_tensor.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,7 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
17561756
"_weights",
17571757
"_lengths",
17581758
"_offsets",
1759+
"_stride_per_key",
17591760
]
17601761

17611762
def __init__(
@@ -3033,7 +3034,17 @@ def dist_init(
30333034
def _kjt_flatten(
30343035
t: KeyedJaggedTensor,
30353036
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
3036-
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys
3037+
values, keys = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]], t._keys
3038+
3039+
# append stride_per_key to the tensors
3040+
stride_per_key = t.stride_per_key()
3041+
values.append(
3042+
torch.tensor(stride_per_key, dtype=torch.int32, device=t.device())
3043+
if stride_per_key is not None
3044+
else None
3045+
)
3046+
3047+
return values, keys
30373048

30383049

30393050
def _kjt_flatten_with_keys(
@@ -3049,13 +3060,22 @@ def _kjt_flatten_with_keys(
30493060
def _kjt_unflatten(
30503061
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
30513062
) -> KeyedJaggedTensor:
3052-
return KeyedJaggedTensor(context, *values)
3063+
return KeyedJaggedTensor(context, *values[:-1], stride_per_key=values[-1])
30533064

30543065

30553066
def _kjt_flatten_spec(
30563067
t: KeyedJaggedTensor, spec: TreeSpec
30573068
) -> List[Optional[torch.Tensor]]:
3058-
return [getattr(t, a) for a in KeyedJaggedTensor._fields]
3069+
stride_per_key = t.stride_per_key()
3070+
stride_per_key_tensor: Optional[torch.Tensor] = (
3071+
torch.tensor(stride_per_key, dtype=torch.int32, device=t.device())
3072+
if stride_per_key is not None
3073+
else None
3074+
)
3075+
values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]]
3076+
values.append(stride_per_key_tensor)
3077+
3078+
return values
30593079

30603080

30613081
register_pytree_node(

torchrec/sparse/tests/test_keyed_jagged_tensor.py

+16
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,22 @@ def test_meta_device_compatibility(self) -> None:
10171017
lengths=torch.tensor([], device=torch.device("meta")),
10181018
)
10191019

1020+
def test_flatten_unflatten_with_vbe(self) -> None:
1021+
kjt = KeyedJaggedTensor(
1022+
keys=["f1", "f2"],
1023+
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
1024+
lengths=torch.tensor([3, 3, 2]),
1025+
stride_per_key_per_rank=[[2], [1]],
1026+
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
1027+
)
1028+
1029+
flat_kjt, spec = pytree.tree_flatten(kjt)
1030+
unflattened_kjt = pytree.tree_unflatten(flat_kjt, spec)
1031+
1032+
self.assertEqual(
1033+
kjt.stride_per_key(), unflattened_kjt.stride_per_key().tolist()
1034+
)
1035+
10201036

10211037
class TestKeyedJaggedTensorScripting(unittest.TestCase):
10221038
def test_scriptable_forward(self) -> None:

0 commit comments

Comments
 (0)