Skip to content

Commit f3f7939

Browse files
iindyktfx-copybara
authored andcommitted
Switching from deprecated tf.raw_ops.BoostedTreesBucketize to tf.searchsorted in tft.apply_buckets. This fixes bug with large int64 values being incorrectly mapped and allows to simplify tft.apply_buckets implementation.
PiperOrigin-RevId: 431295261
1 parent 6f08265 commit f3f7939

File tree

6 files changed

+152
-202
lines changed

6 files changed

+152
-202
lines changed

RELEASE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
in deep copy optimization. The reason of adding these tags is to prevent
2323
root Reads that are generated from deep copy being merged due to common
2424
subexpression elimination.
25+
* Fixed an issue when large int64 values would be incorrectly bucketized in
26+
`tft.apply_buckets`.
2527
* Depends on `apache-beam[gcp]>=2.36,<3`.
2628
* Depends on
2729
`tensorflow>=1.15.5,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,<3`.

tensorflow_transform/analyzers.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -644,11 +644,7 @@ def sum( # pylint: disable=redefined-builtin
644644
"""
645645
with tf.compat.v1.name_scope(name, 'sum'):
646646
if reduce_instance_dims:
647-
if isinstance(x, tf.SparseTensor):
648-
x = x.values
649-
elif isinstance(x, tf.RaggedTensor):
650-
x = x.flat_values
651-
x = tf.reduce_sum(input_tensor=x)
647+
x = tf.reduce_sum(input_tensor=tf_utils.get_values(x))
652648
elif isinstance(x, tf.SparseTensor):
653649
if x.dtype == tf.uint8 or x.dtype == tf.uint16:
654650
x = tf.cast(x, tf.int64)
@@ -669,6 +665,11 @@ def sum( # pylint: disable=redefined-builtin
669665
output_dtypes=[output_dtype])[0]
670666

671667

668+
def remove_leftmost_boundary(boundaries: tf.Tensor) -> tf.Tensor:
669+
"""Removes the leftmost boundary from [1, None]-shaped `Tensor` of buckets."""
670+
return boundaries[:, 1:]
671+
672+
672673
@common.log_api_use(common.ANALYZER_COLLECTION)
673674
def histogram(x: common_types.TensorType,
674675
boundaries: Optional[Union[tf.Tensor, int]] = None,
@@ -704,13 +705,7 @@ def histogram(x: common_types.TensorType,
704705
"""
705706

706707
with tf.compat.v1.name_scope(name, 'histogram'):
707-
# We need to flatten because BoostedTreesBucketize expects a rank-1 input
708-
if isinstance(x, tf.SparseTensor):
709-
x = x.values
710-
elif isinstance(x, tf.RaggedTensor):
711-
x = x.flat_values
712-
else:
713-
x = tf.reshape(x, [-1])
708+
x = tf.reshape(tf_utils.get_values(x), [-1])
714709
if categorical:
715710
x_dtype = x.dtype
716711
x = x if x_dtype == tf.string else tf.strings.as_string(x)
@@ -732,10 +727,8 @@ def histogram(x: common_types.TensorType,
732727
# and due to the fact that the rightmost boundary is essentially ignored.
733728
boundaries = tf.expand_dims(tf.cast(boundaries, tf.float32), 0) - 0.0001
734729

735-
bucket_indices = tf_utils.apply_bucketize_op(tf.cast(x, tf.float32),
736-
boundaries,
737-
remove_leftmost_boundary=True)
738-
730+
bucket_indices = tf_utils.assign_buckets(
731+
tf.cast(x, tf.float32), remove_leftmost_boundary(boundaries))
739732
bucket_vocab, counts = count_per_key(tf.strings.as_string(bucket_indices))
740733
counts = tf_utils.reorder_histogram(bucket_vocab, counts,
741734
tf.size(boundaries) - 1)

tensorflow_transform/mappers.py

Lines changed: 22 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,12 @@ def _scale_to_gaussian_internal(
138138
x, reduce_instance_dims=not elementwise, output_dtype=output_dtype)
139139

140140
compose_result_fn = _make_composite_tensor_wrapper_if_composite(x)
141-
x_values = x
141+
x_values = tf_utils.get_values(x)
142142

143143
x_var = analyzers.var(x, reduce_instance_dims=not elementwise,
144144
output_dtype=output_dtype)
145145

146146
if isinstance(x, tf.SparseTensor):
147-
x_values = x.values
148147
if elementwise:
149148
x_loc = tf.gather_nd(x_loc, x.indices[:, 1:])
150149
x_scale = tf.gather_nd(x_scale, x.indices[:, 1:])
@@ -155,7 +154,6 @@ def _scale_to_gaussian_internal(
155154
if elementwise:
156155
raise NotImplementedError(
157156
'Elementwise scale_to_gaussian does not support RaggedTensors.')
158-
x_values = x.flat_values
159157

160158
numerator = tf.cast(x_values, x_loc.dtype) - x_loc
161159
is_long_tailed = tf.math.logical_or(hl > 0.0, hr > 0.0)
@@ -390,19 +388,17 @@ def _scale_by_min_max_internal(
390388
-minus_min_max_for_key[:, 0], minus_min_max_for_key[:, 1])
391389

392390
compose_result_fn = _make_composite_tensor_wrapper_if_composite(x)
393-
x_values = x
391+
x_values = tf_utils.get_values(x)
394392
if isinstance(x, tf.SparseTensor):
395393
if elementwise:
396394
min_x_value = tf.gather_nd(
397395
tf.broadcast_to(min_x_value, x.dense_shape), x.indices)
398396
max_x_value = tf.gather_nd(
399397
tf.broadcast_to(max_x_value, x.dense_shape), x.indices)
400-
x_values = x.values
401398
elif isinstance(x, tf.RaggedTensor):
402399
if elementwise:
403400
raise NotImplementedError(
404401
'Elementwise min_and_max does not support RaggedTensors.')
405-
x_values = x.flat_values
406402

407403
# If min>=max, then the corresponding input to the min_and_max analyzer either
408404
# was empty and the analyzer returned default values, or contained only one
@@ -640,18 +636,16 @@ def _scale_to_z_score_internal(
640636
x_mean, x_var = (mean_var_for_key[:, 0], mean_var_for_key[:, 1])
641637

642638
compose_result_fn = _make_composite_tensor_wrapper_if_composite(x)
643-
x_values = x
639+
x_values = tf_utils.get_values(x)
644640

645641
if isinstance(x, tf.SparseTensor):
646-
x_values = x.values
647642
if elementwise:
648643
x_mean = tf.gather_nd(tf.broadcast_to(x_mean, x.dense_shape), x.indices)
649644
x_var = tf.gather_nd(tf.broadcast_to(x_var, x.dense_shape), x.indices)
650645
elif isinstance(x, tf.RaggedTensor):
651646
if elementwise:
652647
raise NotImplementedError(
653648
'Elementwise scale_to_z_score does not support RaggedTensors')
654-
x_values = x.flat_values
655649

656650
numerator = tf.cast(x_values, x_mean.dtype) - x_mean
657651
denominator = tf.sqrt(x_var)
@@ -1142,7 +1136,7 @@ def _construct_table(asset_filepath):
11421136
return table
11431137

11441138
compose_result_fn = _make_composite_tensor_wrapper_if_composite(x)
1145-
x_values = _get_values_if_composite(x)
1139+
x_values = tf_utils.get_values(x)
11461140
result, table_size = tf_utils.construct_and_lookup_table(
11471141
_construct_table, deferred_vocab_filename_tensor, x_values)
11481142
result = compose_result_fn(result)
@@ -1159,7 +1153,7 @@ def _construct_table(asset_filepath):
11591153
min_value = tf.minimum(min_value, default_value)
11601154
max_value = tf.maximum(max_value, default_value)
11611155
schema_inference.set_tensor_schema_override(
1162-
_get_values_if_composite(result), min_value, max_value)
1156+
tf_utils.get_values(result), min_value, max_value)
11631157
return result
11641158

11651159

@@ -1682,7 +1676,7 @@ def hash_strings(
16821676
strings, hash_buckets, key, name=name)
16831677
else:
16841678
compose_result_fn = _make_composite_tensor_wrapper_if_composite(strings)
1685-
values = _get_values_if_composite(strings)
1679+
values = tf_utils.get_values(strings)
16861680
return compose_result_fn(hash_strings(values, hash_buckets, key))
16871681

16881682

@@ -1747,7 +1741,7 @@ def bucketize(x: common_types.ConsistentTensorType,
17471741
# See explanation in args documentation for epsilon.
17481742
epsilon = min(1.0 / num_buckets, 0.01)
17491743

1750-
x_values = _get_values_if_composite(x)
1744+
x_values = tf_utils.get_values(x)
17511745
bucket_boundaries = analyzers.quantiles(
17521746
x_values,
17531747
num_buckets,
@@ -1821,11 +1815,11 @@ def bucketize_per_key(
18211815
(key_vocab, bucket_boundaries, scale_factor_per_key, shift_per_key,
18221816
actual_num_buckets) = (
18231817
analyzers._quantiles_per_key( # pylint: disable=protected-access
1824-
_get_values_if_composite(x),
1825-
_get_values_if_composite(key),
1818+
tf_utils.get_values(x),
1819+
tf_utils.get_values(key),
18261820
num_buckets,
18271821
epsilon,
1828-
weights=_get_values_if_composite(weights)))
1822+
weights=tf_utils.get_values(weights)))
18291823
return _apply_buckets_with_keys(x, key, key_vocab, bucket_boundaries,
18301824
scale_factor_per_key, shift_per_key,
18311825
actual_num_buckets)
@@ -1847,15 +1841,6 @@ def from_nested_row_splits(values):
18471841
return lambda values: values
18481842

18491843

1850-
def _get_values_if_composite(x: common_types.TensorType) -> tf.Tensor:
1851-
if isinstance(x, tf.SparseTensor):
1852-
return x.values
1853-
elif isinstance(x, tf.RaggedTensor):
1854-
return x.flat_values
1855-
else:
1856-
return x
1857-
1858-
18591844
def _fill_shape(value, shape, dtype):
18601845
return tf.cast(tf.fill(shape, value), dtype)
18611846

@@ -1887,9 +1872,9 @@ def _apply_buckets_with_keys(
18871872
`key` is not present in `key_vocab` then the resulting bucket will be -1.
18881873
"""
18891874
with tf.compat.v1.name_scope(name, 'apply_buckets_with_keys'):
1890-
x_values = tf.cast(_get_values_if_composite(x), tf.float32)
1875+
x_values = tf.cast(tf_utils.get_values(x), tf.float32)
18911876
compose_result_fn = _make_composite_tensor_wrapper_if_composite(x)
1892-
key_values = _get_values_if_composite(key)
1877+
key_values = tf_utils.get_values(key)
18931878

18941879
# Convert `key_values` to indices in key_vocab.
18951880
key_indices = tf_utils.lookup_key(key_values, key_vocab)
@@ -1906,8 +1891,8 @@ def _apply_buckets_with_keys(
19061891

19071892
transformed_x = x_values * scale_factors + shifts
19081893

1909-
offset_buckets = _assign_buckets_all_shapes(
1910-
transformed_x, tf.expand_dims(bucket_boundaries, 0))
1894+
offset_buckets = tf_utils.assign_buckets(
1895+
transformed_x, bucket_boundaries, side=tf_utils.Side.RIGHT)
19111896

19121897
max_bucket = num_buckets - 1
19131898

@@ -1972,7 +1957,7 @@ def apply_buckets_with_interpolation(
19721957
with tf.compat.v1.name_scope(name, 'buckets_with_interpolation'):
19731958
bucket_boundaries = tf.convert_to_tensor(bucket_boundaries)
19741959
tf.compat.v1.assert_rank(bucket_boundaries, 2)
1975-
x_values = _get_values_if_composite(x)
1960+
x_values = tf_utils.get_values(x)
19761961
compose_result_fn = _make_composite_tensor_wrapper_if_composite(x)
19771962
if not (x_values.dtype.is_floating or x_values.dtype.is_integer):
19781963
raise ValueError(
@@ -1992,7 +1977,8 @@ def apply_buckets_with_interpolation(
19921977
tf.constant(0, tf.int64),
19931978
name='assert_1_or_more_finite_boundaries')
19941979
with tf.control_dependencies([assert_some_finite_boundaries]):
1995-
bucket_indices = _assign_buckets_all_shapes(x_values, bucket_boundaries)
1980+
bucket_indices = tf_utils.assign_buckets(
1981+
x_values, bucket_boundaries, side=tf_utils.Side.RIGHT)
19961982
# Get max, min, and width of the corresponding bucket for each element.
19971983
bucket_max = tf.cast(
19981984
tf.gather(
@@ -2086,7 +2072,8 @@ def apply_buckets(
20862072
bucket_boundaries = tf.convert_to_tensor(bucket_boundaries)
20872073
tf.compat.v1.assert_rank(bucket_boundaries, 2)
20882074

2089-
bucketized_values = _assign_buckets_all_shapes(x, bucket_boundaries)
2075+
bucketized_values = tf_utils.assign_buckets(
2076+
tf_utils.get_values(x), bucket_boundaries, side=tf_utils.Side.RIGHT)
20902077

20912078
# Attach the relevant metadata to result, so that the corresponding
20922079
# output feature will have this metadata set.
@@ -2099,94 +2086,6 @@ def apply_buckets(
20992086
return compose_result_fn(bucketized_values)
21002087

21012088

2102-
def _assign_buckets_all_shapes(x: common_types.TensorType,
2103-
bucket_boundaries: tf.Tensor) -> tf.Tensor:
2104-
"""Assigns every value in x to a bucket index defined by bucket_boundaries.
2105-
2106-
Depending on the shape of the x input, we split into individual vectors
2107-
so that the actual _assign_buckets function can operate as expected.
2108-
2109-
Args:
2110-
x: a `Tensor` or `CompositeTensor` with no more than 2 dimensions.
2111-
bucket_boundaries: The bucket boundaries represented as a rank 2 `Tensor`.
2112-
2113-
Returns:
2114-
A `Tensor` of the same shape as `x`, with each element in the
2115-
returned tensor representing the bucketized value. Bucketized value is
2116-
in the range [0, len(bucket_boundaries)].
2117-
"""
2118-
with tf.compat.v1.name_scope(None, 'assign_buckets_all_shapes'):
2119-
bucket_boundaries = tf.cast(bucket_boundaries, tf.float32)
2120-
x = tf.cast(_get_values_if_composite(x), tf.float32)
2121-
2122-
# We expect boundaries in final dimension but have to satisfy other shapes.
2123-
if bucket_boundaries.shape[0] != 1:
2124-
bucket_boundaries = tf.transpose(bucket_boundaries)
2125-
2126-
if x.get_shape().ndims == 1:
2127-
buckets = _assign_buckets(x, bucket_boundaries)
2128-
elif x.get_shape().ndims is None or x.get_shape().ndims == 0:
2129-
buckets = tf.squeeze(_assign_buckets(
2130-
tf.expand_dims(x, axis=0), bucket_boundaries))
2131-
elif x.get_shape().ndims == 2:
2132-
# For x with 2 dimensions, assign buckets to each column separately.
2133-
buckets = [_assign_buckets(x_column, bucket_boundaries)
2134-
for x_column in tf.unstack(x, axis=1)]
2135-
# Ex: x = [[1,2], [3,4]], boundaries = [[2.5]]
2136-
# results in [[0,1], [0,1]] transposed to [[0,0], [1,1]].
2137-
buckets = tf.transpose(buckets)
2138-
else:
2139-
raise ValueError('Assign buckets requires at most 2 dimensions')
2140-
return buckets
2141-
2142-
2143-
# TODO(b/148278398): Determine how NaN values should be assigned to buckets.
2144-
# Currently it maps to the highest bucket.
2145-
def _assign_buckets(x_values: tf.Tensor,
2146-
bucket_boundaries: tf.Tensor) -> tf.Tensor:
2147-
"""Assigns every value in x to a bucket index defined by bucket_boundaries.
2148-
2149-
Args:
2150-
x_values: a `Tensor` of dtype float32 with no more than one dimension.
2151-
bucket_boundaries: The bucket boundaries represented as a rank 2 `Tensor`.
2152-
Should be sorted.
2153-
2154-
Returns:
2155-
A `Tensor` of the same shape as `x_values`, with each element in the
2156-
returned tensor representing the bucketized value. Bucketized value is
2157-
in the range [0, len(bucket_boundaries)].
2158-
"""
2159-
with tf.compat.v1.name_scope(None, 'assign_buckets'):
2160-
max_value = tf.cast(tf.shape(input=bucket_boundaries)[1], dtype=tf.int64)
2161-
2162-
# We need to reverse the negated boundaries and x_values and add a final
2163-
# max boundary to work with the new bucketize op.
2164-
bucket_boundaries = tf.reverse(-bucket_boundaries, [-1])
2165-
bucket_boundaries = tf.concat([
2166-
bucket_boundaries, [[tf.reduce_max(-x_values)]]], axis=-1)
2167-
2168-
if x_values.get_shape().ndims > 1:
2169-
x_values = tf.squeeze(x_values)
2170-
2171-
# BoostedTreesBucketize assigns to lower bound instead of upper bound, so
2172-
# we need to reverse both boundaries and x_values and make them negative
2173-
# to make cases exactly at the boundary consistent.
2174-
buckets = tf_utils.apply_bucketize_op(-x_values, bucket_boundaries)
2175-
# After reversing the inputs, the assigned buckets are exactly reversed
2176-
# and need to be re-reversed to their original index.
2177-
buckets = tf.subtract(max_value, buckets)
2178-
2179-
if buckets.shape.ndims <= 1:
2180-
# As a result of the above squeeze, there might be too few bucket dims
2181-
# and we want the output shape to match the input.
2182-
if not buckets.shape.ndims:
2183-
buckets = tf.expand_dims(buckets, -1)
2184-
elif x_values.shape.ndims is not None and x_values.shape.ndims > 1:
2185-
buckets = tf.expand_dims(buckets, -1)
2186-
2187-
return buckets
2188-
2189-
21902089
def _annotate_buckets(x: tf.Tensor, bucket_boundaries: tf.Tensor) -> None:
21912090
"""Annotates a bucketized tensor with the boundaries that were applied.
21922091
@@ -2299,9 +2198,9 @@ def estimated_probability_density(x: tf.Tensor,
22992198
tf.cast(tf.size(probabilities), tf.float32))
23002199
bucket_densities = probabilities / bin_width
23012200

2302-
bucket_indices = tf_utils.apply_bucketize_op(
2303-
tf.cast(x, tf.float32), boundaries, True)
2304-
2201+
bucket_indices = tf_utils.assign_buckets(
2202+
tf.cast(x, tf.float32),
2203+
analyzers.remove_leftmost_boundary(boundaries))
23052204
bucket_indices = tf_utils._align_dims(bucket_indices, xdims) # pylint: disable=protected-access
23062205

23072206
# In the categorical case, when keys are missing, the indices may be -1,

0 commit comments

Comments
 (0)