Skip to content


add methods and tests for broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers committed May 7, 2024
1 parent eca4e63 commit e57e2a6
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 55 deletions.
35 changes: 20 additions & 15 deletions dascore/core/
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from dascore.utils.mapping import FrozenDict
from dascore.utils.misc import (
Expand Down Expand Up @@ -458,7 +459,7 @@ def check_dims(
check_behavior: WARN_LEVELS = "raise",
allow_sub_coords: bool = False,
intersection: bool = False,
) -> bool:
Return True if dimensions of two patches are equal.
Expand All @@ -472,19 +473,21 @@ def check_dims(
String with 'raise' will raise an error if incompatible,
'warn' will provide a warning, None will do nothing.
If True, allow subcoords to be considered equal.
If True, allow any intersection of dimensions to pass. This is useful
when only broad-castablity needs to be checked. If false require dims
to be equal.
dims1, dims2 = patch1.dims, patch2.dims
dims_ok = True
if not allow_sub_coords and patch1.dims == patch2.dims:
if not intersection and patch1.dims == patch2.dims:
return True
dset1, dset2 = set(dims1), set(dims2)
if allow_sub_coords and (dset1.issubset(dset2) or dset2.issubset(dset1)):
if intersection and (dset1 | dset2):
return True
msg = (
"Patch dimensions are not compatible for merging."
" Patch1 dims: {dims1}, Patch2 dims: {dims2}"
f" Patch1 dims: {dims1}, Patch2 dims: {dims2}"
warn_or_raise(msg, exception=IncompatiblePatchError, behavior=check_behavior)
return dims_ok
Expand Down Expand Up @@ -533,7 +536,7 @@ def check_coords(
if not_equal_coords and len(shared):
msg = (
f"Patches are not compatible. The following shared coordinates "
f"are not equal {coord}"
f"are not equal: {coord}"
warn_or_raise(msg, exception=IncompatiblePatchError, behavior=check_behavior)
return False
Expand All @@ -544,7 +547,7 @@ def merge_compatible_coords_attrs(
patch1: PatchType,
patch2: PatchType,
attrs_to_ignore=("history", "dims"),
allow_subcoords: bool = False,
dim_intersection: bool = False,
) -> tuple[CoordManager, PatchAttrs]:
Merge the coordinates and attributes of patches or raise if incompatible.
Expand All @@ -565,28 +568,30 @@ def merge_compatible_coords_attrs(
The first patch
The second patch
A sequence of attributes to not consider in equality. Only these
attributes from the first patch are kept in outputs.
If True, merge if any dimensions overlap, else raise if all do not

def _merge_coords(coords1, coords2):
out = {}
cmap1, cmap2 = coords1.coord_map, coords1.coord_map
cmap1, cmap2 = coords1.coord_map, coords2.coord_map
coord_names = set(cmap1) | set(cmap2)
# fast path to update identical coordinates
if len(coord_names) == len(cmap1):
if coord_names == set(cmap1):
return coords1
if len(coord_names) == len(cmap2):
if coord_names == set(cmap2):
return coords2
# otherwise just squish coords from both managers together.
for name in coord_names:
coord = coords1 if name in coords1.coord_map else coords2
dims = coord.dim_map[name]
out[name] = (dims, coord.coord_map[name])
# Need to get coordinate that are in output, but preserve order.
dim1, dim2 = coords1.dims, coords2.dims
dims = dim1 if len(dim2) < len(dim1) else dim2
dims = _merge_tuples(coords1.dims, coords2.dims)
return dc.core.coordmanager.get_coord_manager(out, dims=dims)

def _merge_models(attrs1, attrs2, coord):
Expand All @@ -613,7 +618,7 @@ def _merge_models(attrs1, attrs2, coord):
raise IncompatiblePatchError(msg)
return combine_patch_attrs([dict1, dict2], conflicts="keep_first")

check_dims(patch1, patch2, allow_sub_coords=allow_subcoords)
check_dims(patch1, patch2, intersection=dim_intersection)
check_coords(patch1, patch2)
coord1, coord2 = patch1.coords, patch2.coords
attrs1, attrs2 = patch1.attrs, patch2.attrs
Expand Down
34 changes: 23 additions & 11 deletions dascore/core/
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _divide_kwargs(kwargs):

coord_updates, coord_to_drop, coord_to_add = _divide_kwargs(kwargs)
# get coords to drop from selecting None
coord_map, dim_map = _get_coord_dim_map(coord_to_add, self.dims)
coord_map, dim_map, dims = _get_coord_dim_map(coord_to_add, self.dims)
# find coords to drop because their dimension changed.
indirect_coord_drops = _get_dim_change_drop(coord_map, dim_map)
# drop coords then call get_coords to handle adding new ones.
Expand All @@ -297,8 +297,8 @@ def _divide_kwargs(kwargs):
new = list(out[coord_name])
new[1] = new[1].update(**{attr: value})
out[coord_name] = tuple(new)
dims = tuple(x for x in self.dims if x not in coord_to_drop)

dims = tuple(x for x in dims if x not in coord_to_drop)
return get_coord_manager(out, dims=dims)

# we need this here to maintain backwards compatibility
Expand Down Expand Up @@ -680,6 +680,11 @@ def size(self):
"""Return the size of the patch data matrix."""

def ndim(self):
"""Return the number of dimensions in the coordinage manager."""
return len(self.dims)

def validate_data(self, data):
"""Ensure data conforms to coordinates."""
data = np.array([]) if data is None else data
Expand Down Expand Up @@ -1025,7 +1030,7 @@ def get_coord_manager(
dims = ()
coords = {} if coords is None else coords
coord_map, dim_map = _get_coord_dim_map(coords, dims)
coord_map, dim_map, dims = _get_coord_dim_map(coords, dims)
if attrs:
coord_updates, _ = separate_coord_info(attrs, dims)
updateable_coords = set(coord_updates) - set(coord_map)
Expand All @@ -1037,7 +1042,7 @@ def get_coord_manager(

def _get_coord_dim_map(coords, dims):
"""Get coord_map and dim_map from coord input."""
"""Get coord_map, dim_map, and new dims from coord input."""

def _get_coord(coord):
"""Get a coordinate from various inputs."""
Expand All @@ -1061,7 +1066,7 @@ def _coord_from_simple(name, coord):
out = _get_coord(coord)
return out, (name,)

def _maybe_coord_from_nested(coord):
def _maybe_coord_from_nested(name, coord, new_dims):
Get coordinates from {coord_name: (dim_name, coord)} or
{coord_name: ((dim_names...,), coord)}.
Expand All @@ -1073,16 +1078,21 @@ def _maybe_coord_from_nested(coord):
raise CoordError(msg)
dim_names = iterate(coord[0])
# all dims must be in the input dims.
if not (d1 := set(dim_names)).issubset(d2 := set(dims)):
bad_dims = d2 - d1
# # all dims must be in the input dims or a new coord.
d1, d2 = set(dim_names), set(dims)
if (not d1.issubset(d2)) and d1 != {name}:
bad_dims = d1 - d2
msg = (
f"Coordinate specified invalid dimension(s) {bad_dims}."
f" Valid dimensions are {dims}"
raise CoordError(msg)
# pull out any relevant info from attrs.
coord_out = _get_coord(coord[1])
# check if this is added a new dimension.
if len(dim_names) == 1 and (newdname := dim_names[0]) == name:
if newdname not in dims:
assert coord_out.shape == np.shape(coord[1])
return coord_out, dim_names

Expand All @@ -1092,14 +1102,16 @@ def _maybe_coord_from_nested(coord):
# coords_dump = coords.model_dump()
# return dict(coords_dump["coord_map"]), dict(coords_dump["dim_map"])

c_map, d_map = {}, {}
c_map, d_map, new_dims = {}, {}, []
# iterate coords, get coordinate output.
for name, coord in coords.items():
if not isinstance(coord, tuple):
c_map[name], d_map[name] = _coord_from_simple(name, coord)
c_map[name], d_map[name] = _maybe_coord_from_nested(coord)
return c_map, d_map
c_map[name], d_map[name] = _maybe_coord_from_nested(name, coord, new_dims)
if new_dims:
dims = tuple(list(dims) + new_dims)
return c_map, d_map, dims

def merge_coord_managers(
Expand Down
11 changes: 10 additions & 1 deletion dascore/core/
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def __mul__(self, other):
def __pow__(self, other):
return dascore.proc.apply_operator(self, other, np.power)

def __neg__(self):
return self.update(

def __rich__(self):
dascore_text = get_dascore_text()
patch_text = Text("Patch ⚡", style="bold")
Expand All @@ -145,9 +148,14 @@ def dims(self) -> tuple[str, ...]:
"""Return the dimensions contained in patch."""
return self.coords.dims

def ndim(self) -> int:
"""Return the number of dimensions contained in patch."""
return len(self.coords.dims)

def coord_shapes(self) -> dict[str, tuple[int, ...]]:
"""Return a dict of coordinate: (shape, ...)."""
"""Return a dict of {coordinate: (shape, ...)}."""
return self.coords.coord_shapes

Expand Down Expand Up @@ -197,6 +205,7 @@ def channel_count(self) -> int:
pipe = dascore.proc.pipe
set_dims = dascore.proc.set_dims
squeeze = dascore.proc.squeeze
append_dims = dascore.proc.append_dims
transpose = dascore.proc.transpose
snap_coords = dascore.proc.snap_coords
sort_coords = dascore.proc.sort_coords
Expand Down
88 changes: 65 additions & 23 deletions dascore/proc/
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dascore.exceptions import UnitError
from dascore.units import DimensionalityError, Quantity, Unit, get_quantity
from dascore.utils.models import ArrayLike
from dascore.utils.patch import patch_function
from dascore.utils.patch import _make_dims_alike, _select_compatible, patch_function

def set_dims(self: PatchType, **kwargs: str) -> PatchType:
Expand Down Expand Up @@ -264,6 +264,55 @@ def transpose(self: PatchType, *dims: str) -> PatchType:
return, coords=new_coord)

def append_dims(patch: PatchType, **dim_kwargs) -> PatchType:
Insert dimensions at the end of the patch.
This can be used to add dummy dimensions to the patch.
Used to pass keys (new dim names) and values (coordinate values).
>>> import dascore as dc
>>> patch = dc.get_example_patch()
>>> # Add a dummy dimension called "face" to end of patch
>>> # which has a coordinate value of [1].
>>> new = patch.append_dims(face=1)
>>> # Same thing as above, but with a larger coords which broadcasts
>>> # the data to shape appropriate to mach coordinates.
>>> new = patch.append_dims(face=[1, 2])
- This tries to be more simple than numpy and xarray's expand_dims.
- Use [`Patch.transpose`](`dascore.patch.transpose`) to re-arrange dimensions.
- If dimension already exists nothing will happen.
# Ensure the input values are all arrays so they can be coords.
kwargs = {
i: (i, np.atleast_1d(v)) for i, v in dim_kwargs.items() if i not in patch.dims
if not kwargs:
return patch
ndim = patch.ndim
# First get data with empty dimensions
insert_inds = [x + ndim for x in range(len(kwargs))]
data = np.expand_dims(, insert_inds)
shapes = list(data.shape)
for ind, (_, coord_data) in zip(insert_inds, kwargs.values()):
shapes[ind] = len(coord_data)
data = np.broadcast_to(data, shapes)
coords = patch.coords.update(**kwargs)
return patch.update(data=data, coords=coords)

def squeeze(self: PatchType, dim=None) -> PatchType:
Expand Down Expand Up @@ -376,7 +425,7 @@ def apply_operator(patch: PatchType, other, operator) -> PatchType:
Apply a ufunc-type operator to a patch.
This is used to implement a patch's operator overload.
This is used to implement a patch's operator overloading.
Expand Down Expand Up @@ -406,18 +455,26 @@ def apply_operator(patch: PatchType, other, operator) -> PatchType:
>>> # subtract one patch from another. Coords and attrs must be compatible
>>> new = apply_operator(patch, patch, np.subtract)
>>> assert np.allclose(, 0)
See [numpy's ufunc docs](
# Handle creating merged coords and patch stuff.
if isinstance(other, dc.Patch):
if len(other.dims) > len(patch.dims):
patch, other = other, patch
other_patch = other
# Trim patches so their shared dims overlap exactly.
patch, other_patch = _select_compatible(patch, other_patch)
# get metadata. This is done before making the patch coords compatible.
coords, attrs = merge_compatible_coords_attrs(
other = _maybe_broadcast_data(patch, other)
if other_units := get_quantity(attrs.data_units):
# Make sure dims are the same for each patch.
patch, other_patch = _make_dims_alike(patch, other_patch)
other =
if other_units := get_quantity(other_patch.attrs.data_units):
other = other * other_units
coords, attrs = patch.coords, patch.attrs
Expand All @@ -442,21 +499,6 @@ def apply_operator(patch: PatchType, other, operator) -> PatchType:
return new

def _maybe_broadcast_data(patch, other):
"""Maybe broadcast data in two compatible patches."""
# If the shapes already match no broadcast needed.
if patch.shape == other.shape:
assert set(other.dims).issubset(patch.dims)
# Ensure the other patch has dims ordered correctly
dims1, dims2 = patch.dims, other.dims
other_dims_order = tuple(x for x in dims1 if x in dims2)
other = other.transpose(*other_dims_order)
size_map = {d: other.shape[i] for i, d in enumerate(other.dims)}
other_shape = tuple(size_map.get(x, 1) for x in patch.dims)

def dropna(patch: PatchType, dim, how: Literal["any", "all"] = "any") -> PatchType:
Expand Down
4 changes: 2 additions & 2 deletions dascore/proc/
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def update_coords(self: PatchType, **kwargs) -> PatchType:
The mapping from old names to new names
If not None, the new dimensions of the coordinate manager.
Expand Down
8 changes: 8 additions & 0 deletions dascore/utils/
Original file line number Diff line number Diff line change
Expand Up @@ -735,3 +735,11 @@ def check_filter_range(nyquist, low, high, filt_min, filt_max):
f"filt_min = {filt_min}, filt_max = {filt_max}"
raise FilterValueError(msg)

def _merge_tuples(dims1, dims2):
"""Merge tuples together, preserving order where possible."""
dims = dict.fromkeys(dims1)
out = tuple(dims.keys())
return out

0 comments on commit e57e2a6

Please sign in to comment.