Skip to content

Commit

Permalink
fix synchronization issue with attrs.dims and coords.dims (#367)
Browse files Browse the repository at this point in the history
* fix synchronization issue with attrs.dims and coords.dims

* add dims consistency assert in Patch init
  • Loading branch information
d-chambers authored Apr 27, 2024
1 parent 6aad9f3 commit 38eed62
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 17 deletions.
3 changes: 3 additions & 0 deletions dascore/core/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def from_dict(
@property
def dim_tuple(self):
"""Return a tuple of dimensions. The dims attr is a string."""
dim_str = self.dims
if not dim_str:
return tuple()
return tuple(self.dims.split(","))

def rename_dimension(self, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion dascore/core/coordmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def keys(self):
"""Return the keys (coordinates) in the coord manager."""
return self.coord_map.keys()

def to_summary_dict(self) -> dict[str, CoordSummary]:
def to_summary_dict(self) -> dict[str, CoordSummary | tuple[str, ...]]:
"""Convert the contents of the coordinate manager to a summary dict."""
dim_map = self.dim_map
out = {}
Expand Down
1 change: 1 addition & 0 deletions dascore/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
else:
# ensure attrs conforms to coords
attrs = dc.PatchAttrs.from_dict(attrs).update(coords=coords)
assert coords.dims == attrs.dim_tuple, "dim mismatch on coords and attrs"
self._coords = coords
self._attrs = attrs
self._data = array(self.coords.validate_data(data))
Expand Down
5 changes: 4 additions & 1 deletion dascore/proc/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ def _update_attrs_coord_units(patch: dc.Patch, data_units, coords):
"""Update attributes with new units."""
attrs = patch.attrs
# set data units
attrs = attrs.update(data_units=data_units, coords=coords.to_summary_dict())
attrs = attrs.update(
data_units=data_units,
coords=coords.to_summary_dict(),
)
return attrs


Expand Down
51 changes: 36 additions & 15 deletions dascore/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,16 +452,32 @@ def separate_coord_info(
coord_dict and attrs_dict.
"""

def _meets_required(coord_dict):
"""Return True coord dict meets the minimum required keys."""
def _meets_required(coord_dict, strict=True):
"""
Return True coord dict meets the minimum required keys.
coord_dict represents potential coordinate fields.
Strict ensures all required values exist.
"""
if not coord_dict:
return False
if not required and (set(coord_dict) - cant_be_alone):
return True
return set(coord_dict).issuperset(required)
if required or not strict:
return set(coord_dict).issuperset(required)
return False

def _get_dims(obj):
"""Try to ascertain dims from keys in obj."""
# check first for coord manager
if isinstance(obj, dict) and hasattr(obj.get("coords", None), "dims"):
return obj["coords"].dims

# This object already has dims, just honor it.
if dims := obj.get("dims", None):
return tuple(dims.split(",")) if isinstance(dims, str) else dims

potential_keys = defaultdict(set)
for key in obj:
if not is_valid_coord_str(key):
Expand All @@ -481,7 +497,7 @@ def _get_coords_from_top_level(obj, out, dims):
warnings.warn(msg, DeprecationWarning, stacklevel=3)
potential_coord["step"] = obj[bad_name]

if _meets_required(potential_coord):
if _meets_required(potential_coord, strict=False):
out[dim] = potential_coord

def _get_coords_from_coord_level(obj, out):
Expand All @@ -494,7 +510,7 @@ def _get_coords_from_coord_level(obj, out):
value = value.to_summary()
if hasattr(value, "model_dump"):
value = value.model_dump()
if _meets_required(value):
if _meets_required(value, strict=False):
out[key] = value

def _pop_keys(obj, out):
Expand All @@ -509,23 +525,28 @@ def _pop_keys(obj, out):
obj.pop(f"d_{coord_name}", None)

# sequence of short-circuit checks
out = {}
coord_dict = {}
required = set(required) if required is not None else set()
cant_be_alone = set(cant_be_alone)
if obj is None:
return out, {}
return coord_dict, {}
if hasattr(obj, "model_dump"):
obj = obj.model_dump()
if dims is None:
dims = _get_dims(obj)
obj = dict(obj)
# Check if dims need to be updated.
new_dims = _get_dims(obj)
if new_dims and new_dims != dims:
obj["dims"] = new_dims
dims = new_dims
# this is already a dict of coord info.
if set(dims) == set(obj):
if dims and set(dims) == set(obj):
return obj, {}
obj = dict(obj)
_get_coords_from_coord_level(obj, out)
_get_coords_from_top_level(obj, out, dims)
_pop_keys(obj, out)
return out, obj
_get_coords_from_coord_level(obj, coord_dict)
_get_coords_from_top_level(obj, coord_dict, dims)
_pop_keys(obj, coord_dict)
if "dims" not in obj and dims is not None:
obj["dims"] = dims
return coord_dict, obj


def cached_method(func):
Expand Down
12 changes: 12 additions & 0 deletions tests/test_core/test_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ def test_items(self, random_patch):
out = dict(attrs.items())
assert out == attrs.model_dump()

def test_dims_match_attrs(self, random_patch):
"""Ensure the dims from patch attrs matches patch dims."""
pat = random_patch.rename_coords(distance="channel")
assert pat.dims == pat.attrs.dim_tuple


class TestSummaryAttrs:
"""Tests for summarizing a schema."""
Expand Down Expand Up @@ -288,6 +293,13 @@ def test_attrs_can_update(self, random_attrs):
expected = dc.get_quantity("miles")
assert dc.get_quantity(attrs.coords["distance"].units) == expected

def test_update_from_coords(self, random_patch):
"""Ensure attrs.update updates dims from coords."""
attrs = random_patch.attrs
new_patch = random_patch.rename_coords(distance="channel")
new = attrs.update(coords=new_patch.coords)
assert new.dim_tuple == new_patch.dims


class TestMergeAttrs:
"""Tests for merging patch attrs."""
Expand Down
7 changes: 7 additions & 0 deletions tests/test_core/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ def test_patch_has_size(self, random_patch):
"""Ensure patches have same size as data."""
assert random_patch.size == random_patch.data.size

def test_new_patch_non_standard_dims(self):
"""Ensure a non-standard dimension has matching dims in attrs and coords."""
data = np.random.rand(10, 5)
coords = {"time": np.arange(10), "can": np.arange(5)}
patch = dc.Patch(data=data, coords=coords, dims=("time", "can"))
assert patch.dims == patch.attrs.dim_tuple


class TestNew:
"""Tests for `Patch.new` method."""
Expand Down

0 comments on commit 38eed62

Please sign in to comment.