Skip to content

Commit

Permalink
Fix dasdae format writes with datetime-like coords
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers committed Mar 21, 2024
1 parent f93ac93 commit c23d4e5
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 11 deletions.
41 changes: 36 additions & 5 deletions dascore/core/coordmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from collections import defaultdict
from collections.abc import Mapping, Sequence, Sized
from functools import reduce
from itertools import zip_longest
from operator import and_, or_
from typing import Annotated, Any, TypeVar

Expand Down Expand Up @@ -560,6 +561,23 @@ def set_dims(self, **kwargs: str) -> Self:
dim_map[coord_name] = tuple(old_to_new[x] for x in coord_dims)
return self.__class__(dims=dims, coord_map=coord_map, dim_map=dim_map)

def _get_single_dim_kwarg_list(self, kwargs):
"""Get a list of dicts where each dict uses a dimension at most once."""
used_coords = sorted(set(self.coord_map) & set(kwargs))
dims = [self.dim_map[x] for x in used_coords]
# No duplicate usage, just return.
if len(set(dims)) == len(dims):
return [{i: kwargs[i] for i in used_coords}]
# We need to split kwargs up so each dimension is used no more
# than once in each dict (element of the list).
dim_dicts = defaultdict(list)
for coord, dim in zip(used_coords, dims):
dim_dicts[dim].append(coord)
out = []
for args in zip_longest(*dim_dicts.values()):
out.append({x: kwargs[x] for x in args if x is not None})
return out

def select(
self, array: MaybeArray = None, relative=False, samples=False, **kwargs
) -> tuple[Self, MaybeArray]:
Expand All @@ -578,11 +596,24 @@ def select(
Used to specify select arguments. Can be of the form
{coord_name: (lower_limit, upper_limit)}.
"""
new_coords, indexers = _get_indexers_and_new_coords_dict(
self, kwargs, samples=samples, relative=relative
)
new_cm = self.update(**new_coords)
return new_cm, self._get_new_data(indexers, array)
# Relative or sample queries cannot be performed multiple times on
# the same dimension (since multiple coords can reference the same dim)
if relative or samples:
used_dims = [self.dim_map[x] for x in kwargs if x in self.coord_map]
if len(set(used_dims)) < len(used_dims):
msg = (
f"Cannot use {kwargs} for query; some coords " f"share a dimension."
)
raise CoordError(msg)
# Otherwise, we need to sort through kwargs and call in a loop.
kwarg_list = self._get_single_dim_kwarg_list(kwargs)
for kwargs in kwarg_list:
new_coords, indexers = _get_indexers_and_new_coords_dict(
self, kwargs, samples=samples, relative=relative
)
self = self.update(**new_coords)
array = self._get_new_data(indexers, array)
return self, array

def _get_new_data(self, indexer, array: MaybeArray) -> MaybeArray:
"""Get new data array after applying some trimming."""
Expand Down
6 changes: 6 additions & 0 deletions dascore/core/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,12 @@ def select(self, args, relative=False) -> tuple[Self, slice | ArrayLike]:
out = out & (values <= val2)
if not np.any(out):
return self.empty(), out
if np.all(out):
return self, slice(None, None)
# Convert boolean to int indexes because these are supported for
# indexing pytables arrays but booleans are not.
if len(self.shape) == 1:
out = np.arange(len(out))[out]
return self.new(values=values[out]), out

def sort(self, reverse=False) -> tuple[BaseCoord, slice | ArrayLike]:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_core/test_coordmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,26 @@ def test_samples_slice(self, coord_manager):
new, _ = coord_manager.select(time=..., samples=True)
assert new == coord_manager

def test_select_shared_dims(self, coord_manager):
"""Ensure selections work when queries share a dimension."""
dist = coord_manager.get_coord("distance")
new_coord = np.arange(len(dist))
cm = coord_manager.update_coords(
d1=("distance", new_coord),
d2=("distance", new_coord),
)
# Relative values should raise when the same dim is targeted by
# multiple coords.
with pytest.raises(CoordError):
cm.select(d1=(3, None), d2=(None, 6), relative=True)
# Same for samples.
with pytest.raises(CoordError):
cm.select(d1=(3, None), d2=(None, 6), samples=True)
# But normal values should work and produce a shape of 4 for this case.
out, _ = cm.select(d1=(3, None), d2=(None, 6))
distance_dim = out.dims.index("distance")
assert out.shape[distance_dim] == 4


class TestEquals:
"""Tests for coord manager equality."""
Expand Down
3 changes: 1 addition & 2 deletions tests/test_core/test_coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,10 +1027,9 @@ def test_select(self, random_coord):
min_v, max_v = np.min(random_coord.values), np.max(random_coord.values)
dist = max_v - min_v
val1, val2 = min_v + 0.2 * dist, max_v - 0.2 * dist
new, bool_array = random_coord.select((val1, val2))
new, ind_array = random_coord.select((val1, val2))
assert np.all(new.values >= val1)
assert np.all(new.values <= val2)
assert bool_array.sum() == len(new)

def test_sort(self, random_coord):
"""Ensure the coord can be ordered."""
Expand Down
33 changes: 29 additions & 4 deletions tests/test_io/test_dasdae/test_dasdae.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def test_indexed_vs_unindexed(
class TestRoundTrips:
"""Tests for round-tripping various patches/spools."""

formatter = DASDAEV1()

def test_write_patch_with_lat_lon(
self, random_patch_with_lat_lon, tmp_path_factory
):
Expand Down Expand Up @@ -202,8 +204,7 @@ def test_roundtrip_empty_time_patch(self, tmp_path_factory, random_patch):
time_max = time.max() + 3 * time.step
empty_patch = patch.select(time=(time_max, ...))
empty_patch.io.write(path, "dasdae")
formatter = DASDAEV1()
spool = formatter.read(path)
spool = self.formatter.read(path)
new_patch = spool[0]
assert empty_patch.equals(new_patch)

Expand All @@ -217,7 +218,31 @@ def test_roundtrip_dim_1_patch(self, tmp_path_factory, random_patch):
time_min="2023-06-13T15:38:00.49953408",
)
patch.io.write(path, "dasdae")
formatter = DASDAEV1()
spool = formatter.read(path)

spool = self.formatter.read(path)
new_patch = spool[0]
assert patch.equals(new_patch)

def test_roundtrip_datetime_coord(self, tmp_path_factory, random_patch):
"""Ensure a patch with an attached datetime coord works."""
path = tmp_path_factory.mktemp("roundtrip_datetme_coord") / "out.h5"
dist = random_patch.get_coord("distance")
dt = dc.to_datetime64(np.zeros_like(dist))
dt[0] = dc.to_datetime64("2017-09-17")
new = random_patch.update_coords(dt=("distance", dt))
new.io.write(path, "dasdae")
patch = dc.spool(path, file_format="DASDAE")[0]
assert isinstance(patch, dc.Patch)

def test_roundtrip_nullish_datetime_coord(self, tmp_path_factory, random_patch):
"""Ensure a patch with an attached datetime coord with nulls works."""
path = tmp_path_factory.mktemp("roundtrip_datetme_coord") / "out.h5"
dist = random_patch.get_coord("distance")
dt = dc.to_datetime64(np.zeros_like(dist))
dt[~dt.astype(bool)] = np.datetime64("nat")
dt[0] = dc.to_datetime64("2017-09-17")
dt[-4] = dc.to_datetime64("2020-01-03")
new = random_patch.update_coords(dt=("distance", dt))
new.io.write(path, "dasdae")
patch = dc.spool(path, file_format="DASDAE")[0]
assert isinstance(patch, dc.Patch)

0 comments on commit c23d4e5

Please sign in to comment.