Skip to content

Commit 1db86e8

Browse files
committed
Improve support for pandas Extension Arrays (pydata#10301)
1 parent 97f02b4 commit 1db86e8

File tree

5 files changed

+398
-84
lines changed

5 files changed

+398
-84
lines changed

xarray/core/dtypes.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
6363
# N.B. these casting rules should match pandas
6464
dtype_: np.typing.DTypeLike
6565
fill_value: Any
66-
if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
66+
if is_extension_array_dtype(dtype):
67+
return dtype, dtype.na_value
68+
elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
6769
# for now, we always promote string dtypes to object for consistency with existing behavior
6870
# TODO: refactor this once we have a better way to handle numpy vlen-string dtypes
6971
dtype_ = object
@@ -222,19 +224,51 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
222224
return xp.isdtype(dtype, kind)
223225

224226

225-
def preprocess_types(t):
226-
if isinstance(t, str | bytes):
227-
return type(t)
228-
elif isinstance(dtype := getattr(t, "dtype", t), np.dtype) and (
229-
np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_)
230-
):
227+
def maybe_promote_to_variable_width(
228+
array_or_dtype: np.typing.ArrayLike | np.typing.DTypeLike,
229+
) -> np.typing.ArrayLike | np.typing.DTypeLike:
230+
if isinstance(array_or_dtype, str | bytes):
231+
return type(array_or_dtype)
232+
elif isinstance(
233+
dtype := getattr(array_or_dtype, "dtype", array_or_dtype), np.dtype
234+
) and (np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_)):
231235
# drop the length from numpy's fixed-width string dtypes, it is better to
232236
# recalculate
233237
# TODO(keewis): remove once the minimum version of `numpy.result_type` does this
234238
# for us
235239
return dtype.type
236240
else:
237-
return t
241+
return array_or_dtype
242+
243+
244+
def should_promote_to_object(
245+
arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, xp
246+
) -> bool:
247+
"""
248+
Test whether the given arrays_and_dtypes, when evaluated individually, match the
249+
type promotion rules found in PROMOTE_TO_OBJECT.
250+
"""
251+
np_result_types = set()
252+
for arr_or_dtype in arrays_and_dtypes:
253+
try:
254+
result_type = array_api_compat.result_type(
255+
maybe_promote_to_variable_width(arr_or_dtype), xp=xp
256+
)
257+
if isinstance(result_type, np.dtype):
258+
np_result_types.add(result_type)
259+
except TypeError:
260+
# passing individual objects to xp.result_type means NEP-18 implementations won't have
261+
# a chance to intercept special values (such as NA) that numpy core cannot handle
262+
pass
263+
264+
if np_result_types:
265+
for left, right in PROMOTE_TO_OBJECT:
266+
if any(np.issubdtype(t, left) for t in np_result_types) and any(
267+
np.issubdtype(t, right) for t in np_result_types
268+
):
269+
return True
270+
271+
return False
238272

239273

240274
def result_type(
@@ -263,19 +297,9 @@ def result_type(
263297
if xp is None:
264298
xp = get_array_namespace(arrays_and_dtypes)
265299

266-
types = {
267-
array_api_compat.result_type(preprocess_types(t), xp=xp)
268-
for t in arrays_and_dtypes
269-
}
270-
if any(isinstance(t, np.dtype) for t in types):
271-
# only check if there's numpy dtypes – the array API does not
272-
# define the types we're checking for
273-
for left, right in PROMOTE_TO_OBJECT:
274-
if any(np.issubdtype(t, left) for t in types) and any(
275-
np.issubdtype(t, right) for t in types
276-
):
277-
return np.dtype(object)
300+
if should_promote_to_object(arrays_and_dtypes, xp):
301+
return np.dtype(object)
278302

279303
return array_api_compat.result_type(
280-
*map(preprocess_types, arrays_and_dtypes), xp=xp
304+
*map(maybe_promote_to_variable_width, arrays_and_dtypes), xp=xp
281305
)

xarray/core/duck_array_ops.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
from xarray.compat import dask_array_compat, dask_array_ops
2828
from xarray.compat.array_api_compat import get_array_namespace
2929
from xarray.core import dtypes, nputils
30+
from xarray.core.extension_array import (
31+
PandasExtensionArray,
32+
as_extension_array,
33+
is_scalar,
34+
)
3035
from xarray.core.options import OPTIONS
3136
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
3237
from xarray.namedarray.parallelcompat import get_chunked_array_type
@@ -239,7 +244,14 @@ def astype(data, dtype, *, xp=None, **kwargs):
239244

240245

241246
def asarray(data, xp=np, dtype=None):
242-
converted = data if is_duck_array(data) else xp.asarray(data)
247+
if is_duck_array(data):
248+
converted = data
249+
elif is_extension_array_dtype(dtype):
250+
# data may or may not be an ExtensionArray, so we can't rely on
251+
# np.asarray to call our NEP-18 handler; gotta hook it ourselves
252+
converted = PandasExtensionArray(as_extension_array(data, dtype))
253+
else:
254+
converted = xp.asarray(data, dtype=dtype)
243255

244256
if dtype is None or converted.dtype == dtype:
245257
return converted
@@ -252,19 +264,6 @@ def asarray(data, xp=np, dtype=None):
252264

253265
def as_shared_dtype(scalars_or_arrays, xp=None):
254266
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
255-
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
256-
extension_array_types = [
257-
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
258-
]
259-
if len(extension_array_types) == len(scalars_or_arrays) and all(
260-
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
261-
):
262-
return scalars_or_arrays
263-
raise ValueError(
264-
"Cannot cast arrays to shared type, found"
265-
f" array types {[x.dtype for x in scalars_or_arrays]}"
266-
)
267-
268267
# Avoid calling array_type("cupy") repeatidely in the any check
269268
array_type_cupy = array_type("cupy")
270269
if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
@@ -384,7 +383,12 @@ def where(condition, x, y):
384383
else:
385384
condition = astype(condition, dtype=dtype, xp=xp)
386385

387-
return xp.where(condition, *as_shared_dtype([x, y], xp=xp))
386+
promoted_x, promoted_y = as_shared_dtype([x, y], xp=xp)
387+
388+
# pd.where won't broadcast 0-dim arrays across a series; scalar y's must be preserved
389+
maybe_promoted_y = y if is_extension_array_dtype(x) and is_scalar(y) else promoted_y
390+
391+
return xp.where(condition, promoted_x, maybe_promoted_y)
388392

389393

390394
def where_method(data, cond, other=dtypes.NA):

0 commit comments

Comments
 (0)