Skip to content

Commit 231a1f9

Browse files
authored
Merge pull request #201 from ironArray/fixFiltering
Adding field-indexing to caterva2. Fixes #187.
2 parents 03ea726 + f2bd661 commit 231a1f9

File tree

3 files changed

+73
-19
lines changed

3 files changed

+73
-19
lines changed

caterva2/client.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -313,13 +313,13 @@ def get_download_url(self):
313313
"""
314314
return api_utils.get_download_url(self.path, self.urlbase)
315315

316-
def __getitem__(self, key):
316+
def __getitem__(self, item):
317317
"""
318318
Retrieves a slice of the dataset.
319319
320320
Parameters
321321
----------
322-
key : int, slice, tuple of ints and slices, or None
322+
item : int, slice, tuple of ints and slices, or None
323323
Specifies the slice to fetch.
324324
325325
Returns
@@ -340,7 +340,17 @@ def __getitem__(self, key):
340340
>>> ds[0:10]
341341
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
342342
"""
343-
return self.slice(key, as_blosc2=False)
343+
if isinstance(item, str): # used a filter or field to index so want blosc2 array as result
344+
fields = np.dtype(eval(self.dtype)).fields
345+
if fields is None:
346+
raise ValueError("The array is not structured (its dtype does not have fields)")
347+
if item in fields:
348+
# A shortcut to access fields
349+
return self.client.get_slice(self.path, as_blosc2=True, field=item) # arg key is None
350+
else: # used a filter (possibly lazyexpr)
351+
return self.client.get_slice(self.path, item, as_blosc2=True)
352+
else:
353+
return self.slice(item, as_blosc2=False)
344354

345355
def slice(
346356
self, key: int | slice | Sequence[slice], as_blosc2: bool = True
@@ -871,24 +881,25 @@ def fetch(self, path, slice_=None):
871881
[(1.0000500e-02, 1.0100005), (1.0050503e-02, 1.0100505)]],
872882
dtype=[('a', '<f4'), ('b', '<f8')])
873883
"""
874-
urlbase, path = _format_paths(self.urlbase, path)
875-
slice_ = api_utils.slice_to_string(slice_) # convert to string
876-
return api_utils.fetch_data(path, urlbase, {"slice_": slice_}, auth_cookie=self.cookie)
884+
# Does the same as get_slice but forces return of np array
885+
return self.get_slice(path, key=slice_, as_blosc2=False)
877886

878-
def get_slice(self, path, key=None, as_blosc2=True):
887+
def get_slice(self, path, key=None, as_blosc2=True, field=None):
879888
"""Get a slice of a File/Dataset.
880889
881890
Parameters
882891
----------
883-
key : int, slice, or sequence of slices
892+
key : int, slice, sequence of slices or str
884893
The slice to retrieve. If a single slice is provided, it will be
885894
applied to the first dimension. If a sequence of slices is
886895
provided, each slice will be applied to the corresponding
887-
dimension.
896+
dimension. If str, is interpreted as filter.
888897
as_blosc2 : bool
889898
If True (default), the result will be returned as a Blosc2 object
890899
(either a `SChunk` or `NDArray`). If False, it will be returned
891900
as a NumPy array (equivalent to `self[key]`).
901+
field: str
902+
Shortcut to access a field in a structured array. If provided, `key` is ignored.
892903
893904
Returns
894905
-------
@@ -905,12 +916,20 @@ def get_slice(self, path, key=None, as_blosc2=True):
905916
dtype=[('a', '<f4'), ('b', '<f8')])
906917
"""
907918
urlbase, path = _format_paths(self.urlbase, path)
908-
# Convert slices to strings
909-
slice_ = api_utils.slice_to_string(key)
910-
# Fetch and return the data as a Blosc2 object / NumPy array
911-
return api_utils.fetch_data(
912-
path, urlbase, {"slice_": slice_}, auth_cookie=self.cookie, as_blosc2=as_blosc2
913-
)
919+
if field: # blosc2 doesn't support indexing of multiple fields
920+
return api_utils.fetch_data(
921+
path, urlbase, {"field": field}, auth_cookie=self.cookie, as_blosc2=as_blosc2
922+
)
923+
if isinstance(key, str): # A filter has been passed
924+
return api_utils.fetch_data(
925+
path, urlbase, {"filter": key}, auth_cookie=self.cookie, as_blosc2=as_blosc2
926+
)
927+
else: # Convert slices to strings
928+
slice_ = api_utils.slice_to_string(key)
929+
# Fetch and return the data as a Blosc2 object / NumPy array
930+
return api_utils.fetch_data(
931+
path, urlbase, {"slice_": slice_}, auth_cookie=self.cookie, as_blosc2=as_blosc2
932+
)
914933

915934
def get_chunk(self, path, nchunk):
916935
"""

caterva2/services/sub.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import tarfile
2525
import typing
2626
import zipfile
27+
from argparse import ArgumentError
2728

2829
# Requirements
2930
import blosc2
@@ -779,6 +780,8 @@ async def fetch_data(
779780
path: pathlib.Path,
780781
slice_: str | None = None,
781782
user: db.User = Depends(optional_user),
783+
filter: str | None = None,
784+
field: str | None = None,
782785
):
783786
"""
784787
Fetch a dataset.
@@ -789,6 +792,10 @@ async def fetch_data(
789792
The path to the dataset.
790793
slice_ : str
791794
The slice to fetch.
795+
filter : str
796+
The filter to apply to the dataset.
797+
field : str
798+
The desired field of dataset. If provided, filter is ignored.
792799
793800
Returns
794801
-------
@@ -804,12 +811,20 @@ async def fetch_data(
804811
abspath, dataprep = abspath_and_dataprep(path, slice_, user=user)
805812
# This is still needed and will only update the necessary chunks
806813
await dataprep()
807-
container = open_b2(abspath, path)
814+
815+
if filter:
816+
if field:
817+
raise ArgumentError("Cannot handle both field and filter parameters at the same time")
818+
filter = filter.strip()
819+
container, _ = get_filtered_array(abspath, path, filter, sortby=None)
820+
else:
821+
container = open_b2(abspath, path)
808822

809823
if isinstance(container, blosc2.Proxy):
810824
container = container._cache
825+
container = container[field] if field else container
811826

812-
if isinstance(container, blosc2.NDArray | blosc2.LazyExpr | hdf5.HDF5Proxy):
827+
if isinstance(container, blosc2.NDArray | blosc2.LazyExpr | hdf5.HDF5Proxy | blosc2.NDField):
813828
array = container
814829
schunk = getattr(array, "schunk", None) # not really needed
815830
typesize = array.dtype.itemsize
@@ -834,14 +849,14 @@ async def fetch_data(
834849
for sl, sh in zip(slice_, shape, strict=False)
835850
)
836851

837-
if whole and not isinstance(array, blosc2.LazyExpr | hdf5.HDF5Proxy):
852+
if whole and (not isinstance(array, blosc2.LazyExpr | hdf5.HDF5Proxy | blosc2.NDField)) and (not filter):
838853
# Send the data in the file straight to the client,
839854
# avoiding slicing and re-compression.
840855
return FileResponse(abspath, filename=abspath.name, media_type="application/octet-stream")
841856

842857
if isinstance(array, hdf5.HDF5Proxy):
843858
data = array.to_cframe(() if slice_ is None else slice_)
844-
elif isinstance(array, blosc2.LazyExpr):
859+
elif isinstance(array, blosc2.LazyExpr | blosc2.NDField):
845860
data = array[() if slice_ is None else slice_]
846861
data = blosc2.asarray(data)
847862
data = data.to_cframe()
@@ -1895,13 +1910,15 @@ def get_filtered_array(abspath, path, filter, sortby):
18951910

18961911
# Filter rows only for NDArray with fields
18971912
if filter:
1913+
arr = arr._cache if isinstance(arr, blosc2.Proxy) else arr
18981914
# Check whether filter is the name of a field
18991915
if filter in arr.fields:
19001916
if arr.dtype.fields[filter][0] == bool: # noqa: E721
19011917
# If boolean, give the filter a boolean expression
19021918
filter = f"{filter} == True"
19031919
else:
19041920
raise IndexError("Filter should be a boolean expression")
1921+
19051922
# Let's create a LazyExpr with the filter
19061923
larr = arr[filter]
19071924
# TODO: do some benchmarking to see if this is worth it

caterva2/tests/test_api.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,24 @@ def test_lazyexpr_getchunk(auth_client):
740740
np.testing.assert_array_equal(out, out_expr)
741741

742742

743+
def test_lazyexpr_fields(auth_client):
744+
if not auth_client:
745+
pytest.skip("authentication support needed")
746+
747+
oppt = f"{TEST_CATERVA2_ROOT}/ds-1d-fields.b2nd"
748+
auth_client.subscribe(TEST_CATERVA2_ROOT)
749+
750+
# Test a field
751+
arr = auth_client.get(oppt)
752+
field = arr["a"]
753+
np.testing.assert_allclose(field[:], arr[:]["a"])
754+
755+
# Test a lazyexpr
756+
servered = arr["(a < 500) & (b >= .1)"][:]
757+
downloaded = arr.slice(None)["(a < 500) & (b >= .1)"][:]
758+
[np.testing.assert_array_equal(servered[f], downloaded[f]) for f in downloaded.dtype.fields]
759+
760+
743761
def test_expr_from_expr(auth_client):
744762
if not auth_client:
745763
pytest.skip("authentication support needed")

0 commit comments

Comments
 (0)