From 8ded4e342021eda192c1b0e1ea3672c4b9e8d7cc Mon Sep 17 00:00:00 2001 From: derrick chambers Date: Fri, 26 Apr 2024 14:39:33 -0600 Subject: [PATCH] fix synchronization issue with attrs.dims and coords.dims --- dascore/core/coordmanager.py | 2 +- dascore/proc/units.py | 5 +++- dascore/utils/misc.py | 51 ++++++++++++++++++++++++----------- tests/test_core/test_attrs.py | 12 +++++++++ tests/test_core/test_patch.py | 7 +++++ 5 files changed, 60 insertions(+), 17 deletions(-) diff --git a/dascore/core/coordmanager.py b/dascore/core/coordmanager.py index eb3bc3a1..c08a3a49 100644 --- a/dascore/core/coordmanager.py +++ b/dascore/core/coordmanager.py @@ -900,7 +900,7 @@ def keys(self): """Return the keys (coordinates) in the coord manager.""" return self.coord_map.keys() - def to_summary_dict(self) -> dict[str, CoordSummary]: + def to_summary_dict(self) -> dict[str, CoordSummary | tuple[str, ...]]: """Convert the contents of the coordinate manager to a summary dict.""" dim_map = self.dim_map out = {} diff --git a/dascore/proc/units.py b/dascore/proc/units.py index d12e9dc7..07e7bedc 100644 --- a/dascore/proc/units.py +++ b/dascore/proc/units.py @@ -12,7 +12,10 @@ def _update_attrs_coord_units(patch: dc.Patch, data_units, coords): """Update attributes with new units.""" attrs = patch.attrs # set data units - attrs = attrs.update(data_units=data_units, coords=coords.to_summary_dict()) + attrs = attrs.update( + data_units=data_units, + coords=coords.to_summary_dict(), + ) return attrs diff --git a/dascore/utils/misc.py b/dascore/utils/misc.py index a1600706..94c905fd 100644 --- a/dascore/utils/misc.py +++ b/dascore/utils/misc.py @@ -452,16 +452,32 @@ def separate_coord_info( coord_dict and attrs_dict. """ - def _meets_required(coord_dict): - """Return True coord dict meets the minimum required keys.""" + def _meets_required(coord_dict, strict=True): + """ + Return True coord dict meets the minimum required keys. + + coord_dict represents potential coordinate fields. + + Strict ensures all required values exist. + """ if not coord_dict: return False if not required and (set(coord_dict) - cant_be_alone): return True - return set(coord_dict).issuperset(required) + if required or not strict: + return set(coord_dict).issuperset(required) + return False def _get_dims(obj): """Try to ascertain dims from keys in obj.""" + # check first for coord manager + if isinstance(obj, dict) and hasattr(obj.get("coords", None), "dims"): + return obj["coords"].dims + + # This object already has dims, just honor it. + if dims := obj.get("dims", None): + return tuple(dims.split(",")) if isinstance(dims, str) else dims + potential_keys = defaultdict(set) for key in obj: if not is_valid_coord_str(key): @@ -481,7 +497,7 @@ def _get_coords_from_top_level(obj, out, dims): warnings.warn(msg, DeprecationWarning, stacklevel=3) potential_coord["step"] = obj[bad_name] - if _meets_required(potential_coord): + if _meets_required(potential_coord, strict=False): out[dim] = potential_coord def _get_coords_from_coord_level(obj, out): @@ -494,7 +510,7 @@ def _get_coords_from_coord_level(obj, out): value = value.to_summary() if hasattr(value, "model_dump"): value = value.model_dump() - if _meets_required(value): + if _meets_required(value, strict=False): out[key] = value def _pop_keys(obj, out): @@ -509,23 +525,28 @@ def _pop_keys(obj, out): obj.pop(f"d_{coord_name}", None) # sequence of short-circuit checks - out = {} + coord_dict = {} required = set(required) if required is not None else set() cant_be_alone = set(cant_be_alone) if obj is None: - return out, {} + return coord_dict, {} if hasattr(obj, "model_dump"): obj = obj.model_dump() - if dims is None: - dims = _get_dims(obj) + obj = dict(obj) + # Check if dims need to be updated. + new_dims = _get_dims(obj) + if new_dims and new_dims != dims: + obj["dims"] = new_dims + dims = new_dims # this is already a dict of coord info. - if set(dims) == set(obj): + if dims and set(dims) == set(obj): return obj, {} - obj = dict(obj) - _get_coords_from_coord_level(obj, out) - _get_coords_from_top_level(obj, out, dims) - _pop_keys(obj, out) - return out, obj + _get_coords_from_coord_level(obj, coord_dict) + _get_coords_from_top_level(obj, coord_dict, dims) + _pop_keys(obj, coord_dict) + if "dims" not in obj and dims is not None: + obj["dims"] = dims + return coord_dict, obj def cached_method(func): diff --git a/tests/test_core/test_attrs.py b/tests/test_core/test_attrs.py index 3dabfdcd..55eb7101 100644 --- a/tests/test_core/test_attrs.py +++ b/tests/test_core/test_attrs.py @@ -197,6 +197,11 @@ def test_items(self, random_patch): out = dict(attrs.items()) assert out == attrs.model_dump() + def test_dims_match_attrs(self, random_patch): + """Ensure the dims from patch attrs matches patch dims.""" + pat = random_patch.rename_coords(distance="channel") + assert pat.dims == pat.attrs.dim_tuple + class TestSummaryAttrs: """Tests for summarizing a schema.""" @@ -288,6 +293,13 @@ def test_attrs_can_update(self, random_attrs): expected = dc.get_quantity("miles") assert dc.get_quantity(attrs.coords["distance"].units) == expected + def test_update_from_coords(self, random_patch): + """Ensure attrs.update updates dims from coords.""" + attrs = random_patch.attrs + new_patch = random_patch.rename_coords(distance="channel") + new = attrs.update(coords=new_patch.coords) + assert new.dim_tuple == new_patch.dims + class TestMergeAttrs: """Tests for merging patch attrs.""" diff --git a/tests/test_core/test_patch.py b/tests/test_core/test_patch.py index 63d77ad9..af68ab61 100644 --- a/tests/test_core/test_patch.py +++ b/tests/test_core/test_patch.py @@ -207,6 +207,13 @@ def test_patch_has_size(self, random_patch): """Ensure patches have same size as data.""" assert random_patch.size == random_patch.data.size + def test_new_patch_non_standard_dims(self): + """Ensure a non-standard dimension has matching dims in attrs and coords.""" + data = np.random.rand(10, 5) + coords = {"time": np.arange(10), "can": np.arange(5)} + patch = dc.Patch(data=data, coords=coords, dims=("time", "can")) + assert patch.dims == patch.attrs.dim_tuple + class TestNew: """Tests for `Patch.new` method."""