@@ -1756,6 +1756,7 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
1756
1756
"_weights" ,
1757
1757
"_lengths" ,
1758
1758
"_offsets" ,
1759
+ "_stride_per_key" ,
1759
1760
]
1760
1761
1761
1762
def __init__ (
@@ -3033,7 +3034,17 @@ def dist_init(
3033
3034
def _kjt_flatten (
3034
3035
t : KeyedJaggedTensor ,
3035
3036
) -> Tuple [List [Optional [torch .Tensor ]], List [str ]]:
3036
- return [getattr (t , a ) for a in KeyedJaggedTensor ._fields ], t ._keys
3037
+ values , keys = [getattr (t , a ) for a in KeyedJaggedTensor ._fields [:- 1 ]], t ._keys
3038
+
3039
+ # append stride_per_key to the tensors
3040
+ stride_per_key = t .stride_per_key ()
3041
+ values .append (
3042
+ torch .tensor (stride_per_key , dtype = torch .int32 , device = t .device ())
3043
+ if stride_per_key is not None
3044
+ else None
3045
+ )
3046
+
3047
+ return values , keys
3037
3048
3038
3049
3039
3050
def _kjt_flatten_with_keys (
@@ -3049,13 +3060,22 @@ def _kjt_flatten_with_keys(
3049
3060
def _kjt_unflatten (
3050
3061
values : List [Optional [torch .Tensor ]], context : List [str ] # context is the _keys
3051
3062
) -> KeyedJaggedTensor :
3052
- return KeyedJaggedTensor (context , * values )
3063
+ return KeyedJaggedTensor (context , * values [: - 1 ], stride_per_key = values [ - 1 ] )
3053
3064
3054
3065
3055
3066
def _kjt_flatten_spec (
3056
3067
t : KeyedJaggedTensor , spec : TreeSpec
3057
3068
) -> List [Optional [torch .Tensor ]]:
3058
- return [getattr (t , a ) for a in KeyedJaggedTensor ._fields ]
3069
+ stride_per_key = t .stride_per_key ()
3070
+ stride_per_key_tensor : Optional [torch .Tensor ] = (
3071
+ torch .tensor (stride_per_key , dtype = torch .int32 , device = t .device ())
3072
+ if stride_per_key is not None
3073
+ else None
3074
+ )
3075
+ values = [getattr (t , a ) for a in KeyedJaggedTensor ._fields [:- 1 ]]
3076
+ values .append (stride_per_key_tensor )
3077
+
3078
+ return values
3059
3079
3060
3080
3061
3081
register_pytree_node (
0 commit comments