Skip to content

Commit 71dc7d8

Browse files
cavenesstf-data-validation-team
caveness
authored and
tf-data-validation-team
committed
Update get_weight_feature to be used with Arrow table inputs and use same in arrow_util.enumerate_arrays.
PiperOrigin-RevId: 272281600
1 parent 702c12c commit 71dc7d8

File tree

4 files changed

+65
-71
lines changed

4 files changed

+65
-71
lines changed

tensorflow_data_validation/arrow/arrow_util.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,45 @@
3535
pywrap.TFDV_Arrow_MakeListArrayFromParentIndicesAndValues)
3636

3737

38+
def _get_weight_feature(input_table: pa.Table,
39+
weight_feature: Text) -> np.ndarray:
40+
"""Gets the weight column from the input table.
41+
42+
Args:
43+
input_table: Input table.
44+
weight_feature: Name of the weight feature.
45+
46+
Returns:
47+
A numpy array containing the weights of the examples in the input table.
48+
49+
Raises:
50+
ValueError: If the weight feature is not present in the input table or is
51+
not a valid weight feature (must be of numeric type and have a
52+
single value for each example).
53+
"""
54+
try:
55+
weights = input_table.column(weight_feature).data.chunk(0)
56+
except KeyError:
57+
raise ValueError('Weight feature "{}" not present in the input '
58+
'table.'.format(weight_feature))
59+
60+
# Before flattening, check that there is a single value for each example.
61+
weight_lengths = ListLengthsFromListArray(weights).to_numpy()
62+
if not np.all(weight_lengths == 1):
63+
raise ValueError(
64+
'Weight feature "{}" must have exactly one value in each example.'
65+
.format(weight_feature))
66+
weights = weights.flatten()
67+
# Before converting to numpy view, check the type (cannot convert string and
68+
# binary arrays to numpy view).
69+
weights_type = weights.type
70+
if pa.types.is_string(weights_type) or pa.types.is_binary(weights_type):
71+
raise ValueError(
72+
'Weight feature "{}" must be of numeric type. Found {}.'.format(
73+
weight_feature, weights_type))
74+
return weights.to_numpy()
75+
76+
3877
def primitive_array_to_numpy(primitive_array: pa.Array) -> np.ndarray:
3978
"""Converts a primitive Arrow array to a numpy 1-D ndarray.
4079
@@ -122,10 +161,7 @@ def _recursion_helper(
122161

123162
weights = None
124163
if weight_column is not None:
125-
weights = table.column(weight_column).data.chunk(0).flatten().to_numpy()
126-
if weights.size != table.num_rows:
127-
raise ValueError(
128-
'The weight feature must have exactly one value in each example')
164+
weights = _get_weight_feature(table, weight_column)
129165
for column in table.columns:
130166
column_name = column.name
131167
# use "yield from" after PY 3.3.

tensorflow_data_validation/arrow/arrow_util_test.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,17 +236,39 @@ def testMakeListArray(self, num_parents, parent_indices, values, expected):
236236

237237
class EnumerateArraysTest(absltest.TestCase):
238238

239-
def testInvalidWeightColumn(self):
239+
def testInvalidWeightColumnMissingValue(self):
240240
with self.assertRaisesRegex(
241241
ValueError,
242-
"weight feature must have exactly one value in each example"):
242+
'Weight feature "w" must have exactly one value.*'):
243243
for _ in arrow_util.enumerate_arrays(
244244
pa.Table.from_arrays([pa.array([[1], [2, 3]]),
245245
pa.array([[1], []])], ["v", "w"]),
246246
weight_column="w",
247247
enumerate_leaves_only=False):
248248
pass
249249

250+
def testInvalidWeightColumnTooManyValues(self):
251+
with self.assertRaisesRegex(
252+
ValueError,
253+
'Weight feature "w" must have exactly one value.*'):
254+
for _ in arrow_util.enumerate_arrays(
255+
pa.Table.from_arrays([pa.array([[1], [2, 3]]),
256+
pa.array([[1], [2, 2]])], ["v", "w"]),
257+
weight_column="w",
258+
enumerate_leaves_only=False):
259+
pass
260+
261+
def testInvalidWeightColumnStringValues(self):
262+
with self.assertRaisesRegex(
263+
ValueError,
264+
'Weight feature "w" must be of numeric type.*'):
265+
for _ in arrow_util.enumerate_arrays(
266+
pa.Table.from_arrays([pa.array([[1], [2, 3]]),
267+
pa.array([["two"], ["two"]])], ["v", "w"]),
268+
weight_column="w",
269+
enumerate_leaves_only=False):
270+
pass
271+
250272
def testEnumerate(self):
251273
input_table = pa.Table.from_arrays([
252274
pa.array([[1], [2, 3]]),

tensorflow_data_validation/utils/stats_util.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222
from tensorflow_data_validation import types
2323
from tensorflow_data_validation.pyarrow_tf import pyarrow as pa
24-
from typing import Dict, List, Optional, Text, Union
24+
from typing import Dict, Optional, Text, Union
2525
from google.protobuf import text_format
2626
# TODO(b/125849585): Update to import from TF directly.
2727
from tensorflow.python.lib.io import file_io # pylint: disable=g-direct-tensorflow-import
@@ -175,43 +175,6 @@ def _make_feature_stats_proto(
175175
return result
176176

177177

178-
def get_weight_feature(input_batch: types.ExampleBatch,
179-
weight_feature: types.FeatureName) -> List[np.ndarray]:
180-
"""Gets the weight feature from the input batch.
181-
182-
Args:
183-
input_batch: Input batch of examples.
184-
weight_feature: Name of the weight feature.
185-
186-
Returns:
187-
A list containing the weights of the examples in the input batch.
188-
189-
Raises:
190-
ValueError: If the weight feature is not present in the input batch or is
191-
not a valid weight feature (must be of numeric type and have a
192-
single value).
193-
"""
194-
try:
195-
weights = input_batch[weight_feature]
196-
except KeyError:
197-
raise ValueError('Weight feature "{}" not present in the input '
198-
'batch.'.format(weight_feature))
199-
200-
# Check if we have a valid weight feature.
201-
for w in weights:
202-
if w is None:
203-
raise ValueError('Weight feature "{}" missing in an '
204-
'example.'.format(weight_feature))
205-
elif (get_feature_type(w.dtype) ==
206-
statistics_pb2.FeatureNameStatistics.STRING):
207-
raise ValueError('Weight feature "{}" must be of numeric type. '
208-
'Found {}.'.format(weight_feature, w))
209-
elif w.size != 1:
210-
raise ValueError('Weight feature "{}" must have a single value. '
211-
'Found {}.'.format(weight_feature, w))
212-
return weights # pytype: disable=bad-return-type
213-
214-
215178
def write_stats_text(stats: statistics_pb2.DatasetFeatureStatisticsList,
216179
output_path: bytes) -> None:
217180
"""Writes a DatasetFeatureStatisticsList proto to a file in text format.

tensorflow_data_validation/utils/stats_util_test.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -122,33 +122,6 @@ def test_make_dataset_feature_stats_proto(self):
122122
expected[types.FeaturePath.from_proto(actual_feature_stats.path)],
123123
normalize_numbers=True)
124124

125-
def test_get_weight_feature_with_valid_weight_feature(self):
126-
batch = {'a': np.array([np.array([1, 2]), np.array([3])]),
127-
'w': np.array([np.array([10]), np.array([20])])}
128-
actual = stats_util.get_weight_feature(batch, 'w')
129-
np.testing.assert_equal(actual, batch['w'])
130-
131-
def test_get_weight_feature_invalid_weight_feature(self):
132-
batch = {'a': np.array([np.array([1])])}
133-
with self.assertRaisesRegexp(ValueError, 'Weight feature.*not present'):
134-
stats_util.get_weight_feature(batch, 'w')
135-
136-
def test_get_weight_feature_with_weight_feature_missing(self):
137-
batch = {'a': np.array([np.array([1])]), 'w': np.array([None])}
138-
with self.assertRaisesRegexp(ValueError, 'Weight feature.*missing'):
139-
stats_util.get_weight_feature(batch, 'w')
140-
141-
def test_get_weight_feature_with_weight_feature_string_type(self):
142-
batch = {'a': np.array([np.array([1])]), 'w': np.array([np.array(['a'])])}
143-
with self.assertRaisesRegexp(ValueError, 'Weight feature.*numeric type'):
144-
stats_util.get_weight_feature(batch, 'w')
145-
146-
def test_get_weight_feature_with_weight_feature_multiple_values(self):
147-
batch = {'a': np.array([np.array([1])]),
148-
'w': np.array([np.array([2, 3])])}
149-
with self.assertRaisesRegexp(ValueError, 'Weight feature.*single value'):
150-
stats_util.get_weight_feature(batch, 'w')
151-
152125
def test_get_utf8(self):
153126
self.assertEqual(u'This is valid.',
154127
stats_util.maybe_get_utf8(b'This is valid.'))

0 commit comments

Comments
 (0)