Skip to content

Commit e4d69fd

Browse files
authored
Add support of 1D arrays in UDF vectorization (#5100)
* Add support of 1D arrays to UDF vectorization * Add more test cases * Add a null test case for boolean arrays
1 parent 4117159 commit e4d69fd

File tree

3 files changed

+115
-5
lines changed

3 files changed

+115
-5
lines changed

engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java

+14-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,20 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
4444

4545
// TODO: support for vectorizing functions that return arrays
4646
// https://github.com/deephaven/deephaven-core/issues/4649
47-
private static final Set<Class<?>> vectorizableReturnTypes = Set.of(int.class, long.class, short.class, float.class,
48-
double.class, byte.class, Boolean.class, String.class, Instant.class, PyObject.class);
47+
private static final Set<Class<?>> vectorizableReturnTypes = Set.of(
48+
boolean.class, boolean[].class,
49+
Boolean.class, Boolean[].class,
50+
byte.class, byte[].class,
51+
short.class, short[].class,
52+
char.class, char[].class,
53+
int.class, int[].class,
54+
long.class, long[].class,
55+
float.class, float[].class,
56+
double.class, double[].class,
57+
String.class, String[].class,
58+
Instant.class, Instant[].class,
59+
PyObject.class, PyObject[].class,
60+
Object.class, Object[].class);
4961

5062
@Override
5163
public boolean isVectorizableReturnType() {

py/server/deephaven/_udf.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def _dh_vectorize(fn):
411411
and (3) the input arrays.
412412
"""
413413
p_sig = _parse_signature(fn)
414+
return_array = p_sig.ret_annotation.has_array
414415
ret_dtype = dtypes.from_np_dtype(np.dtype(p_sig.ret_annotation.encoded_type[-1]))
415416

416417
@wraps(fn)
@@ -428,10 +429,18 @@ def wrapper(*args):
428429
for i in range(chunk_size):
429430
scalar_args = next(vectorized_args)
430431
converted_args = _convert_args(p_sig, scalar_args)
431-
chunk_result[i] = _scalar(fn(*converted_args), ret_dtype)
432+
ret = fn(*converted_args)
433+
if return_array:
434+
chunk_result[i] = dtypes.array(ret_dtype, ret)
435+
else:
436+
chunk_result[i] = _scalar(ret, ret_dtype)
432437
else:
433438
for i in range(chunk_size):
434-
chunk_result[i] = _scalar(fn(), ret_dtype)
439+
ret = fn()
440+
if return_array:
441+
chunk_result[i] = dtypes.array(ret_dtype, ret)
442+
else:
443+
chunk_result[i] = _scalar(ret, ret_dtype)
435444

436445
return chunk_result
437446

py/server/tests/test_vectorization.py

+90-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import random
55
import unittest
66

7-
from typing import Optional
7+
from typing import Optional, Union
88
import numpy as np
99

1010
from deephaven import DHError, empty_table, dtypes
@@ -15,6 +15,8 @@
1515
from deephaven._udf import _dh_vectorize as dh_vectorize
1616
from tests.testbase import BaseTestCase
1717

18+
from tests.test_udf_numpy_args import _J_TYPE_NULL_MAP, _J_TYPE_NP_DTYPE_MAP, _J_TYPE_J_ARRAY_TYPE_MAP
19+
1820

1921
class VectorizationTestCase(BaseTestCase):
2022
def setUp(self):
@@ -278,6 +280,93 @@ def pyfunc(p1: np.int32, p2: np.int32, p3: Optional[np.int32]) -> Optional[int]:
278280
self.assertEqual(t.columns[1].data_type, dtypes.long)
279281
self.assertEqual(t.columns[2].data_type, dtypes.long)
280282

283+
def test_1d_array_args_no_null(self):
284+
col1_formula = "Col1 = i % 3"
285+
for j_dtype, np_dtype in _J_TYPE_NP_DTYPE_MAP.items():
286+
col2_formula = f"Col2 = ({j_dtype})i"
287+
with self.subTest(j_dtype):
288+
tbl = empty_table(10).update([col1_formula, col2_formula]).group_by("Col1").update(
289+
"Col2 = Col2.toArray()")
290+
291+
func_str = f"""
292+
def test_udf(col1, col2: np.ndarray[{_J_TYPE_NP_DTYPE_MAP[j_dtype]}]) -> np.ndarray[{_J_TYPE_NP_DTYPE_MAP[j_dtype]}]:
293+
return col2 + 5
294+
"""
295+
exec(func_str, globals())
296+
297+
res = tbl.update("Col3 = test_udf(Col1, Col2)")
298+
self.assertEqual(res.columns[0].data_type, dtypes.int32)
299+
self.assertEqual(res.columns[1].data_type, _J_TYPE_J_ARRAY_TYPE_MAP[j_dtype])
300+
self.assertEqual(res.columns[2].data_type, _J_TYPE_J_ARRAY_TYPE_MAP[j_dtype])
301+
302+
self.assertEqual(_udf.vectorized_count, 1)
303+
_udf.vectorized_count = 0
304+
305+
def test_1d_array_args_null(self):
306+
col1_formula = "Col1 = i % 3"
307+
for j_dtype, null_name in _J_TYPE_NULL_MAP.items():
308+
col2_formula = f"Col2 = i % 3 == 0? {null_name} : ({j_dtype})i"
309+
with self.subTest(j_dtype):
310+
tbl = empty_table(10).update([col1_formula, col2_formula]).group_by("Col1").update("Col2 = Col2.toArray()")
311+
312+
func_str = f"""
313+
def test_udf(col1, col2: np.ndarray[{_J_TYPE_NP_DTYPE_MAP[j_dtype]}]) -> np.ndarray[{_J_TYPE_NP_DTYPE_MAP[j_dtype]}]:
314+
return col2 + 5
315+
"""
316+
exec(func_str, globals())
317+
318+
# for floating point types, DH nulls are auto converted to np.nan
319+
# for integer types, DH nulls in the array raise exceptions
320+
if j_dtype in ("float", "double"):
321+
res = tbl.update("Col3 = test_udf(Col1, Col2)")
322+
self.assertEqual(res.columns[0].data_type, dtypes.int32)
323+
self.assertEqual(res.columns[1].data_type, _J_TYPE_J_ARRAY_TYPE_MAP[j_dtype])
324+
self.assertEqual(res.columns[2].data_type, _J_TYPE_J_ARRAY_TYPE_MAP[j_dtype])
325+
else:
326+
with self.assertRaises(DHError) as cm:
327+
tbl.update("Col3 = test_udf(Col1, Col2)")
328+
329+
self.assertEqual(_udf.vectorized_count, 1)
330+
_udf.vectorized_count = 0
331+
332+
def test_1d_str_bool_datetime_array(self):
333+
with self.subTest("str"):
334+
def f1(p1: np.ndarray[str]) -> bool:
335+
return (p1 == 'None').any()
336+
337+
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? `deephaven`: null"]).group_by("X").update("Y = Y.toArray()")
338+
t1 = t.update(["X1 = f1(Y)"])
339+
self.assertEqual(t1.columns[2].data_type, dtypes.bool_)
340+
self.assertEqual(3, t1.to_string().count("true"))
341+
self.assertEqual(_udf.vectorized_count, 1)
342+
_udf.vectorized_count = 0
343+
344+
with self.subTest("datetime"):
345+
def f2(p1: np.ndarray[np.datetime64]) -> bool:
346+
return np.isnat(p1).any()
347+
348+
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? now() : null"]).group_by("X").update("Y = Y.toArray()")
349+
t1 = t.update(["X1 = f2(Y)"])
350+
self.assertEqual(t1.columns[2].data_type, dtypes.bool_)
351+
self.assertEqual(3, t1.to_string().count("true"))
352+
self.assertEqual(_udf.vectorized_count, 1)
353+
_udf.vectorized_count = 0
354+
355+
with self.subTest("boolean"):
356+
def f3(p1: np.ndarray[np.bool_]) -> np.ndarray[np.bool_]:
357+
return np.invert(p1)
358+
359+
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? true : false"]).group_by("X").update("Y = Y.toArray()")
360+
t1 = t.update(["X1 = f3(Y)"])
361+
self.assertEqual(_udf.vectorized_count, 1)
362+
_udf.vectorized_count = 0
363+
364+
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? true : null"]).group_by("X").update("Y = Y.toArray()")
365+
with self.assertRaises(DHError) as cm:
366+
t1 = t.update(["X1 = f3(Y)"])
367+
self.assertIn("Java java.lang.Boolean array contains Deephaven null values, but numpy int8 array does not support null values", str(cm.exception))
368+
self.assertEqual(_udf.vectorized_count, 1)
369+
_udf.vectorized_count = 0
281370

282371
if __name__ == "__main__":
283372
unittest.main()

0 commit comments

Comments
 (0)