From 9312d2be63028718482607b250a70c2c67824f9a Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 May 2025 16:37:28 +0200 Subject: [PATCH 1/9] (fix): disallow `NumpyExtensionArray` --- properties/test_pandas_roundtrip.py | 28 ++++++++++++++++++++++------ xarray/core/dataset.py | 5 ++++- xarray/core/extension_array.py | 4 ++++ xarray/core/indexing.py | 7 +++++-- 4 files changed, 35 insertions(+), 9 deletions(-) diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 8fc32e75cbd..04babad1a23 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -134,10 +134,26 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None: xr.testing.assert_identical(dataset, roundtripped.to_xarray()) -def test_roundtrip_1d_pandas_extension_array() -> None: - df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])}) - arr = xr.Dataset.from_dataframe(df)["cat"] +@pytest.mark.parametrize( + "extension_array", + [ + pd.Categorical(["a", "b", "c"]), + pd.array([1, 2, 3], dtype="int64"), + pd.array(["a", "b", "c"], dtype="string"), + pd.arrays.IntervalArray( + [pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)] + ), + np.array([1, 2, 3], dtype="int64"), + ], +) +def test_roundtrip_1d_pandas_extension_array(extension_array) -> None: + df = pd.DataFrame({"arr": extension_array}) + arr = xr.Dataset.from_dataframe(df)["arr"] roundtripped = arr.to_pandas() - assert (df["cat"] == roundtripped).all() - assert df["cat"].dtype == roundtripped.dtype - xr.testing.assert_identical(arr, roundtripped.to_xarray()) + assert (df["arr"] == roundtripped).all() + # `NumpyExtensionArray` types are not roundtripped, including `StringArray` which subtypes. + if isinstance(extension_array, pd.arrays.NumpyExtensionArray): + assert isinstance(arr.data, np.ndarray) + else: + assert df["arr"].dtype == roundtripped.dtype + xr.testing.assert_identical(arr, roundtripped.to_xarray()) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5a7f757ba8a..0c918afba63 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7271,7 +7271,10 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: extension_arrays = [] for k, v in dataframe.items(): if not is_extension_array_dtype(v) or isinstance( - v.array, pd.arrays.DatetimeArray | pd.arrays.TimedeltaArray + v.array, + pd.arrays.DatetimeArray + | pd.arrays.TimedeltaArray + | pd.arrays.NumpyExtensionArray, ): arrays.append((k, np.asarray(v))) else: diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index f1631a5ea9e..0c0312e0e23 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -82,6 +82,10 @@ class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin): def __post_init__(self): if not isinstance(self.array, pd.api.extensions.ExtensionArray): raise TypeError(f"{self.array} is not an pandas ExtensionArray.") + if isinstance(self.array, pd.arrays.NumpyExtensionArray): + raise TypeError( + "`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally." + ) def __array_function__(self, func, types, args, kwargs): def replace_duck_with_extension_array(args) -> list: diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index c1b847202c7..ffd43b4bca2 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1802,8 +1802,11 @@ def __array__( def get_duck_array(self) -> np.ndarray | PandasExtensionArray: # We return an PandasExtensionArray wrapper type that satisfies - # duck array protocols. This is what's needed for tests to pass. - if pd.api.types.is_extension_array_dtype(self.array): + # duck array protocols. + # `NumpyExtensionArray` is excluded + if pd.api.types.is_extension_array_dtype(self.array) and not isinstance( + self.array.array, pd.arrays.NumpyExtensionArray + ): from xarray.core.extension_array import PandasExtensionArray return PandasExtensionArray(self.array.array) From d09c0f5a33bcb01dd4b33653f86113f9a9bcffbf Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 16 May 2025 15:24:06 +0200 Subject: [PATCH 2/9] (fix): clarify permitted extension array behavior in `as_compatible_data` --- xarray/core/variable.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 4e58b0d4b20..7b7cff623da 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -191,6 +191,13 @@ def _maybe_wrap_data(data): """ if isinstance(data, pd.Index): return PandasIndexingAdapter(data) + if isinstance( + data, + pd.arrays.DatetimeArray + | pd.arrays.TimedeltaArray + | pd.arrays.NumpyExtensionArray, + ): + return data.to_numpy() if isinstance(data, pd.api.extensions.ExtensionArray): return PandasExtensionArray(data) return data @@ -252,7 +259,19 @@ def convert_non_numpy_type(data): # we don't want nested self-described arrays if isinstance(data, pd.Series | pd.DataFrame): - pandas_data = data.values + if ( + isinstance(data, pd.Series) + and pd.api.types.is_extension_array_dtype(data) + and not isinstance( + data.array, + pd.arrays.DatetimeArray + | pd.arrays.TimedeltaArray + | pd.arrays.NumpyExtensionArray, + ) + ): + pandas_data = data.array + else: + pandas_data = data.values if isinstance(pandas_data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): return convert_non_numpy_type(pandas_data) else: From 88e484115d728e3f91f6b99943e72ec109356cee Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 19 May 2025 12:19:25 +0200 Subject: [PATCH 3/9] (refactor): centralize whitelist --- xarray/core/dataset.py | 6 ++---- xarray/core/extension_array.py | 6 ++++-- xarray/core/variable.py | 19 +++++++------------ 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 0c918afba63..3c4eda74114 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -99,6 +99,7 @@ parse_dims_as_set, ) from xarray.core.variable import ( + UNSUPPORTED_EXTENSION_ARRAY_TYPES, IndexVariable, Variable, as_variable, @@ -7271,10 +7272,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: extension_arrays = [] for k, v in dataframe.items(): if not is_extension_array_dtype(v) or isinstance( - v.array, - pd.arrays.DatetimeArray - | pd.arrays.TimedeltaArray - | pd.arrays.NumpyExtensionArray, + v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES ): arrays.append((k, np.asarray(v))) else: diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 0c0312e0e23..e949f93a387 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -80,11 +80,13 @@ class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin): array: T_ExtensionArray def __post_init__(self): + from xarray.core.variable import UNSUPPORTED_EXTENSION_ARRAY_TYPES + if not isinstance(self.array, pd.api.extensions.ExtensionArray): raise TypeError(f"{self.array} is not an pandas ExtensionArray.") - if isinstance(self.array, pd.arrays.NumpyExtensionArray): + if isinstance(self.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES): raise TypeError( - "`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally." + f"`{type(self.array)}` should be converted to a numpy array in `xarray` internally." ) def __array_function__(self, func, types, args, kwargs): diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 7b7cff623da..32fe55e2ac8 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -63,6 +63,11 @@ ) # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) +UNSUPPORTED_EXTENSION_ARRAY_TYPES = ( + pd.arrays.DatetimeArray, + pd.arrays.TimedeltaArray, + pd.arrays.NumpyExtensionArray, +) if TYPE_CHECKING: from xarray.core.types import ( @@ -191,12 +196,7 @@ def _maybe_wrap_data(data): """ if isinstance(data, pd.Index): return PandasIndexingAdapter(data) - if isinstance( - data, - pd.arrays.DatetimeArray - | pd.arrays.TimedeltaArray - | pd.arrays.NumpyExtensionArray, - ): + if isinstance(data, UNSUPPORTED_EXTENSION_ARRAY_TYPES): return data.to_numpy() if isinstance(data, pd.api.extensions.ExtensionArray): return PandasExtensionArray(data) @@ -262,12 +262,7 @@ def convert_non_numpy_type(data): if ( isinstance(data, pd.Series) and pd.api.types.is_extension_array_dtype(data) - and not isinstance( - data.array, - pd.arrays.DatetimeArray - | pd.arrays.TimedeltaArray - | pd.arrays.NumpyExtensionArray, - ) + and not isinstance(data.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES) ): pandas_data = data.array else: From 174274ddc58df77a152ed982c6151530db16d75f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 19 May 2025 12:26:57 +0200 Subject: [PATCH 4/9] (fix): allow through other types --- xarray/core/extension_array.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index e949f93a387..9377b442aab 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -80,13 +80,14 @@ class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin): array: T_ExtensionArray def __post_init__(self): - from xarray.core.variable import UNSUPPORTED_EXTENSION_ARRAY_TYPES - if not isinstance(self.array, pd.api.extensions.ExtensionArray): raise TypeError(f"{self.array} is not an pandas ExtensionArray.") - if isinstance(self.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES): + # This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because + # we do support extension arrays from datetime, for example, that need + # duck array support internally via this class. + if isinstance(self.array, pd.arrays.NumpyExtensionArray): raise TypeError( - f"`{type(self.array)}` should be converted to a numpy array in `xarray` internally." + "`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally." ) def __array_function__(self, func, types, args, kwargs): From 6329964330368fec5b20aedb708538e00d641658 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 30 May 2025 11:48:50 +0200 Subject: [PATCH 5/9] (chore): add thorough test cases --- properties/test_pandas_roundtrip.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 04babad1a23..fea14baafe3 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -138,22 +138,34 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None: "extension_array", [ pd.Categorical(["a", "b", "c"]), - pd.array([1, 2, 3], dtype="int64"), + pd.array([1, 2, 3], dtype="int64[pyarrow]"), pd.array(["a", "b", "c"], dtype="string"), pd.arrays.IntervalArray( [pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)] ), + pd.arrays.TimedeltaArray._from_sequence(pd.TimedeltaIndex(["1h", "2h", "3h"])), + pd.arrays.DatetimeArray._from_sequence( + pd.DatetimeIndex(["2023-01-01", "2023-01-02", "2023-01-03"], freq="D") + ), np.array([1, 2, 3], dtype="int64"), ], + ids=["cat", "pyarrow", "string", "interval", "timedelta", "datetime", "numpy"], ) -def test_roundtrip_1d_pandas_extension_array(extension_array) -> None: +@pytest.mark.parametrize("is_index", [True, False]) +def test_roundtrip_1d_pandas_extension_array(extension_array, is_index) -> None: df = pd.DataFrame({"arr": extension_array}) + if is_index: + df = df.set_index("arr") arr = xr.Dataset.from_dataframe(df)["arr"] roundtripped = arr.to_pandas() - assert (df["arr"] == roundtripped).all() + df_arr_to_test = df.index if is_index else df["arr"] + assert (df_arr_to_test == roundtripped).all() # `NumpyExtensionArray` types are not roundtripped, including `StringArray` which subtypes. if isinstance(extension_array, pd.arrays.NumpyExtensionArray): assert isinstance(arr.data, np.ndarray) else: - assert df["arr"].dtype == roundtripped.dtype + assert ( + df_arr_to_test.dtype + == (roundtripped.index if is_index else roundtripped).dtype + ) xr.testing.assert_identical(arr, roundtripped.to_xarray()) From c6ac491156fee6f32da752765dac3cbf185a5770 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 30 May 2025 14:24:16 +0200 Subject: [PATCH 6/9] (fix): require pyarrow --- properties/test_pandas_roundtrip.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index fea14baafe3..045c7eb7d66 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -15,6 +15,7 @@ import hypothesis.extra.pandas as pdst # isort:skip import hypothesis.strategies as st # isort:skip from hypothesis import given # isort:skip +from xarray.tests import requires_pyarrow numeric_dtypes = st.one_of( npst.unsigned_integer_dtypes(endianness="="), @@ -138,7 +139,9 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None: "extension_array", [ pd.Categorical(["a", "b", "c"]), - pd.array([1, 2, 3], dtype="int64[pyarrow]"), + pytest.param( + pd.array([1, 2, 3], dtype="int64[pyarrow]"), marks=requires_pyarrow + ), pd.array(["a", "b", "c"], dtype="string"), pd.arrays.IntervalArray( [pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)] From 50843ca2f69857d412b7a4ddc8528718d0abce9d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 30 May 2025 14:35:20 +0200 Subject: [PATCH 7/9] (fix): mypy --- properties/test_pandas_roundtrip.py | 2 +- xarray/core/extension_array.py | 2 +- xarray/core/indexing.py | 3 ++- xarray/core/variable.py | 4 ++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 045c7eb7d66..0d56186f461 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -164,7 +164,7 @@ def test_roundtrip_1d_pandas_extension_array(extension_array, is_index) -> None: df_arr_to_test = df.index if is_index else df["arr"] assert (df_arr_to_test == roundtripped).all() # `NumpyExtensionArray` types are not roundtripped, including `StringArray` which subtypes. - if isinstance(extension_array, pd.arrays.NumpyExtensionArray): + if isinstance(extension_array, pd.arrays.NumpyExtensionArray): # type: ignore[attr-defined] assert isinstance(arr.data, np.ndarray) else: assert ( diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 9052f5ae0a0..e280422c6e5 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -96,7 +96,7 @@ def __post_init__(self): # This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because # we do support extension arrays from datetime, for example, that need # duck array support internally via this class. - if isinstance(self.array, pd.arrays.NumpyExtensionArray): + if isinstance(self.array, pd.arrays.NumpyExtensionArray): # type: ignore[attr-defined] raise TypeError( "`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally." ) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index c0193e066d1..1813a25d7af 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1805,7 +1805,8 @@ def get_duck_array(self) -> np.ndarray | PandasExtensionArray: # duck array protocols. # `NumpyExtensionArray` is excluded if pd.api.types.is_extension_array_dtype(self.array) and not isinstance( - self.array.array, pd.arrays.NumpyExtensionArray + self.array.array, + pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined] ): from xarray.core.extension_array import PandasExtensionArray diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 910b524d42d..9c753a2ffa7 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -66,7 +66,7 @@ UNSUPPORTED_EXTENSION_ARRAY_TYPES = ( pd.arrays.DatetimeArray, pd.arrays.TimedeltaArray, - pd.arrays.NumpyExtensionArray, + pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined] ) if TYPE_CHECKING: @@ -265,7 +265,7 @@ def convert_non_numpy_type(data): ): pandas_data = data.array else: - pandas_data = data.values + pandas_data = data.values # type: ignore[assignment] if isinstance(pandas_data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): return convert_non_numpy_type(pandas_data) else: From b959345bdbb3c4c333acfb967d393f5e9accfca7 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 30 May 2025 14:38:24 +0200 Subject: [PATCH 8/9] (fix): pyarrow check --- properties/test_pandas_roundtrip.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 0d56186f461..ade2869ea3f 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -15,7 +15,7 @@ import hypothesis.extra.pandas as pdst # isort:skip import hypothesis.strategies as st # isort:skip from hypothesis import given # isort:skip -from xarray.tests import requires_pyarrow +from xarray.tests import has_pyarrow numeric_dtypes = st.one_of( npst.unsigned_integer_dtypes(endianness="="), @@ -139,9 +139,6 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None: "extension_array", [ pd.Categorical(["a", "b", "c"]), - pytest.param( - pd.array([1, 2, 3], dtype="int64[pyarrow]"), marks=requires_pyarrow - ), pd.array(["a", "b", "c"], dtype="string"), pd.arrays.IntervalArray( [pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)] @@ -151,8 +148,10 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None: pd.DatetimeIndex(["2023-01-01", "2023-01-02", "2023-01-03"], freq="D") ), np.array([1, 2, 3], dtype="int64"), - ], - ids=["cat", "pyarrow", "string", "interval", "timedelta", "datetime", "numpy"], + ] + + ([pd.array([1, 2, 3], dtype="int64[pyarrow]")] if has_pyarrow else []), + ids=["cat", "string", "interval", "timedelta", "datetime", "numpy"] + + (["pyarrow"] if has_pyarrow else []), ) @pytest.mark.parametrize("is_index", [True, False]) def test_roundtrip_1d_pandas_extension_array(extension_array, is_index) -> None: From 2d33aaa38cf48c2193b44b326849d99722f5e10a Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 30 May 2025 15:52:22 +0200 Subject: [PATCH 9/9] (fix): remove extra ignore --- xarray/core/extension_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index e280422c6e5..9052f5ae0a0 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -96,7 +96,7 @@ def __post_init__(self): # This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because # we do support extension arrays from datetime, for example, that need # duck array support internally via this class. - if isinstance(self.array, pd.arrays.NumpyExtensionArray): # type: ignore[attr-defined] + if isinstance(self.array, pd.arrays.NumpyExtensionArray): raise TypeError( "`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally." )