Skip to content

Commit

Permalink
add methods and tests for broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers committed May 7, 2024
1 parent eca4e63 commit e57e2a6
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 55 deletions.
35 changes: 20 additions & 15 deletions dascore/core/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from dascore.utils.mapping import FrozenDict
from dascore.utils.misc import (
_dict_list_diffs,
_merge_tuples,
all_diffs_close_enough,
get_middle_value,
iterate,
Expand Down Expand Up @@ -458,7 +459,7 @@ def check_dims(
patch1,
patch2,
check_behavior: WARN_LEVELS = "raise",
allow_sub_coords: bool = False,
intersection: bool = False,
) -> bool:
"""
Return True if dimensions of two patches are equal.
Expand All @@ -472,19 +473,21 @@ def check_dims(
check_behavior
String with 'raise' will raise an error if incompatible,
'warn' will provide a warning, None will do nothing.
allow_sub_coords
If True, allow subcoords to be considered equal.
intersection
If True, allow any intersection of dimensions to pass. This is useful
when only broad-castablity needs to be checked. If false require dims
to be equal.
"""
dims1, dims2 = patch1.dims, patch2.dims
dims_ok = True
if not allow_sub_coords and patch1.dims == patch2.dims:
if not intersection and patch1.dims == patch2.dims:
return True
dset1, dset2 = set(dims1), set(dims2)
if allow_sub_coords and (dset1.issubset(dset2) or dset2.issubset(dset1)):
if intersection and (dset1 | dset2):
return True
msg = (
"Patch dimensions are not compatible for merging."
" Patch1 dims: {dims1}, Patch2 dims: {dims2}"
f" Patch1 dims: {dims1}, Patch2 dims: {dims2}"
)
warn_or_raise(msg, exception=IncompatiblePatchError, behavior=check_behavior)
return dims_ok
Expand Down Expand Up @@ -533,7 +536,7 @@ def check_coords(
if not_equal_coords and len(shared):
msg = (
f"Patches are not compatible. The following shared coordinates "
f"are not equal {coord}"
f"are not equal: {coord}"
)
warn_or_raise(msg, exception=IncompatiblePatchError, behavior=check_behavior)
return False
Expand All @@ -544,7 +547,7 @@ def merge_compatible_coords_attrs(
patch1: PatchType,
patch2: PatchType,
attrs_to_ignore=("history", "dims"),
allow_subcoords: bool = False,
dim_intersection: bool = False,
) -> tuple[CoordManager, PatchAttrs]:
"""
Merge the coordinates and attributes of patches or raise if incompatible.
Expand All @@ -565,28 +568,30 @@ def merge_compatible_coords_attrs(
The first patch
patch2
The second patch
attr_ignore
attrs_to_ignore
A sequence of attributes to not consider in equality. Only these
attributes from the first patch are kept in outputs.
dim_intersection
If True, merge if any dimensions overlap, else raise if all do not
overlap.
"""

def _merge_coords(coords1, coords2):
out = {}
cmap1, cmap2 = coords1.coord_map, coords1.coord_map
cmap1, cmap2 = coords1.coord_map, coords2.coord_map
coord_names = set(cmap1) | set(cmap2)
# fast path to update identical coordinates
if len(coord_names) == len(cmap1):
if coord_names == set(cmap1):
return coords1
if len(coord_names) == len(cmap2):
if coord_names == set(cmap2):
return coords2
# otherwise just squish coords from both managers together.
for name in coord_names:
coord = coords1 if name in coords1.coord_map else coords2
dims = coord.dim_map[name]
out[name] = (dims, coord.coord_map[name])
# Need to get coordinate that are in output, but preserve order.
dim1, dim2 = coords1.dims, coords2.dims
dims = dim1 if len(dim2) < len(dim1) else dim2
dims = _merge_tuples(coords1.dims, coords2.dims)
return dc.core.coordmanager.get_coord_manager(out, dims=dims)

def _merge_models(attrs1, attrs2, coord):
Expand All @@ -613,7 +618,7 @@ def _merge_models(attrs1, attrs2, coord):
raise IncompatiblePatchError(msg)
return combine_patch_attrs([dict1, dict2], conflicts="keep_first")

check_dims(patch1, patch2, allow_sub_coords=allow_subcoords)
check_dims(patch1, patch2, intersection=dim_intersection)
check_coords(patch1, patch2)
coord1, coord2 = patch1.coords, patch2.coords
attrs1, attrs2 = patch1.attrs, patch2.attrs
Expand Down
34 changes: 23 additions & 11 deletions dascore/core/coordmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _divide_kwargs(kwargs):

coord_updates, coord_to_drop, coord_to_add = _divide_kwargs(kwargs)
# get coords to drop from selecting None
coord_map, dim_map = _get_coord_dim_map(coord_to_add, self.dims)
coord_map, dim_map, dims = _get_coord_dim_map(coord_to_add, self.dims)
# find coords to drop because their dimension changed.
indirect_coord_drops = _get_dim_change_drop(coord_map, dim_map)
# drop coords then call get_coords to handle adding new ones.
Expand All @@ -297,8 +297,8 @@ def _divide_kwargs(kwargs):
new = list(out[coord_name])
new[1] = new[1].update(**{attr: value})
out[coord_name] = tuple(new)
dims = tuple(x for x in self.dims if x not in coord_to_drop)

dims = tuple(x for x in dims if x not in coord_to_drop)
return get_coord_manager(out, dims=dims)

# we need this here to maintain backwards compatibility
Expand Down Expand Up @@ -680,6 +680,11 @@ def size(self):
"""Return the size of the patch data matrix."""
return np.prod(self.shape)

@property
def ndim(self):
"""Return the number of dimensions in the coordinage manager."""
return len(self.dims)

def validate_data(self, data):
"""Ensure data conforms to coordinates."""
data = np.array([]) if data is None else data
Expand Down Expand Up @@ -1025,7 +1030,7 @@ def get_coord_manager(
else:
dims = ()
coords = {} if coords is None else coords
coord_map, dim_map = _get_coord_dim_map(coords, dims)
coord_map, dim_map, dims = _get_coord_dim_map(coords, dims)
if attrs:
coord_updates, _ = separate_coord_info(attrs, dims)
updateable_coords = set(coord_updates) - set(coord_map)
Expand All @@ -1037,7 +1042,7 @@ def get_coord_manager(


def _get_coord_dim_map(coords, dims):
"""Get coord_map and dim_map from coord input."""
"""Get coord_map, dim_map, and new dims from coord input."""

def _get_coord(coord):
"""Get a coordinate from various inputs."""
Expand All @@ -1061,7 +1066,7 @@ def _coord_from_simple(name, coord):
out = _get_coord(coord)
return out, (name,)

def _maybe_coord_from_nested(coord):
def _maybe_coord_from_nested(name, coord, new_dims):
"""
Get coordinates from {coord_name: (dim_name, coord)} or
{coord_name: ((dim_names...,), coord)}.
Expand All @@ -1073,16 +1078,21 @@ def _maybe_coord_from_nested(coord):
)
raise CoordError(msg)
dim_names = iterate(coord[0])
# all dims must be in the input dims.
if not (d1 := set(dim_names)).issubset(d2 := set(dims)):
bad_dims = d2 - d1
# # all dims must be in the input dims or a new coord.
d1, d2 = set(dim_names), set(dims)
if (not d1.issubset(d2)) and d1 != {name}:
bad_dims = d1 - d2
msg = (
f"Coordinate specified invalid dimension(s) {bad_dims}."
f" Valid dimensions are {dims}"
)
raise CoordError(msg)
# pull out any relevant info from attrs.
coord_out = _get_coord(coord[1])
# check if this is added a new dimension.
if len(dim_names) == 1 and (newdname := dim_names[0]) == name:
if newdname not in dims:
new_dims.append(newdname)
assert coord_out.shape == np.shape(coord[1])
return coord_out, dim_names

Expand All @@ -1092,14 +1102,16 @@ def _maybe_coord_from_nested(coord):
# coords_dump = coords.model_dump()
# return dict(coords_dump["coord_map"]), dict(coords_dump["dim_map"])

c_map, d_map = {}, {}
c_map, d_map, new_dims = {}, {}, []
# iterate coords, get coordinate output.
for name, coord in coords.items():
if not isinstance(coord, tuple):
c_map[name], d_map[name] = _coord_from_simple(name, coord)
else:
c_map[name], d_map[name] = _maybe_coord_from_nested(coord)
return c_map, d_map
c_map[name], d_map[name] = _maybe_coord_from_nested(name, coord, new_dims)
if new_dims:
dims = tuple(list(dims) + new_dims)
return c_map, d_map, dims


def merge_coord_managers(
Expand Down
11 changes: 10 additions & 1 deletion dascore/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def __mul__(self, other):
def __pow__(self, other):
return dascore.proc.apply_operator(self, other, np.power)

def __neg__(self):
return self.update(data=-self.data)

def __rich__(self):
dascore_text = get_dascore_text()
patch_text = Text("Patch ⚡", style="bold")
Expand All @@ -145,9 +148,14 @@ def dims(self) -> tuple[str, ...]:
"""Return the dimensions contained in patch."""
return self.coords.dims

@property
def ndim(self) -> int:
"""Return the number of dimensions contained in patch."""
return len(self.coords.dims)

@property
def coord_shapes(self) -> dict[str, tuple[int, ...]]:
"""Return a dict of coordinate: (shape, ...)."""
"""Return a dict of {coordinate: (shape, ...)}."""
return self.coords.coord_shapes

@property
Expand Down Expand Up @@ -197,6 +205,7 @@ def channel_count(self) -> int:
pipe = dascore.proc.pipe
set_dims = dascore.proc.set_dims
squeeze = dascore.proc.squeeze
append_dims = dascore.proc.append_dims
transpose = dascore.proc.transpose
snap_coords = dascore.proc.snap_coords
sort_coords = dascore.proc.sort_coords
Expand Down
88 changes: 65 additions & 23 deletions dascore/proc/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dascore.exceptions import UnitError
from dascore.units import DimensionalityError, Quantity, Unit, get_quantity
from dascore.utils.models import ArrayLike
from dascore.utils.patch import patch_function
from dascore.utils.patch import _make_dims_alike, _select_compatible, patch_function


def set_dims(self: PatchType, **kwargs: str) -> PatchType:
Expand Down Expand Up @@ -264,6 +264,55 @@ def transpose(self: PatchType, *dims: str) -> PatchType:
return self.new(data=new_data, coords=new_coord)


@patch_function(history=None)
def append_dims(patch: PatchType, **dim_kwargs) -> PatchType:
"""
Insert dimensions at the end of the patch.
This can be used to add dummy dimensions to the patch.
Parameters
----------
dim_kwargs
Used to pass keys (new dim names) and values (coordinate values).
Examples
--------
>>> import dascore as dc
>>> patch = dc.get_example_patch()
>>> # Add a dummy dimension called "face" to end of patch
>>> # which has a coordinate value of [1].
>>> new = patch.append_dims(face=1)
>>> # Same thing as above, but with a larger coords which broadcasts
>>> # the data to shape appropriate to mach coordinates.
>>> new = patch.append_dims(face=[1, 2])
Notes
-----
- This tries to be more simple than numpy and xarray's expand_dims.
- Use [`Patch.transpose`](`dascore.patch.transpose`) to re-arrange dimensions.
- If dimension already exists nothing will happen.
"""
# Ensure the input values are all arrays so they can be coords.
kwargs = {
i: (i, np.atleast_1d(v)) for i, v in dim_kwargs.items() if i not in patch.dims
}
if not kwargs:
return patch
ndim = patch.ndim
# First get data with empty dimensions
insert_inds = [x + ndim for x in range(len(kwargs))]
data = np.expand_dims(patch.data, insert_inds)
shapes = list(data.shape)
for ind, (_, coord_data) in zip(insert_inds, kwargs.values()):
shapes[ind] = len(coord_data)
data = np.broadcast_to(data, shapes)
coords = patch.coords.update(**kwargs)
return patch.update(data=data, coords=coords)


@patch_function()
def squeeze(self: PatchType, dim=None) -> PatchType:
"""
Expand Down Expand Up @@ -376,7 +425,7 @@ def apply_operator(patch: PatchType, other, operator) -> PatchType:
"""
Apply a ufunc-type operator to a patch.
This is used to implement a patch's operator overload.
This is used to implement a patch's operator overloading.
Parameters
----------
Expand Down Expand Up @@ -406,18 +455,26 @@ def apply_operator(patch: PatchType, other, operator) -> PatchType:
>>> # subtract one patch from another. Coords and attrs must be compatible
>>> new = apply_operator(patch, patch, np.subtract)
>>> assert np.allclose(new.data, 0)
Notes
-----
See [numpy's ufunc docs](https://numpy.org/doc/stable/reference/ufuncs.html)
"""
# Handle creating merged coords and patch stuff.
if isinstance(other, dc.Patch):
if len(other.dims) > len(patch.dims):
patch, other = other, patch
other_patch = other
# Trim patches so their shared dims overlap exactly.
patch, other_patch = _select_compatible(patch, other_patch)
# get metadata. This is done before making the patch coords compatible.
coords, attrs = merge_compatible_coords_attrs(
patch,
other,
allow_subcoords=True,
other_patch,
dim_intersection=True,
)
other = _maybe_broadcast_data(patch, other)
if other_units := get_quantity(attrs.data_units):
# Make sure dims are the same for each patch.
patch, other_patch = _make_dims_alike(patch, other_patch)
other = other_patch.data
if other_units := get_quantity(other_patch.attrs.data_units):
other = other * other_units
else:
coords, attrs = patch.coords, patch.attrs
Expand All @@ -442,21 +499,6 @@ def apply_operator(patch: PatchType, other, operator) -> PatchType:
return new


def _maybe_broadcast_data(patch, other):
"""Maybe broadcast data in two compatible patches."""
# If the shapes already match no broadcast needed.
if patch.shape == other.shape:
return other.data
assert set(other.dims).issubset(patch.dims)
# Ensure the other patch has dims ordered correctly
dims1, dims2 = patch.dims, other.dims
other_dims_order = tuple(x for x in dims1 if x in dims2)
other = other.transpose(*other_dims_order)
size_map = {d: other.shape[i] for i, d in enumerate(other.dims)}
other_shape = tuple(size_map.get(x, 1) for x in patch.dims)
return other.data.reshape(other_shape)


@patch_function()
def dropna(patch: PatchType, dim, how: Literal["any", "all"] = "any") -> PatchType:
"""
Expand Down
4 changes: 2 additions & 2 deletions dascore/proc/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def update_coords(self: PatchType, **kwargs) -> PatchType:
Parameters
----------
**kwargs
The mapping from old names to new names
dims
If not None, the new dimensions of the coordinate manager.
Examples
--------
Expand Down
8 changes: 8 additions & 0 deletions dascore/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,3 +735,11 @@ def check_filter_range(nyquist, low, high, filt_min, filt_max):
f"filt_min = {filt_min}, filt_max = {filt_max}"
)
raise FilterValueError(msg)


def _merge_tuples(dims1, dims2):
"""Merge tuples together, preserving order where possible."""
dims = dict.fromkeys(dims1)
dims.update(dict.fromkeys(dims2))
out = tuple(dims.keys())
return out
Loading

0 comments on commit e57e2a6

Please sign in to comment.