Skip to content

Commit 35df75a

Browse files
iindyktf-transform-team
authored and
tf-transform-team
committed
Adding tf.RaggedTensor support to numeric TFT mappers and analyzers.
PiperOrigin-RevId: 404598912
1 parent 84f3eeb commit 35df75a

File tree

9 files changed

+612
-263
lines changed

9 files changed

+612
-263
lines changed

RELEASE.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44

55
## Major Features and Improvements
66

7-
* Added `tf.RaggedTensor` support to `tft.bucketize`,
8-
`tft.compute_and_apply_vocabulary`, `tft.scale_to_gaussian` and related
9-
analyzers and mappers with `reduce_instance_dims=True`.
7+
* Added `tf.RaggedTensor` support to all analyzers and mappers with
8+
`reduce_instance_dims=True`.
109

1110
## Bug Fixes and Other Changes
1211

tensorflow_transform/analyzers.py

Lines changed: 69 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112

113113
def _apply_cacheable_combiner(
114114
combiner: analyzer_nodes.Combiner,
115-
*tensor_inputs: common_types.TensorType) -> Tuple[tf.Tensor, ...]:
115+
*tensor_inputs: common_types.InputTensorType) -> Tuple[tf.Tensor, ...]:
116116
"""Applies the combiner over the whole dataset possibly utilizing cache."""
117117
input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
118118
tensor_inputs)
@@ -137,7 +137,7 @@ def _apply_cacheable_combiner(
137137

138138
def _apply_cacheable_combiner_per_key(
139139
combiner: analyzer_nodes.Combiner,
140-
*tensor_inputs: common_types.TensorType) -> Tuple[tf.Tensor, ...]:
140+
*tensor_inputs: common_types.InputTensorType) -> Tuple[tf.Tensor, ...]:
141141
"""Similar to _apply_cacheable_combiner but this is computed per key."""
142142
input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
143143
tensor_inputs)
@@ -162,7 +162,7 @@ def _apply_cacheable_combiner_per_key(
162162

163163
def _apply_cacheable_combiner_per_key_large(
164164
combiner: analyzer_nodes.Combiner, key_vocabulary_filename: str,
165-
*tensor_inputs: common_types.TensorType
165+
*tensor_inputs: common_types.InputTensorType
166166
) -> Union[tf.Tensor, common_types.Asset]:
167167
"""Similar to above but saves the combined result to a file."""
168168
input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
@@ -382,16 +382,16 @@ def _numeric_combine(inputs: List[tf.Tensor],
382382

383383
@common.log_api_use(common.ANALYZER_COLLECTION)
384384
def min( # pylint: disable=redefined-builtin
385-
x: common_types.TensorType,
385+
x: common_types.InputTensorType,
386386
reduce_instance_dims: bool = True,
387387
name: Optional[str] = None) -> tf.Tensor:
388388
"""Computes the minimum of the values of a `Tensor` over the whole dataset.
389389
390-
In the case of a `SparseTensor` missing values will be used in return value:
391-
for float, NaN is used and for other dtypes the max is used.
390+
In the case of a `CompositeTensor` missing values will be used in return
391+
value: for float, NaN is used and for other dtypes the max is used.
392392
393393
Args:
394-
x: A `Tensor` or `SparseTensor`.
394+
x: A `Tensor` or `CompositeTensor`.
395395
reduce_instance_dims: By default collapses the batch and instance dimensions
396396
to arrive at a single scalar output. If False, only collapses the batch
397397
dimension and outputs a `Tensor` of the same shape as the input.
@@ -409,16 +409,16 @@ def min( # pylint: disable=redefined-builtin
409409

410410
@common.log_api_use(common.ANALYZER_COLLECTION)
411411
def max( # pylint: disable=redefined-builtin
412-
x: common_types.TensorType,
412+
x: common_types.InputTensorType,
413413
reduce_instance_dims: bool = True,
414414
name: Optional[str] = None) -> tf.Tensor:
415415
"""Computes the maximum of the values of a `Tensor` over the whole dataset.
416416
417-
In the case of a `SparseTensor` missing values will be used in return value:
418-
for float, NaN is used and for other dtypes the min is used.
417+
In the case of a `CompositeTensor` missing values will be used in return
418+
value: for float, NaN is used and for other dtypes the min is used.
419419
420420
Args:
421-
x: A `Tensor` or `SparseTensor`.
421+
x: A `Tensor` or `CompositeTensor`.
422422
reduce_instance_dims: By default collapses the batch and instance dimensions
423423
to arrive at a single scalar output. If False, only collapses the batch
424424
dimension and outputs a vector of the same shape as the input.
@@ -433,19 +433,20 @@ def max( # pylint: disable=redefined-builtin
433433
return _min_and_max(x, reduce_instance_dims, name)[1]
434434

435435

436-
def _min_and_max(x: common_types.TensorType,
436+
def _min_and_max(x: common_types.InputTensorType,
437437
reduce_instance_dims: bool = True,
438438
name: Optional[str] = None) -> Tuple[tf.Tensor, tf.Tensor]:
439-
"""Computes the min and max of the values of a `Tensor` or `SparseTensor`.
439+
"""Computes the min and max of the values of a `Tensor` or `CompositeTensor`.
440440
441-
In the case of a `SparseTensor` missing values will be used in return value:
441+
In the case of a `CompositeTensor` missing values will be used in return
442+
value:
442443
for float, NaN is used and for other dtypes the min is used.
443444
444445
Args:
445-
x: A `Tensor` or `SparseTensor`.
446+
x: A `Tensor` or `CompositeTensor`.
446447
reduce_instance_dims: By default collapses the batch and instance dimensions
447-
to arrive at a single scalar output. If False, only collapses the batch
448-
dimension and outputs a vector of the same shape as the input.
448+
to arrive at a single scalar output. If False, only collapses the batch
449+
dimension and outputs a vector of the same shape as the input.
449450
name: (Optional) A name for this operation.
450451
451452
Returns:
@@ -461,6 +462,9 @@ def _min_and_max(x: common_types.TensorType,
461462
combine_fn = np.nanmax
462463
default_accumulator_value = (np.nan if x.dtype.is_floating else
463464
-output_dtype.max)
465+
elif not reduce_instance_dims and isinstance(x, tf.RaggedTensor):
466+
raise NotImplementedError(
467+
'Elemenwise min_and_max does not support RaggedTensors.')
464468
else:
465469
combine_fn = np.max
466470
default_accumulator_value = (-np.inf if x.dtype.is_floating else
@@ -478,31 +482,31 @@ def _min_and_max(x: common_types.TensorType,
478482

479483

480484
def _min_and_max_per_key(
481-
x: common_types.TensorType,
482-
key: common_types.TensorType,
485+
x: common_types.InputTensorType,
486+
key: common_types.InputTensorType,
483487
reduce_instance_dims: bool = True,
484488
key_vocabulary_filename: Optional[str] = None,
485489
name: Optional[str] = None
486490
) -> Union[Tuple[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor]:
487-
"""Computes the min and max of the values of a `Tensor` or `SparseTensor`.
491+
"""Computes the min and max of the values of a `Tensor` or `CompositeTensor`.
488492
489-
In the case of a `SparseTensor` missing values will be used in return value:
490-
for float, NaN is used and for other dtypes the min is used.
493+
In the case of a `CompositeTensor` missing values will be used in return
494+
value: for float, NaN is used and for other dtypes the min is used.
491495
492496
This function operates under the assumption that the size of the key set
493497
is small enough to fit in memory. Anything above a certain size larger is not
494498
guaranteed to be handled properly, but support for larger key sets may be
495499
available in a future version.
496500
497501
Args:
498-
x: A `Tensor` or `SparseTensor`.
499-
key: A Tensor or `SparseTensor` of dtype tf.string. If `x` is
500-
a `SparseTensor`, `key` must exactly match `x` in everything except
502+
x: A `Tensor` or `CompositeTensor`.
503+
key: A Tensor or `CompositeTensor` of dtype tf.string. If `x` is a
504+
`CompositeTensor`, `key` must exactly match `x` in everything except
501505
values.
502506
reduce_instance_dims: By default collapses the batch and instance dimensions
503-
to arrive at a single scalar output. If False, only collapses the batch
504-
dimension and outputs a vector of the same shape as the input.
505-
The False case is not currently supported for _min_and_max_per_key.
507+
to arrive at a single scalar output. If False, only collapses the batch
508+
dimension and outputs a vector of the same shape as the input. The False
509+
case is not currently supported for _min_and_max_per_key.
506510
key_vocabulary_filename: (Optional) The file name for the key-output mapping
507511
file. If None and key are provided, this combiner assumes the keys fit in
508512
memory and will not store the result in a file. If empty string, a file
@@ -528,8 +532,9 @@ def _min_and_max_per_key(
528532

529533
with tf.compat.v1.name_scope(name, 'min_and_max_per_key'):
530534
output_dtype = x.dtype
531-
if (not reduce_instance_dims and isinstance(x, tf.SparseTensor) and
532-
x.dtype.is_floating):
535+
if (not reduce_instance_dims and
536+
isinstance(x,
537+
(tf.SparseTensor, tf.RaggedTensor)) and x.dtype.is_floating):
533538
combine_fn = np.nanmax
534539
default_accumulator_value = (np.nan if x.dtype.is_floating else
535540
-output_dtype.max)
@@ -572,13 +577,13 @@ def _sum_combine_fn_and_dtype(
572577

573578
@common.log_api_use(common.ANALYZER_COLLECTION)
574579
def sum( # pylint: disable=redefined-builtin
575-
x: common_types.TensorType,
580+
x: common_types.InputTensorType,
576581
reduce_instance_dims: bool = True,
577582
name: Optional[str] = None) -> tf.Tensor:
578583
"""Computes the sum of the values of a `Tensor` over the whole dataset.
579584
580585
Args:
581-
x: A `Tensor` or `SparseTensor`. Its type must be floating point
586+
x: A `Tensor` or `CompositeTensor`. Its type must be floating point
582587
(float{16|32|64}),integral (int{8|16|32|64}), or
583588
unsigned integral (uint{8|16})
584589
reduce_instance_dims: By default collapses the batch and instance dimensions
@@ -600,13 +605,18 @@ def sum( # pylint: disable=redefined-builtin
600605
if reduce_instance_dims:
601606
if isinstance(x, tf.SparseTensor):
602607
x = x.values
608+
elif isinstance(x, tf.RaggedTensor):
609+
x = x.flat_values
603610
x = tf.reduce_sum(input_tensor=x)
604611
elif isinstance(x, tf.SparseTensor):
605612
if x.dtype == tf.uint8 or x.dtype == tf.uint16:
606613
x = tf.cast(x, tf.int64)
607614
elif x.dtype == tf.uint32 or x.dtype == tf.uint64:
608615
TypeError('Data type %r is not supported' % x.dtype)
609616
x = tf.sparse.reduce_sum(x, axis=0)
617+
elif isinstance(x, tf.RaggedTensor):
618+
raise NotImplementedError(
619+
'Elementwise sum does not support RaggedTensors.')
610620
else:
611621
x = tf.reduce_sum(input_tensor=x, axis=0)
612622
output_dtype, sum_fn = _sum_combine_fn_and_dtype(x.dtype)
@@ -619,7 +629,7 @@ def sum( # pylint: disable=redefined-builtin
619629

620630

621631
@common.log_api_use(common.ANALYZER_COLLECTION)
622-
def histogram(x: common_types.TensorType,
632+
def histogram(x: common_types.InputTensorType,
623633
boundaries: Optional[Union[tf.Tensor, int]] = None,
624634
categorical: Optional[bool] = False,
625635
name: Optional[str] = None) -> Tuple[tf.Tensor, tf.Tensor]:
@@ -638,7 +648,7 @@ def histogram(x: common_types.TensorType,
638648
zip(classes, probabilities)))
639649
640650
Args:
641-
x: A `Tensor` or `SparseTensor`.
651+
x: A `Tensor` or `CompositeTensor`.
642652
boundaries: (Optional) A `Tensor` or `int` used to build the histogram;
643653
ignored if `categorical` is True. If possible, provide boundaries as
644654
multiple sorted values. Default to 10 intervals over the 0-1 range, or
@@ -654,7 +664,12 @@ def histogram(x: common_types.TensorType,
654664

655665
with tf.compat.v1.name_scope(name, 'histogram'):
656666
# We need to flatten because BoostedTreesBucketize expects a rank-1 input
657-
x = x.values if isinstance(x, tf.SparseTensor) else tf.reshape(x, [-1])
667+
if isinstance(x, tf.SparseTensor):
668+
x = x.values
669+
elif isinstance(x, tf.RaggedTensor):
670+
x = x.flat_values
671+
else:
672+
x = tf.reshape(x, [-1])
658673
if categorical:
659674
x_dtype = x.dtype
660675
x = x if x_dtype == tf.string else tf.strings.as_string(x)
@@ -687,13 +702,13 @@ def histogram(x: common_types.TensorType,
687702

688703

689704
@common.log_api_use(common.ANALYZER_COLLECTION)
690-
def size(x: common_types.TensorType,
705+
def size(x: common_types.InputTensorType,
691706
reduce_instance_dims: bool = True,
692707
name: Optional[str] = None) -> tf.Tensor:
693708
"""Computes the total size of instances in a `Tensor` over the whole dataset.
694709
695710
Args:
696-
x: A `Tensor` or `SparseTensor`.
711+
x: A `Tensor` or `CompositeTensor`.
697712
reduce_instance_dims: By default collapses the batch and instance dimensions
698713
to arrive at a single scalar output. If False, only collapses the batch
699714
dimension and outputs a vector of the same shape as the input.
@@ -715,13 +730,13 @@ def size(x: common_types.TensorType,
715730

716731

717732
@common.log_api_use(common.ANALYZER_COLLECTION)
718-
def count_per_key(key: common_types.TensorType,
733+
def count_per_key(key: common_types.InputTensorType,
719734
key_vocabulary_filename: Optional[str] = None,
720735
name: Optional[str] = None):
721736
"""Computes the count of each element of a `Tensor`.
722737
723738
Args:
724-
key: A Tensor or `SparseTensor` of dtype tf.string or tf.int.
739+
key: A Tensor or `CompositeTensor` of dtype tf.string or tf.int.
725740
key_vocabulary_filename: (Optional) The file name for the key-output mapping
726741
file. If None and key are provided, this combiner assumes the keys fit in
727742
memory and will not store the result in a file. If empty string, a file
@@ -764,14 +779,14 @@ def count_per_key(key: common_types.TensorType,
764779

765780

766781
@common.log_api_use(common.ANALYZER_COLLECTION)
767-
def mean(x: common_types.TensorType,
782+
def mean(x: common_types.InputTensorType,
768783
reduce_instance_dims: bool = True,
769784
name: Optional[str] = None,
770785
output_dtype: Optional[tf.DType] = None) -> tf.Tensor:
771786
"""Computes the mean of the values of a `Tensor` over the whole dataset.
772787
773788
Args:
774-
x: A `Tensor` or `SparseTensor`. Its type must be floating point
789+
x: A `Tensor` or `CompositeTensor`. Its type must be floating point
775790
(float{16|32|64}), or integral ([u]int{8|16|32|64}).
776791
reduce_instance_dims: By default collapses the batch and instance dimensions
777792
to arrive at a single scalar output. If False, only collapses the batch
@@ -792,7 +807,7 @@ def mean(x: common_types.TensorType,
792807

793808

794809
@common.log_api_use(common.ANALYZER_COLLECTION)
795-
def var(x: common_types.TensorType,
810+
def var(x: common_types.InputTensorType,
796811
reduce_instance_dims: bool = True,
797812
name: Optional[str] = None,
798813
output_dtype: Optional[tf.DType] = None) -> tf.Tensor:
@@ -802,7 +817,7 @@ def var(x: common_types.TensorType,
802817
(x - mean(x))**2 / length(x).
803818
804819
Args:
805-
x: `Tensor` or `SparseTensor`. Its type must be floating point
820+
x: `Tensor` or `CompositeTensor`. Its type must be floating point
806821
(float{16|32|64}), or integral ([u]int{8|16|32|64}).
807822
reduce_instance_dims: By default collapses the batch and instance dimensions
808823
to arrive at a single scalar output. If False, only collapses the batch
@@ -822,12 +837,17 @@ def var(x: common_types.TensorType,
822837
return _mean_and_var(x, reduce_instance_dims, output_dtype)[1]
823838

824839

825-
def _mean_and_var(x, reduce_instance_dims=True, output_dtype=None):
840+
def _mean_and_var(x: common_types.InputTensorType,
841+
reduce_instance_dims: bool = True,
842+
output_dtype: Optional[tf.DType] = None):
826843
"""More efficient combined `mean` and `var`. See `var`."""
827844
if output_dtype is None:
828845
output_dtype = _FLOAT_OUTPUT_DTYPE_MAP.get(x.dtype)
829846
if output_dtype is None:
830847
raise TypeError('Tensor type %r is not supported' % x.dtype)
848+
if not reduce_instance_dims and isinstance(x, tf.RaggedTensor):
849+
raise NotImplementedError(
850+
'Elementwise mean_and_var does not support RaggedTensors.')
831851

832852
with tf.compat.v1.name_scope('mean_and_var'):
833853

@@ -1007,8 +1027,8 @@ def _tukey_parameters(
10071027

10081028

10091029
def _mean_and_var_per_key(
1010-
x: common_types.TensorType,
1011-
key: common_types.TensorType,
1030+
x: common_types.InputTensorType,
1031+
key: common_types.InputTensorType,
10121032
reduce_instance_dims: bool = True,
10131033
output_dtype: Optional[tf.DType] = None,
10141034
key_vocabulary_filename: Optional[str] = None
@@ -1017,9 +1037,9 @@ def _mean_and_var_per_key(
10171037
"""`mean_and_var` by group, specified by key.
10181038
10191039
Args:
1020-
x: A `Tensor` or `SparseTensor`.
1021-
key: A Tensor or `SparseTensor` of dtype tf.string. If `x` is
1022-
a `SparseTensor`, `key` must exactly match `x` in everything except
1040+
x: A `Tensor` or `CompositeTensor`.
1041+
key: A Tensor or `CompositeTensor` of dtype tf.string. If `x` is
1042+
a `CompositeTensor`, `key` must exactly match `x` in everything except
10231043
values.
10241044
reduce_instance_dims: (Optional) By default collapses the batch and instance
10251045
dimensions to arrive at a single scalar output. The False case is not

0 commit comments

Comments
 (0)