Skip to content

Commit f2bd661

Browse files
author
Luke Shaw
committed
Enabled field filtering for Proxy sources
1 parent 9f742a9 commit f2bd661

File tree

3 files changed

+35
-13
lines changed

3 files changed

+35
-13
lines changed

caterva2/client.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -313,13 +313,13 @@ def get_download_url(self):
313313
"""
314314
return api_utils.get_download_url(self.path, self.urlbase)
315315

316-
def __getitem__(self, key):
316+
def __getitem__(self, item):
317317
"""
318318
Retrieves a slice of the dataset.
319319
320320
Parameters
321321
----------
322-
key : int, slice, tuple of ints and slices, or None
322+
item : int, slice, tuple of ints and slices, or None
323323
Specifies the slice to fetch.
324324
325325
Returns
@@ -340,10 +340,17 @@ def __getitem__(self, key):
340340
>>> ds[0:10]
341341
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
342342
"""
343-
if isinstance(key, str): # used a filter or field to index so want blosc2 array as result
344-
return self.client.get_slice(self.path, key, as_blosc2=True)
343+
if isinstance(item, str): # used a filter or field to index so want blosc2 array as result
344+
fields = np.dtype(eval(self.dtype)).fields
345+
if fields is None:
346+
raise ValueError("The array is not structured (its dtype does not have fields)")
347+
if item in fields:
348+
# A shortcut to access fields
349+
return self.client.get_slice(self.path, as_blosc2=True, field=item) # arg key is None
350+
else: # used a filter (possibly lazyexpr)
351+
return self.client.get_slice(self.path, item, as_blosc2=True)
345352
else:
346-
return self.slice(key, as_blosc2=False)
353+
return self.slice(item, as_blosc2=False)
347354

348355
def slice(
349356
self, key: int | slice | Sequence[slice], as_blosc2: bool = True
@@ -874,7 +881,7 @@ def fetch(self, path, slice_=None):
874881
# Does the same as get_slice but forces return of np array
875882
return self.get_slice(path, key=slice_, as_blosc2=False)
876883

877-
def get_slice(self, path, key=None, as_blosc2=True):
884+
def get_slice(self, path, key=None, as_blosc2=True, field=None):
878885
"""Get a slice of a File/Dataset.
879886
880887
Parameters
@@ -888,6 +895,8 @@ def get_slice(self, path, key=None, as_blosc2=True):
888895
If True (default), the result will be returned as a Blosc2 object
889896
(either a `SChunk` or `NDArray`). If False, it will be returned
890897
as a NumPy array (equivalent to `self[key]`).
898+
field: str
899+
Shortcut to access a field in a structured array. If provided, `key` is ignored.
891900
892901
Returns
893902
-------
@@ -904,7 +913,11 @@ def get_slice(self, path, key=None, as_blosc2=True):
904913
dtype=[('a', '<f4'), ('b', '<f8')])
905914
"""
906915
urlbase, path = _format_paths(self.urlbase, path)
907-
if isinstance(key, str): # A filter or field has been passed
916+
if field: # blosc2 doesn't support indexing of multiple fields
917+
return api_utils.fetch_data(
918+
path, urlbase, {"field": field}, auth_cookie=self.cookie, as_blosc2=as_blosc2
919+
)
920+
if isinstance(key, str): # A filter has been passed
908921
return api_utils.fetch_data(
909922
path, urlbase, {"filter": key}, auth_cookie=self.cookie, as_blosc2=as_blosc2
910923
)

caterva2/services/sub.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import tarfile
2525
import typing
2626
import zipfile
27+
from argparse import ArgumentError
2728

2829
# Requirements
2930
import blosc2
@@ -747,6 +748,7 @@ async def fetch_data(
747748
slice_: str | None = None,
748749
user: db.User = Depends(optional_user),
749750
filter: str | None = None,
751+
field: str | None = None,
750752
):
751753
"""
752754
Fetch a dataset.
@@ -759,6 +761,8 @@ async def fetch_data(
759761
The slice to fetch.
760762
filter : str
761763
The filter to apply to the dataset.
764+
field : str
765+
The desired field of dataset. If provided, filter is ignored.
762766
763767
Returns
764768
-------
@@ -774,14 +778,18 @@ async def fetch_data(
774778
abspath, dataprep = abspath_and_dataprep(path, slice_, user=user)
775779
# This is still needed and will only update the necessary chunks
776780
await dataprep()
781+
777782
if filter:
783+
if field:
784+
raise ArgumentError("Cannot handle both field and filter parameters at the same time")
778785
filter = filter.strip()
779786
container, _ = get_filtered_array(abspath, path, filter, sortby=None)
780787
else:
781788
container = open_b2(abspath, path)
782789

783790
if isinstance(container, blosc2.Proxy):
784791
container = container._cache
792+
container = container[field] if field else container
785793

786794
if isinstance(container, blosc2.NDArray | blosc2.LazyExpr | hdf5.HDF5Proxy | blosc2.NDField):
787795
array = container
@@ -808,7 +816,7 @@ async def fetch_data(
808816
for sl, sh in zip(slice_, shape, strict=False)
809817
)
810818

811-
if whole and not isinstance(array, blosc2.LazyExpr | hdf5.HDF5Proxy | blosc2.NDField):
819+
if whole and (not isinstance(array, blosc2.LazyExpr | hdf5.HDF5Proxy | blosc2.NDField)) and (not filter):
812820
# Send the data in the file straight to the client,
813821
# avoiding slicing and re-compression.
814822
return FileResponse(abspath, filename=abspath.name, media_type="application/octet-stream")
@@ -1869,13 +1877,15 @@ def get_filtered_array(abspath, path, filter, sortby):
18691877

18701878
# Filter rows only for NDArray with fields
18711879
if filter:
1880+
arr = arr._cache if isinstance(arr, blosc2.Proxy) else arr
18721881
# Check whether filter is the name of a field
18731882
if filter in arr.fields:
18741883
if arr.dtype.fields[filter][0] == bool: # noqa: E721
18751884
# If boolean, give the filter a boolean expression
18761885
filter = f"{filter} == True"
18771886
else:
1878-
return arr[filter], idx # just return the blosc2 NDfield associated to field
1887+
raise IndexError("Filter should be a boolean expression")
1888+
18791889
# Let's create a LazyExpr with the filter
18801890
larr = arr[filter]
18811891
# TODO: do some benchmarking to see if this is worth it

caterva2/tests/test_api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,6 @@ def test_lazyexpr_fields(auth_client):
698698
if not auth_client:
699699
pytest.skip("authentication support needed")
700700

701-
opnm = "ds"
702701
oppt = f"{TEST_CATERVA2_ROOT}/ds-1d-fields.b2nd"
703702
auth_client.subscribe(TEST_CATERVA2_ROOT)
704703

@@ -708,9 +707,9 @@ def test_lazyexpr_fields(auth_client):
708707
np.testing.assert_allclose(field[:], arr[:]["a"])
709708

710709
# Test a lazyexpr
711-
servered = arr["a < 500 & b >= .1"]
712-
downloaded = arr.slice(None)["a < 500 & b >= .1"]
713-
np.testing.assert_allclose(servered[:], downloaded[:])
710+
servered = arr["(a < 500) & (b >= .1)"][:]
711+
downloaded = arr.slice(None)["(a < 500) & (b >= .1)"][:]
712+
[np.testing.assert_array_equal(servered[f], downloaded[f]) for f in downloaded.dtype.fields]
714713

715714

716715
def test_expr_from_expr(auth_client):

0 commit comments

Comments
 (0)