Skip to content

Commit

Permalink
Stacking (#340)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Eileen R Martin <eileenrmartin@mines.edu>
  • Loading branch information
d-chambers and eileenrmartin authored Jan 11, 2024
1 parent c0f415d commit 875902a
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 35 deletions.
3 changes: 3 additions & 0 deletions dascore/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def map(self, func, iterables, **kwargs):
# Level of progress bar
PROGRESS_LEVELS = Literal["standard", "basic", None]

# Options for handling specific warnings
WARN_LEVELS = Literal["warn", "raise", None]

# A map from the unit name to the code used in numpy.timedelta64
NUMPY_TIME_UNIT_MAPPING = {
"hour": "h",
Expand Down
108 changes: 80 additions & 28 deletions dascore/core/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dascore.constants import (
VALID_DATA_CATEGORIES,
VALID_DATA_TYPES,
WARN_LEVELS,
PatchType,
attr_conflict_description,
max_lens,
Expand All @@ -32,6 +33,7 @@
iterate,
separate_coord_info,
to_str,
warn_or_raise,
)
from dascore.utils.models import (
CommaSeparatedStr,
Expand Down Expand Up @@ -449,6 +451,82 @@ def _handle_other_attrs(mod_dict_list):
return cls(**mod_dict_list[0])


def check_dims(patch1, patch2, check_behavior: WARN_LEVELS = "raise") -> bool:
"""
Return True if dimensions of two patches are equal.
Parameters
----------
patch1
first patch
patch2
second patch
check_behavior
String with 'raise' will raise an error if incompatible,
'warn' will provide a warning, None will do nothing.
"""
dims1 = patch1.dims
dims2 = patch2.dims
if dims1 == dims2:
return True
msg = (
"Patches are not compatible because their dimensions are not equal."
f" Patch1 dims: {dims1}, Patch2 dims: {dims2}"
)
warn_or_raise(msg, exception=IncompatiblePatchError, behavior=check_behavior)
return False


def check_coords(
patch1, patch2, check_behavior: WARN_LEVELS = "raise", dim_to_ignore=None
) -> bool:
"""
Return True if the coordinates of two patches are compatible, else False.
Parameters
----------
patch1
patch 1
patch2
patch 2
check_behavior
String with 'raise' will raise an error if incompatible,
'warn' will provide a warning.
dim_to_ignore
None by default (all coordinates must be identical).
String specifying a dimension that differences in values,
but not shape, are allowed.
"""
cm1 = patch1.coords
cm2 = patch2.coords
cset1, cset2 = set(cm1.coord_map), set(cm2.coord_map)
shared = cset1 & cset2
not_equal_coords = []
for coord in shared:
coord1 = cm1.coord_map[coord]
coord2 = cm2.coord_map[coord]
if coord1 == coord2:
# Straightforward case, coords are identical.
continue
elif coord == dim_to_ignore:
# If dimension that's ok to ignore value differences,
# check whether shape is the same.
if coord1.shape == coord2.shape:
continue
else:
not_equal_coords.append(coord)
else:
not_equal_coords.append(coord)
if not_equal_coords and len(shared):
msg = (
f"Patches are not compatible. The following shared coordinates "
f"are not equal {coord}"
)
warn_or_raise(msg, exception=IncompatiblePatchError, behavior=check_behavior)
return False
return True


def merge_compatible_coords_attrs(
patch1: PatchType, patch2: PatchType, attrs_to_ignore=("history",)
) -> tuple[CoordManager, PatchAttrs]:
Expand Down Expand Up @@ -476,32 +554,6 @@ def merge_compatible_coords_attrs(
attributes from the first patch are kept in outputs.
"""

def _check_dims(dims1, dims2):
if dims1 == dims2:
return
msg = (
"Patches are not compatible because their dimensions are not equal."
f" Patch1 dims: {dims1}, Patch2 dims: {dims2}"
)
raise IncompatiblePatchError(msg)

def _check_coords(cm1, cm2):
cset1, cset2 = set(cm1.coord_map), set(cm2.coord_map)
shared = cset1 & cset2
not_equal_coords = []
for coord in shared:
coord1 = cm1.coord_map[coord]
coord2 = cm2.coord_map[coord]
if coord1 == coord2:
continue
not_equal_coords.append(coord)
if not_equal_coords:
msg = (
f"Patches are not compatible. The following shared coordinates "
f"are not equal {coord}"
)
raise IncompatiblePatchError(msg)

def _merge_coords(coords1, coords2):
out = {}
coord_names = set(coords1.coord_map) & set(coords2.coord_map)
Expand Down Expand Up @@ -539,10 +591,10 @@ def _merge_models(attrs1, attrs2, coord):
raise IncompatiblePatchError(msg)
return combine_patch_attrs([dict1, dict2], conflicts="keep_first")

_check_dims(patch1.dims, patch2.dims)
check_dims(patch1, patch2)
check_coords(patch1, patch2)
coord1, coord2 = patch1.coords, patch2.coords
attrs1, attrs2 = patch1.attrs, patch2.attrs
_check_coords(coord1, coord2)
coord_out = _merge_coords(coord1, coord2)
attrs = _merge_models(attrs1, attrs2, coord_out)
return coord_out, attrs
Expand Down
10 changes: 6 additions & 4 deletions dascore/core/coordmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,16 @@ def _divide_kwargs(kwargs):
indirect_coord_drops = _get_dim_change_drop(coord_map, dim_map)
# drop coords then call get_coords to handle adding new ones.
coords, _ = self.drop_coords(*(coord_to_drop + indirect_coord_drops))
out = coords._get_dim_array_dict()
out = coords._get_dim_array_dict(keep_coord=True)
out.update({i: v for i, v in kwargs.items() if i not in coord_to_drop})
# update based on keywords
for item, value in coord_updates.items():
coord_name, attr = item.split("_")
new = list(out[coord_name])
coord = get_coord(data=new[1])
new[1] = coord.update(**{attr: value})
new[1] = new[1].update(**{attr: value})
out[coord_name] = tuple(new)
dims = tuple(x for x in self.dims if x not in coord_to_drop)

return get_coord_manager(out, dims=dims)

# we need this here to maintain backwards compatibility
Expand Down Expand Up @@ -660,7 +660,9 @@ def validate_data(self, data):
raise CoordDataError(msg)
return data

def _get_dim_array_dict(self, keep_coord=False):
def _get_dim_array_dict(
self, keep_coord=False
) -> dict[tuple[str], ArrayLike | BaseCoord]:
"""
Get the coord map in the form:
{coord_name = ((dims,), array)}.
Expand Down
60 changes: 59 additions & 1 deletion dascore/core/spool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
import dascore as dc
from dascore.constants import (
PROGRESS_LEVELS,
WARN_LEVELS,
ExecutorType,
PatchType,
attr_conflict_description,
numeric_types,
timeable_types,
)
from dascore.exceptions import InvalidSpoolError, ParameterError
from dascore.core.attrs import check_coords, check_dims
from dascore.core.patch import Patch
from dascore.exceptions import InvalidSpoolError, ParameterError, PatchDimError
from dascore.utils.chunk import ChunkManager
from dascore.utils.display import get_dascore_text, get_nice_text
from dascore.utils.docs import compose_docstring
Expand Down Expand Up @@ -304,6 +307,61 @@ def map(
**kwargs,
)

def stack(self, dim_vary=None, check_behavior: WARN_LEVELS = "warn") -> PatchType:
"""
Stack (add) all patches compatible with first patch together.
Parameters
----------
dim_vary
The name of the dimension which can be different in values
(but not shape) and patches still added together.
check_behavior
Indicates what to do when an incompatible patch is found in the
spool. `None` will silently skip any incompatible patches,
'warn' will issue a warning and then skip incompatible patches,
'raise' will raise an
[`IncompatiblePatchError`](`dascore.exceptions.IncompatiblePatchError`)
if any incompatible patches are found.
Examples
--------
>>> import dascore as dc
>>> # add a spool with equal sized patches but progressing time dim
>>> spool = dc.get_example_spool()
>>> stacked_patch = spool.stack(dim_vary='time')
"""
# check the dims/coords of first patch (considered to be standard for rest)
init_patch = self[0]
stack_arr = np.zeros_like(init_patch.data)

# ensure dim_vary is in dims
if dim_vary is not None and dim_vary not in init_patch.dims:
msg = f"Dimension {dim_vary} is not in first patch."
raise PatchDimError(msg)

for p in self:
# check dimensions of patch compared to init_patch
dims_ok = check_dims(init_patch, p, check_behavior)
coords_ok = check_coords(init_patch, p, check_behavior, dim_vary)
# actually do the stacking of data
if dims_ok and coords_ok:
stack_arr = stack_arr + p.data

# create attributes for the stack with adjusted history
stack_attrs = init_patch.attrs
new_history = list(init_patch.attrs.history)
new_history.append("stack")
stack_attrs = stack_attrs.update(history=new_history)

# create coords array for the stack
stack_coords = init_patch.coords
if dim_vary: # adjust dim_vary to start at 0 for junk dimension indicator
coord_to_change = stack_coords.coord_map[dim_vary]
new_dim = coord_to_change.update_limits(min=0)
stack_coords = stack_coords.update_coords(**{dim_vary: new_dim})
return Patch(stack_arr, stack_coords, init_patch.dims, stack_attrs)


class DataFrameSpool(BaseSpool):
"""An abstract class for spools whose contents are managed by a dataframe."""
Expand Down
28 changes: 28 additions & 0 deletions dascore/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from scipy.special import factorial

import dascore as dc
from dascore.constants import WARN_LEVELS
from dascore.exceptions import (
FilterValueError,
MissingOptionalDependencyError,
Expand Down Expand Up @@ -94,6 +95,33 @@ def suppress_warnings(category=Warning):
return None


def warn_or_raise(
msg: str,
exception: type[Exception] = Exception,
warning: type[Warning] = UserWarning,
behavior: WARN_LEVELS = "warn",
):
"""
A helper function to issues a warning, raise an exception or do nothing.
Parameters
----------
msg
The message to attach to warning or exception.
exception
The exception class to raise.
warning
The type of warning to use. Must be a subclass of Warning.
behavior
If None, do nothing. If
"""
if not behavior:
return
if behavior == "raise":
raise exception(msg)
warnings.warn(msg, warning)


class MethodNameSpace(metaclass=_NameSpaceMeta):
"""A namespace for class methods."""

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def terra15_das_unfinished_path() -> Path:
return out


@pytest.fixture(scope="session")
@pytest.fixture(scope="class")
@register_func(SPOOL_FIXTURES)
def random_spool() -> SpoolType:
"""Init a random array."""
Expand Down
7 changes: 7 additions & 0 deletions tests/test_core/test_coordmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,13 @@ def test_dissociate(self, basic_coord_manager):
dissociated = new.update(new_time=(None, new_time))
assert dissociated.dim_map["new_time"] == ()

def test_unchanged_coords(self, coord_manager_with_units):
"""Ensure coordinates not updated are left unchanged."""
cm = coord_manager_with_units
new_time = cm.coord_map["time"].update(min=0)
new = cm.update(time=new_time)
assert new.coord_map["distance"] == cm.coord_map["distance"]


class TestSqueeze:
"""Tests for squeezing degenerate dimensions."""
Expand Down
Loading

0 comments on commit 875902a

Please sign in to comment.