Skip to content

Commit 1c2f366

Browse files
Bye-legumeszhaoch23gvspraveen
authored
[Data] remove pyarrow 8 check for arrow.py (#52404)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? Since we only support pyarrow>=8.0.0, this condition check can be removed. @zhaoch23 ## Related issue number <!-- For example: "Closes #1234" --> ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: zhilong <zhilong.chen@mail.mcgill.ca> Co-authored-by: Zhaoch <c233zhao@uwaterloo.ca> Co-authored-by: Praveen <praveeng@anyscale.com>
1 parent cca9b9e commit 1c2f366

File tree

1 file changed

+35
-79
lines changed
  • python/ray/air/util/tensor_extensions

1 file changed

+35
-79
lines changed

python/ray/air/util/tensor_extensions/arrow.py

Lines changed: 35 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727
from ray.util.annotations import DeveloperAPI, PublicAPI
2828

2929
PYARROW_VERSION = get_pyarrow_version()
30-
# Minimum version of Arrow that supports ExtensionScalars.
31-
# TODO(Clark): Remove conditional definition once we only support Arrow 8.0.0+.
32-
MIN_PYARROW_VERSION_SCALAR = parse_version("8.0.0")
3330
# Minimum version of Arrow that supports subclassable ExtensionScalars.
3431
# TODO(Clark): Remove conditional definition once we only support Arrow 9.0.0+.
3532
MIN_PYARROW_VERSION_SCALAR_SUBCLASS = parse_version("9.0.0")
@@ -59,17 +56,6 @@ def __init__(self, data_str: str):
5956
super().__init__(message)
6057

6158

62-
def _arrow_supports_extension_scalars():
63-
"""
64-
Whether Arrow ExtensionScalars are supported in the current pyarrow version.
65-
66-
This returns True if the pyarrow version is 8.0.0+, or if the pyarrow version is
67-
unknown.
68-
"""
69-
# TODO(Clark): Remove utility once we only support Arrow 8.0.0+.
70-
return PYARROW_VERSION is None or PYARROW_VERSION >= MIN_PYARROW_VERSION_SCALAR
71-
72-
7359
def _arrow_extension_scalars_are_subclassable():
7460
"""
7561
Whether Arrow ExtensionScalars support subclassing in the current pyarrow version.
@@ -489,20 +475,16 @@ def __arrow_ext_scalar_class__(self):
489475
"""
490476
return ArrowTensorScalar
491477

492-
if _arrow_supports_extension_scalars():
493-
# TODO(Clark): Remove this version guard once we only support Arrow 8.0.0+.
494-
def _extension_scalar_to_ndarray(
495-
self, scalar: pa.ExtensionScalar
496-
) -> np.ndarray:
497-
"""
498-
Convert an ExtensionScalar to a tensor element.
499-
"""
500-
raw_values = scalar.value.values
501-
shape = scalar.type.shape
502-
value_type = raw_values.type
503-
offset = raw_values.offset
504-
data_buffer = raw_values.buffers()[1]
505-
return _to_ndarray_helper(shape, value_type, offset, data_buffer)
478+
def _extension_scalar_to_ndarray(self, scalar: "pa.ExtensionScalar") -> np.ndarray:
479+
"""
480+
Convert an ExtensionScalar to a tensor element.
481+
"""
482+
raw_values = scalar.value.values
483+
shape = scalar.type.shape
484+
value_type = raw_values.type
485+
offset = raw_values.offset
486+
data_buffer = raw_values.buffers()[1]
487+
return _to_ndarray_helper(shape, value_type, offset, data_buffer)
506488

507489
def __str__(self) -> str:
508490
return (
@@ -657,42 +639,20 @@ def to_pylist(self):
657639
# support (see comment in __getitem__).
658640
return list(self)
659641

660-
if _arrow_supports_extension_scalars():
661-
# NOTE(Clark): This __getitem__ override is only needed for Arrow 8.*,
662-
# before ExtensionScalar subclassing support was added.
663-
# TODO(Clark): Remove these methods once we only support Arrow 9.0.0+.
664-
def __getitem__(self, key):
665-
# This __getitem__ hook allows us to support proper indexing when
666-
# accessing a single tensor (a "scalar" item of the array). Without this
667-
# hook for integer keys, the indexing will fail on pyarrow < 9.0.0 due
668-
# to a lack of ExtensionScalar subclassing support.
669-
670-
# NOTE(Clark): We'd like to override the pa.Array.getitem() helper
671-
# instead, which would obviate the need for overriding __iter__(), but
672-
# unfortunately overriding Cython cdef methods with normal Python
673-
# methods isn't allowed.
674-
item = super().__getitem__(key)
675-
if not isinstance(key, slice):
676-
item = item.type._extension_scalar_to_ndarray(item)
677-
return item
642+
def __getitem__(self, key):
643+
# This __getitem__ hook allows us to support proper indexing when
644+
# accessing a single tensor (a "scalar" item of the array). Without this
645+
# hook for integer keys, the indexing will fail on pyarrow < 9.0.0 due
646+
# to a lack of ExtensionScalar subclassing support.
678647

679-
else:
680-
# NOTE(Clark): This __getitem__ override is only needed for Arrow < 8.0.0,
681-
# before any ExtensionScalar support was added.
682-
# TODO(Clark): Remove these methods once we only support Arrow 8.0.0+.
683-
def __getitem__(self, key):
684-
# This __getitem__ hook allows us to support proper indexing when
685-
# accessing a single tensor (a "scalar" item of the array). Without this
686-
# hook for integer keys, the indexing will fail on pyarrow < 8.0.0 due
687-
# to a lack of ExtensionScalar support.
688-
689-
# NOTE(Clark): We'd like to override the pa.Array.getitem() helper
690-
# instead, which would obviate the need for overriding __iter__(), but
691-
# unfortunately overriding Cython cdef methods with normal Python
692-
# methods isn't allowed.
693-
if isinstance(key, slice):
694-
return super().__getitem__(key)
695-
return self._to_numpy(key)
648+
# NOTE(Clark): We'd like to override the pa.Array.getitem() helper
649+
# instead, which would obviate the need for overriding __iter__(), but
650+
# unfortunately overriding Cython cdef methods with normal Python
651+
# methods isn't allowed.
652+
item = super().__getitem__(key)
653+
if not isinstance(key, slice):
654+
item = item.type._extension_scalar_to_ndarray(item)
655+
return item
696656

697657

698658
# NOTE: We need to inherit from the mixin before pa.ExtensionArray to ensure that the
@@ -1109,22 +1069,18 @@ def __str__(self) -> str:
11091069
def __repr__(self) -> str:
11101070
return str(self)
11111071

1112-
if _arrow_supports_extension_scalars():
1113-
# TODO(Clark): Remove this version guard once we only support Arrow 8.0.0+.
1114-
def _extension_scalar_to_ndarray(
1115-
self, scalar: pa.ExtensionScalar
1116-
) -> np.ndarray:
1117-
"""
1118-
Convert an ExtensionScalar to a tensor element.
1119-
"""
1120-
data = scalar.value.get("data")
1121-
raw_values = data.values
1122-
1123-
shape = tuple(scalar.value.get("shape").as_py())
1124-
value_type = raw_values.type
1125-
offset = raw_values.offset
1126-
data_buffer = raw_values.buffers()[1]
1127-
return _to_ndarray_helper(shape, value_type, offset, data_buffer)
1072+
def _extension_scalar_to_ndarray(self, scalar: "pa.ExtensionScalar") -> np.ndarray:
1073+
"""
1074+
Convert an ExtensionScalar to a tensor element.
1075+
"""
1076+
data = scalar.value.get("data")
1077+
raw_values = data.values
1078+
1079+
shape = tuple(scalar.value.get("shape").as_py())
1080+
value_type = raw_values.type
1081+
offset = raw_values.offset
1082+
data_buffer = raw_values.buffers()[1]
1083+
return _to_ndarray_helper(shape, value_type, offset, data_buffer)
11281084

11291085

11301086
# NOTE: We need to inherit from the mixin before pa.ExtensionArray to ensure that the

0 commit comments

Comments
 (0)