Skip to content

Commit bf40237

Browse files
paulgc17tf-data-validation-team
authored and
tf-data-validation-team
committed
Use arrow utilities from tfx_bsl
PiperOrigin-RevId: 273571365
1 parent b9f060a commit bf40237

21 files changed

+41
-1500
lines changed

tensorflow_data_validation/arrow/arrow_util.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,9 @@
2121
import numpy as np
2222
from tensorflow_data_validation import types
2323
from tensorflow_data_validation.pyarrow_tf import pyarrow as pa
24-
from tensorflow_data_validation.pywrap import pywrap_tensorflow_data_validation as pywrap
24+
from tfx_bsl.arrow import array_util
2525
from typing import Iterable, Optional, Text, Tuple
2626

27-
# The following are function aliases thus valid function names.
28-
# pylint: disable=invalid-name
29-
ListLengthsFromListArray = pywrap.TFDV_Arrow_ListLengthsFromListArray
30-
GetFlattenedArrayParentIndices = pywrap.TFDV_Arrow_GetFlattenedArrayParentIndices
31-
GetArrayNullBitmapAsByteArray = pywrap.TFDV_Arrow_GetArrayNullBitmapAsByteArray
32-
GetBinaryArrayTotalByteSize = pywrap.TFDV_Arrow_GetBinaryArrayTotalByteSize
33-
ValueCounts = pywrap.TFDV_Arrow_ValueCounts
34-
MakeListArrayFromParentIndicesAndValues = (
35-
pywrap.TFDV_Arrow_MakeListArrayFromParentIndicesAndValues)
36-
3727

3828
def _get_weight_feature(input_table: pa.Table,
3929
weight_feature: Text) -> np.ndarray:
@@ -58,7 +48,7 @@ def _get_weight_feature(input_table: pa.Table,
5848
'table.'.format(weight_feature))
5949

6050
# Before flattening, check that there is a single value for each example.
61-
weight_lengths = ListLengthsFromListArray(weights).to_numpy()
51+
weight_lengths = array_util.ListLengthsFromListArray(weights).to_numpy()
6252
if not np.all(weight_lengths == 1):
6353
raise ValueError(
6454
'Weight feature "{}" must have exactly one value in each example.'
@@ -148,7 +138,8 @@ def _recursion_helper(
148138
flat_struct_array = array.flatten()
149139
flat_weights = None
150140
if weights is not None:
151-
flat_weights = weights[GetFlattenedArrayParentIndices(array).to_numpy()]
141+
flat_weights = weights[
142+
array_util.GetFlattenedArrayParentIndices(array).to_numpy()]
152143
for field in flat_struct_array.type:
153144
field_name = field.name
154145
# use "yield from" after PY 3.3.

tensorflow_data_validation/arrow/arrow_util_test.py

Lines changed: 0 additions & 207 deletions
Original file line numberDiff line numberDiff line change
@@ -20,220 +20,13 @@
2020
import itertools
2121

2222
from absl.testing import absltest
23-
from absl.testing import parameterized
2423
import numpy as np
2524
import six
2625
from tensorflow_data_validation import types
2726
from tensorflow_data_validation.arrow import arrow_util
2827
from tensorflow_data_validation.pyarrow_tf import pyarrow as pa
2928

3029

31-
class ArrowUtilTest(absltest.TestCase):
32-
33-
def test_invalid_input_type(self):
34-
35-
functions_expecting_list_array = [
36-
arrow_util.ListLengthsFromListArray,
37-
arrow_util.GetFlattenedArrayParentIndices,
38-
]
39-
functions_expecting_array = [arrow_util.GetArrayNullBitmapAsByteArray]
40-
functions_expecting_binary_array = [arrow_util.GetBinaryArrayTotalByteSize]
41-
for f in itertools.chain(functions_expecting_list_array,
42-
functions_expecting_array,
43-
functions_expecting_binary_array):
44-
with self.assertRaisesRegex(RuntimeError, "Could not unwrap Array"):
45-
f(1)
46-
47-
for f in functions_expecting_list_array:
48-
with self.assertRaisesRegex(RuntimeError, "Expected ListArray but got"):
49-
f(pa.array([1, 2, 3]))
50-
51-
for f in functions_expecting_binary_array:
52-
with self.assertRaisesRegex(RuntimeError, "Expected BinaryArray"):
53-
f(pa.array([[1, 2, 3]]))
54-
55-
def test_list_lengths(self):
56-
list_lengths = arrow_util.ListLengthsFromListArray(
57-
pa.array([], type=pa.list_(pa.int64())))
58-
self.assertTrue(list_lengths.equals(pa.array([], type=pa.int32())))
59-
list_lengths = arrow_util.ListLengthsFromListArray(
60-
pa.array([[1., 2.], [], [3.]]))
61-
self.assertTrue(list_lengths.equals(pa.array([2, 0, 1], type=pa.int32())))
62-
list_lengths = arrow_util.ListLengthsFromListArray(
63-
pa.array([[1., 2.], None, [3.]]))
64-
self.assertTrue(list_lengths.equals(pa.array([2, 0, 1], type=pa.int32())))
65-
66-
def test_get_array_null_bitmap_as_byte_array(self):
67-
array = pa.array([], type=pa.int32())
68-
null_masks = arrow_util.GetArrayNullBitmapAsByteArray(array)
69-
self.assertTrue(null_masks.equals(pa.array([], type=pa.uint8())))
70-
71-
array = pa.array([1, 2, None, 3, None], type=pa.int32())
72-
null_masks = arrow_util.GetArrayNullBitmapAsByteArray(array)
73-
self.assertTrue(
74-
null_masks.equals(pa.array([0, 0, 1, 0, 1], type=pa.uint8())))
75-
76-
array = pa.array([1, 2, 3])
77-
null_masks = arrow_util.GetArrayNullBitmapAsByteArray(array)
78-
self.assertTrue(null_masks.equals(pa.array([0, 0, 0], type=pa.uint8())))
79-
80-
array = pa.array([None, None, None], type=pa.int32())
81-
null_masks = arrow_util.GetArrayNullBitmapAsByteArray(array)
82-
self.assertTrue(null_masks.equals(pa.array([1, 1, 1], type=pa.uint8())))
83-
# Demonstrate that the returned array can be converted to a numpy boolean
84-
# array w/o copying
85-
np.testing.assert_equal(
86-
np.array([True, True, True]), null_masks.to_numpy().view(np.bool))
87-
88-
def test_get_flattened_array_parent_indices(self):
89-
indices = arrow_util.GetFlattenedArrayParentIndices(
90-
pa.array([], type=pa.list_(pa.int32())))
91-
self.assertTrue(indices.equals(pa.array([], type=pa.int32())))
92-
93-
indices = arrow_util.GetFlattenedArrayParentIndices(
94-
pa.array([[1.], [2.], [], [3.]]))
95-
self.assertTrue(indices.equals(pa.array([0, 1, 3], type=pa.int32())))
96-
97-
def test_get_binary_array_total_byte_size(self):
98-
binary_array = pa.array([b"abc", None, b"def", b"", b"ghi"])
99-
self.assertEqual(9, arrow_util.GetBinaryArrayTotalByteSize(binary_array))
100-
sliced_1_2 = binary_array.slice(1, 2)
101-
self.assertEqual(3, arrow_util.GetBinaryArrayTotalByteSize(sliced_1_2))
102-
sliced_2 = binary_array.slice(2)
103-
self.assertEqual(6, arrow_util.GetBinaryArrayTotalByteSize(sliced_2))
104-
105-
unicode_array = pa.array([u"abc"])
106-
self.assertEqual(3, arrow_util.GetBinaryArrayTotalByteSize(unicode_array))
107-
108-
empty_array = pa.array([], type=pa.binary())
109-
self.assertEqual(0, arrow_util.GetBinaryArrayTotalByteSize(empty_array))
110-
111-
def _value_counts_struct_array_to_dict(self, value_counts):
112-
result = {}
113-
for value_count in value_counts:
114-
value_count = value_count.as_py()
115-
result[value_count["values"]] = value_count["counts"]
116-
return result
117-
118-
def test_value_counts_binary(self):
119-
binary_array = pa.array([b"abc", b"ghi", b"def", b"ghi", b"ghi", b"def"])
120-
expected_result = {b"abc": 1, b"ghi": 3, b"def": 2}
121-
self.assertDictEqual(self._value_counts_struct_array_to_dict(
122-
arrow_util.ValueCounts(binary_array)), expected_result)
123-
124-
def test_value_counts_integer(self):
125-
int_array = pa.array([1, 4, 1, 3, 1, 4])
126-
expected_result = {1: 3, 4: 2, 3: 1}
127-
self.assertDictEqual(self._value_counts_struct_array_to_dict(
128-
arrow_util.ValueCounts(int_array)), expected_result)
129-
130-
def test_value_counts_empty(self):
131-
empty_array = pa.array([])
132-
expected_result = {}
133-
self.assertDictEqual(self._value_counts_struct_array_to_dict(
134-
arrow_util.ValueCounts(empty_array)), expected_result)
135-
136-
_MAKE_LIST_ARRAY_INVALID_INPUT_TEST_CASES = [
137-
dict(
138-
testcase_name="invalid_parent_index",
139-
num_parents=None,
140-
parent_indices=np.array([0], dtype=np.int64),
141-
values=pa.array([1]),
142-
expected_error=RuntimeError,
143-
expected_error_regexp="Expected integer"),
144-
dict(
145-
testcase_name="parent_indices_not_np",
146-
num_parents=1,
147-
parent_indices=[0],
148-
values=pa.array([1]),
149-
expected_error=TypeError,
150-
expected_error_regexp="to be a numpy array"
151-
),
152-
dict(
153-
testcase_name="parent_indices_not_1d",
154-
num_parents=1,
155-
parent_indices=np.array([[0]], dtype=np.int64),
156-
values=pa.array([1]),
157-
expected_error=TypeError,
158-
expected_error_regexp="to be a 1-D int64 numpy array"
159-
),
160-
dict(
161-
testcase_name="parent_indices_not_int64",
162-
num_parents=1,
163-
parent_indices=np.array([0], dtype=np.int32),
164-
values=pa.array([1]),
165-
expected_error=TypeError,
166-
expected_error_regexp="to be a 1-D int64 numpy array"
167-
),
168-
dict(
169-
testcase_name="parent_indices_length_not_equal_to_values_length",
170-
num_parents=1,
171-
parent_indices=np.array([0], dtype=np.int64),
172-
values=pa.array([1, 2]),
173-
expected_error=RuntimeError,
174-
expected_error_regexp="values array and parent indices array must be of the same length"
175-
),
176-
dict(
177-
testcase_name="num_parents_too_small",
178-
num_parents=1,
179-
parent_indices=np.array([1], dtype=np.int64),
180-
values=pa.array([1]),
181-
expected_error=RuntimeError,
182-
expected_error_regexp="Found a parent index 1 while num_parents was 1"
183-
)
184-
]
185-
186-
187-
_MAKE_LIST_ARRAY_TEST_CASES = [
188-
dict(
189-
testcase_name="parents_are_all_empty",
190-
num_parents=5,
191-
parent_indices=np.array([], dtype=np.int64),
192-
values=pa.array([], type=pa.int64()),
193-
expected=pa.array([None, None, None, None, None],
194-
type=pa.list_(pa.int64()))),
195-
dict(
196-
testcase_name="long_num_parent",
197-
num_parents=(long(1) if six.PY2 else 1),
198-
parent_indices=np.array([0], dtype=np.int64),
199-
values=pa.array([1]),
200-
expected=pa.array([[1]])
201-
),
202-
dict(
203-
testcase_name="leading nones",
204-
num_parents=3,
205-
parent_indices=np.array([2], dtype=np.int64),
206-
values=pa.array([1]),
207-
expected=pa.array([None, None, [1]]),
208-
),
209-
dict(
210-
testcase_name="same_parent_and_holes",
211-
num_parents=4,
212-
parent_indices=np.array([0, 0, 0, 3, 3], dtype=np.int64),
213-
values=pa.array(["a", "b", "c", "d", "e"]),
214-
expected=pa.array([["a", "b", "c"], None, None, ["d", "e"]])
215-
)
216-
]
217-
218-
219-
class MakeListArrayFromParentIndicesAndValuesTest(parameterized.TestCase):
220-
221-
@parameterized.named_parameters(*_MAKE_LIST_ARRAY_INVALID_INPUT_TEST_CASES)
222-
def testInvalidInput(self, num_parents, parent_indices, values,
223-
expected_error, expected_error_regexp):
224-
with self.assertRaisesRegex(expected_error, expected_error_regexp):
225-
arrow_util.MakeListArrayFromParentIndicesAndValues(
226-
num_parents, parent_indices, values)
227-
228-
@parameterized.named_parameters(*_MAKE_LIST_ARRAY_TEST_CASES)
229-
def testMakeListArray(self, num_parents, parent_indices, values, expected):
230-
actual = arrow_util.MakeListArrayFromParentIndicesAndValues(
231-
num_parents, parent_indices, values)
232-
self.assertTrue(
233-
actual.equals(expected),
234-
"actual: {}, expected: {}".format(actual, expected))
235-
236-
23730
class EnumerateArraysTest(absltest.TestCase):
23831

23932
def testInvalidWeightColumnMissingValue(self):

tensorflow_data_validation/arrow/cc/BUILD

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,6 @@ cc_library(
1616
],
1717
)
1818

19-
cc_library(
20-
name = "arrow_util",
21-
srcs = ["arrow_util.cc"],
22-
hdrs = ["arrow_util.h"],
23-
deps = [
24-
":common",
25-
":init_numpy",
26-
"@arrow",
27-
"@com_google_absl//absl/strings",
28-
"@com_google_absl//absl/types:span",
29-
"@local_config_python//:python_headers",
30-
],
31-
)
32-
3319
cc_library(
3420
name = "decoded_examples_to_arrow",
3521
srcs = ["decoded_examples_to_arrow.cc"],
@@ -47,22 +33,6 @@ cc_library(
4733
],
4834
)
4935

50-
cc_library(
51-
name = "merge",
52-
srcs = ["merge.cc"],
53-
hdrs = ["merge.h"],
54-
deps = [
55-
":common",
56-
":init_numpy",
57-
"@arrow",
58-
"@com_google_absl//absl/container:flat_hash_map",
59-
"@com_google_absl//absl/strings",
60-
"@com_google_absl//absl/types:variant",
61-
"@local_config_python//:numpy_headers",
62-
"@local_config_python//:python_headers",
63-
],
64-
)
65-
6636
cc_library(
6737
name = "init_numpy",
6838
srcs = ["init_numpy.cc"],

0 commit comments

Comments
 (0)