Skip to content

Commit a7059ac

Browse files
committed
Add a function to get the number of values in the innermost level of each array in the outmost level.
PiperOrigin-RevId: 630526074
1 parent 0ba56da commit a7059ac

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

tensorflow_data_validation/arrow/arrow_util.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,47 @@ def get_column(record_batch: pa.RecordBatch,
246246
return None
247247
raise KeyError('missing column %s' % feature_name)
248248
return record_batch.column(idx)
249+
250+
251+
def get_arries_innermost_level_value_counts(array: pa.array) -> np.ndarray:
252+
"""Gets the number of values in the innermost level of each example.
253+
254+
Returns an empty array if the input is not a nested array. However, if the
255+
input is a nested array, the function returns a numpy array containing the
256+
number of values present in the innermost level of each array within the top
257+
level.
258+
The handling of null/None values within nested arrays follows a specific logic
259+
as the following:
260+
- If a null/None is in place of an array at the outermost level, treat it as
261+
a missing array and do not compute value counts for it.
262+
- If a null/None is in place of a list and not at the outermost level, treat
263+
it like an empty list.
264+
- If a null/None is in place of a concrete innermost value type (e.g., an
265+
int), treat it as a value for counting purposes.
266+
267+
268+
Args:
269+
array: A pa.Array.
270+
271+
Returns:
272+
A numpy array containing the number of values in the innermost level of
273+
each outmost array.
274+
"""
275+
276+
offsets = []
277+
non_null = ~np.asarray(array.is_null())
278+
while array_util.is_list_like(array.type):
279+
offsets.append(np.asarray(array.offsets))
280+
array = array.flatten()
281+
flattened_arr = array.filter(array.is_valid())
282+
offsets = offsets[::-1]
283+
if flattened_arr and offsets:
284+
example_indices = offsets[0]
285+
for offset in offsets[1:]:
286+
example_indices = example_indices[offset]
287+
return np.diff(example_indices)[non_null]
288+
if offsets:
289+
# An empty array should have 0 values, whereas null does not contain any
290+
# values.
291+
return np.array([0] * np.count_nonzero(non_null))
292+
return np.array([])

tensorflow_data_validation/arrow/arrow_util_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from absl.testing import absltest
2525
from absl.testing import parameterized
2626
import numpy as np
27+
from numpy import testing as np_testing
2728
import pyarrow as pa
2829
import six
2930
from tensorflow_data_validation import types
@@ -451,5 +452,47 @@ def testGetColumn(self):
451452
with self.assertRaises(KeyError):
452453
arrow_util.get_column(_INPUT_RECORD_BATCH, "xyz")
453454

455+
@parameterized.named_parameters([
456+
dict(
457+
testcase_name="all_values_present",
458+
array=pa.array([[[1, 2], [3]], [[3], [4]]]),
459+
expected_counts=np.array([3, 2]),
460+
),
461+
dict(
462+
testcase_name="none_in_inner_level",
463+
array=pa.array([[[1, 2], None], [[3]], [None]]),
464+
expected_counts=np.array([2, 1, 0]),
465+
),
466+
dict(
467+
testcase_name="none_in_innermost_level",
468+
array=pa.array([[[1, 2]], [[3, None]]]),
469+
expected_counts=np.array([2, 2]),
470+
),
471+
dict(
472+
testcase_name="none_in_outermost_level",
473+
array=pa.array([[[1, 2]], None]),
474+
expected_counts=np.array([2]),
475+
),
476+
dict(
477+
testcase_name="all_nones",
478+
array=pa.array([None, [None, None], [[None]]]),
479+
expected_counts=np.array([0, 0]),
480+
),
481+
dict(
482+
testcase_name="empty_array",
483+
array=pa.array([[[]]]),
484+
expected_counts=np.array([0]),
485+
),
486+
dict(
487+
testcase_name="non_nested_array",
488+
array=pa.array([1, 2, 3]),
489+
expected_counts=np.array([]),
490+
),
491+
])
492+
def testGetArriesInnermostLevelValueCounts(self, array, expected_counts):
493+
got = arrow_util.get_arries_innermost_level_value_counts(array)
494+
np_testing.assert_array_equal(got, expected_counts)
495+
496+
454497
if __name__ == "__main__":
455498
absltest.main()

0 commit comments

Comments
 (0)