Skip to content

Commit 759c0f0

Browse files
authored
More helpful error when UDF arg type check failed (deephaven#5175)
1 parent e29e7c6 commit 759c0f0

File tree

3 files changed

+60
-39
lines changed

3 files changed

+60
-39
lines changed

py/server/deephaven/_udf.py

+43-34
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434

3535

3636
@dataclass
37-
class _ParsedParamAnnotation:
37+
class _ParsedParam:
38+
name: Union[str, int] = field(init=True)
3839
orig_types: set[type] = field(default_factory=set)
3940
encoded_types: set[str] = field(default_factory=set)
4041
none_allowed: bool = False
@@ -54,7 +55,7 @@ class _ParsedReturnAnnotation:
5455
@dataclass
5556
class _ParsedSignature:
5657
fn: Callable = None
57-
params: List[_ParsedParamAnnotation] = field(default_factory=list)
58+
params: List[_ParsedParam] = field(default_factory=list)
5859
ret_annotation: _ParsedReturnAnnotation = None
5960

6061
@property
@@ -93,22 +94,22 @@ def _encode_param_type(t: type) -> str:
9394
return tc
9495

9596

96-
def _parse_param_annotation(annotation: Any) -> _ParsedParamAnnotation:
97+
def _parse_param(name: str, annotation: Any) -> _ParsedParam:
9798
""" Parse a parameter annotation in a function's signature """
98-
p_annotation = _ParsedParamAnnotation()
99+
p_param = _ParsedParam(name)
99100

100101
if annotation is inspect._empty:
101-
p_annotation.encoded_types.add("O")
102-
p_annotation.none_allowed = True
102+
p_param.encoded_types.add("O")
103+
p_param.none_allowed = True
103104
elif isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union:
104105
for t in annotation.__args__:
105-
_parse_type_no_nested(annotation, p_annotation, t)
106+
_parse_type_no_nested(annotation, p_param, t)
106107
else:
107-
_parse_type_no_nested(annotation, p_annotation, annotation)
108-
return p_annotation
108+
_parse_type_no_nested(annotation, p_param, annotation)
109+
return p_param
109110

110111

111-
def _parse_type_no_nested(annotation: Any, p_annotation: _ParsedParamAnnotation, t: Union[type, str]) -> None:
112+
def _parse_type_no_nested(annotation: Any, p_param: _ParsedParam, t: Union[type, str]) -> None:
112113
""" Parse a specific type (top level or nested in a top-level Union annotation) without handling nested types
113114
(e.g. a nested Union). The result is stored in the given _ParsedAnnotation object.
114115
"""
@@ -117,25 +118,25 @@ def _parse_type_no_nested(annotation: Any, p_annotation: _ParsedParamAnnotation,
117118
# annotation is already a type, and we can remove this line.
118119
t = eval(t) if isinstance(t, str) else t
119120

120-
p_annotation.orig_types.add(t)
121+
p_param.orig_types.add(t)
121122
tc = _encode_param_type(t)
122123
if "[" in tc:
123-
p_annotation.has_array = True
124+
p_param.has_array = True
124125
if tc in {"N", "O"}:
125-
p_annotation.none_allowed = True
126+
p_param.none_allowed = True
126127
if tc in _NUMPY_INT_TYPE_CODES:
127-
if p_annotation.int_char and p_annotation.int_char != tc:
128+
if p_param.int_char and p_param.int_char != tc:
128129
raise DHError(message=f"multiple integer types in annotation: {annotation}, "
129-
f"types: {p_annotation.int_char}, {tc}. this is not supported because it is not "
130+
f"types: {p_param.int_char}, {tc}. this is not supported because it is not "
130131
f"clear which Deephaven null value to use when checking for nulls in the argument")
131-
p_annotation.int_char = tc
132+
p_param.int_char = tc
132133
if tc in _NUMPY_FLOATING_TYPE_CODES:
133-
if p_annotation.floating_char and p_annotation.floating_char != tc:
134+
if p_param.floating_char and p_param.floating_char != tc:
134135
raise DHError(message=f"multiple floating types in annotation: {annotation}, "
135-
f"types: {p_annotation.floating_char}, {tc}. this is not supported because it is not "
136+
f"types: {p_param.floating_char}, {tc}. this is not supported because it is not "
136137
f"clear which Deephaven null value to use when checking for nulls in the argument")
137-
p_annotation.floating_char = tc
138-
p_annotation.encoded_types.add(tc)
138+
p_param.floating_char = tc
139+
p_param.encoded_types.add(tc)
139140

140141

141142
def _parse_return_annotation(annotation: Any) -> _ParsedReturnAnnotation:
@@ -182,8 +183,8 @@ def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufun
182183
p_sig.ret_annotation.encoded_type = rt_char
183184

184185
if isinstance(fn, numba.np.ufunc.dufunc.DUFunc):
185-
for p in params:
186-
pa = _ParsedParamAnnotation()
186+
for i, p in enumerate(params):
187+
pa = _ParsedParam(i + 1)
187188
pa.encoded_types.add(p)
188189
if p in _NUMPY_INT_TYPE_CODES:
189190
pa.int_char = p
@@ -198,8 +199,8 @@ def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufun
198199
input_decl = re.sub("[()]", "", input_decl).split(",")
199200
output_decl = re.sub("[()]", "", output_decl)
200201

201-
for p, d in zip(params, input_decl):
202-
pa = _ParsedParamAnnotation()
202+
for i, (p, d) in enumerate(zip(params, input_decl)):
203+
pa = _ParsedParam(i + 1)
203204
if d:
204205
pa.encoded_types.add("[" + p)
205206
pa.has_array = True
@@ -225,9 +226,10 @@ def _parse_np_ufunc_signature(fn: numpy.ufunc) -> _ParsedSignature:
225226
# them in the future (https://github.com/deephaven/deephaven-core/issues/4762)
226227
p_sig = _ParsedSignature(fn)
227228
if fn.nin > 0:
228-
pa = _ParsedParamAnnotation()
229-
pa.encoded_types.add("O")
230-
p_sig.params = [pa] * fn.nin
229+
for i in range(fn.nin):
230+
pa = _ParsedParam(i + 1)
231+
pa.encoded_types.add("O")
232+
p_sig.params.append(pa)
231233
p_sig.ret_annotation = _ParsedReturnAnnotation()
232234
p_sig.ret_annotation.encoded_type = "O"
233235
return p_sig
@@ -249,7 +251,7 @@ def _parse_signature(fn: Callable) -> _ParsedSignature:
249251
else:
250252
sig = inspect.signature(fn)
251253
for n, p in sig.parameters.items():
252-
p_sig.params.append(_parse_param_annotation(p.annotation))
254+
p_sig.params.append(_parse_param(n, p.annotation))
253255

254256
p_sig.ret_annotation = _parse_return_annotation(sig.return_annotation)
255257
return p_sig
@@ -263,11 +265,11 @@ def _is_from_np_type(param_types: set[type], np_type_char: str) -> bool:
263265
return False
264266

265267

266-
def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
268+
def _convert_arg(param: _ParsedParam, arg: Any) -> Any:
267269
""" Convert a single argument to the type specified by the annotation """
268270
if arg is None:
269271
if not param.none_allowed:
270-
raise TypeError(f"Argument {arg} is not compatible with annotation {param.orig_types}")
272+
raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation {param.orig_types}")
271273
else:
272274
return None
273275

@@ -277,12 +279,17 @@ def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
277279
# if it matches one of the encoded types, convert it
278280
if encoded_type in param.encoded_types:
279281
dtype = dtypes.from_np_dtype(np_dtype)
280-
return _j_array_to_numpy_array(dtype, arg, conv_null=True, type_promotion=False)
282+
try:
283+
return _j_array_to_numpy_array(dtype, arg, conv_null=True, type_promotion=False)
284+
except Exception as e:
285+
raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation"
286+
f" {param.encoded_types}"
287+
f"\n{str(e)}") from e
281288
# if the annotation is missing, or it is a generic object type, return the arg
282289
elif "O" in param.encoded_types:
283290
return arg
284291
else:
285-
raise TypeError(f"Argument {arg} is not compatible with annotation {param.encoded_types}")
292+
raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation {param.encoded_types}")
286293
else: # if the arg is not a Java array
287294
specific_types = param.encoded_types - {"N", "O"} # remove NoneType and object type
288295
if specific_types:
@@ -300,7 +307,8 @@ def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
300307
if param.none_allowed:
301308
return None
302309
else:
303-
raise DHError(f"Argument {arg} is not compatible with annotation {param.orig_types}")
310+
raise DHError(f"Argument {param.name!r}: {arg} is not compatible with annotation"
311+
f" {param.orig_types}")
304312
else:
305313
# return a numpy integer instance only if the annotation is a numpy type
306314
if _is_from_np_type(param.orig_types, param.int_char):
@@ -332,7 +340,8 @@ def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
332340
if "O" in param.encoded_types:
333341
return arg
334342
else:
335-
raise TypeError(f"Argument {arg} is not compatible with annotation {param.orig_types}")
343+
raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation"
344+
f" {param.orig_types}")
336345
else: # if no annotation or generic object, return arg
337346
return arg
338347

py/server/tests/test_numba_guvectorize.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
from numba import guvectorize, int64, int32
99

10-
from deephaven import empty_table, dtypes
10+
from deephaven import empty_table, dtypes, DHError
1111
from tests.testbase import BaseTestCase
1212

1313
a = np.arange(5, dtype=np.int64)
@@ -89,6 +89,18 @@ def g(x, res):
8989
t = empty_table(10).update(["X=i%3", "Y=ii"]).group_by("X").update("Z=g(Y)")
9090
self.assertEqual(t.columns[2].data_type, dtypes.long_array)
9191

92+
def test_type_mismatch_error(self):
93+
# vector input to scalar output function (m)->()
94+
@guvectorize([(int64[:], int64[:])], "(m)->()", nopython=True)
95+
def g(x, res):
96+
res[0] = 0
97+
for xi in x:
98+
res[0] += xi
99+
100+
with self.assertRaises(DHError) as cm:
101+
t = empty_table(10).update(["X=i%3", "Y=(double)ii"]).group_by("X").update("Z=g(Y)")
102+
self.assertIn("Argument 1", str(cm.exception))
103+
92104

93105
if __name__ == '__main__':
94106
unittest.main()

py/server/tests/test_udf_numpy_args.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def f3(p1: np.ndarray[np.bool_], p2=None) -> bool:
337337
self.assertEqual(t1.columns[2].data_type, dtypes.bool_)
338338
with self.assertRaises(DHError) as cm:
339339
t2 = t.update(["X1 = f3(null, Y )"])
340-
self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation")
340+
self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation")
341341

342342
def f31(p1: Optional[np.ndarray[bool]], p2=None) -> bool:
343343
return bool(len(p1)) if p1 is not None else False
@@ -352,7 +352,7 @@ def f1(p1: str, p2=None) -> bool:
352352
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? `deephaven`: null"])
353353
with self.assertRaises(DHError) as cm:
354354
t1 = t.update(["X1 = f1(Y)"])
355-
self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation")
355+
self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation")
356356

357357
def f11(p1: Union[str, None], p2=None) -> bool:
358358
return p1 is None
@@ -366,7 +366,7 @@ def f2(p1: np.datetime64, p2=None) -> bool:
366366
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? now() : null"])
367367
with self.assertRaises(DHError) as cm:
368368
t1 = t.update(["X1 = f2(Y)"])
369-
self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation")
369+
self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation")
370370

371371
def f21(p1: Union[np.datetime64, None], p2=None) -> bool:
372372
return p1 is None
@@ -380,7 +380,7 @@ def f3(p1: np.bool_, p2=None) -> bool:
380380
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? true : null"])
381381
with self.assertRaises(DHError) as cm:
382382
t1 = t.update(["X1 = f3(Y)"])
383-
self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation")
383+
self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation")
384384

385385
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? true : false"])
386386
t1 = t.update(["X1 = f3(Y)"])

0 commit comments

Comments
 (0)