Skip to content

Commit

Permalink
clarify spool concat new dim behavior (#407)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: ahmadtourei <ahmadtourei@gmail.com>
  • Loading branch information
d-chambers and ahmadtourei authored Jul 12, 2024
1 parent 17c0237 commit fa08013
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 8 deletions.
11 changes: 8 additions & 3 deletions dascore/utils/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,9 @@ def concatenate_patches(
"""
Concatenate the patches together.
Only patches which are compatible with the first patch are concatenated
together.
Parameters
----------
{check_bev}
Expand All @@ -846,18 +849,20 @@ def concatenate_patches(
>>> spool_concat = spool.concatenate(time=None)
>>> assert len(spool_concat) == 1
>>>
>>> # Concatenate patches along a new dimension
>>> # Concatenate patches along a new dimension.
>>> # Note: This will only include the first patch if existing
>>> # dimensions are not identical.
>>> spool_concat = spool.concatenate(wave_rank=None)
>>> assert "wave_rank" in spool_concat[0].dims
>>>
>>> # concatenate patches in groups of 3.
>>> # Concatenate patches in groups of 3.
>>> big_spool = dc.spool([patch] * 12)
>>> spool_concat = big_spool.concatenate(time=3)
>>> assert len(spool_concat) == 4
Notes
-----
- [`Spool.chunk `](`dascore.BaseSpool.chunk`) performs a similar operation
- [`Spool.chunk`](`dascore.BaseSpool.chunk`) performs a similar operation
but accounts for coordinate values.
- See also the
[chunk section of the spool tutorial](`docs/tutorial/spool`#concatenate)
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorial/spool.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ merged_spool = spool.chunk(time=None)
```

# concatenate
Similar to `chunk`, [`Spool.concatenate`](`dascore.BaseSpool.concatenate`) is used to combine patches together. However, `concatenate doesn't account for coordinate values along the concatenation axis, and can even be used to create new patch dimensions.
Similar to `chunk`, [`Spool.concatenate`](`dascore.BaseSpool.concatenate`) is used to combine patches together. However, `concatenate` doesn't account for coordinate values along the concatenation axis, and can even be used to create new patch dimensions.

:::{.callout-warning}
However, unlike [`chunk`](`dascore.BaseSpool.chunk`), not all `Spool` types implement [`concatenate`](`dascore.BaseSpool.concatenate`).
Expand Down
50 changes: 46 additions & 4 deletions tests/test_utils/test_patch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,29 @@ def test_new_dim(self, random_patch):
assert len(out) == 1
patch = out[0]
assert "zoolou" in patch.dims
coord = patch.get_coord("zoolou")
assert len(coord) == len(patches)

def test_concat_chunk_to_new_dimension(self, random_patch):
"""Ensure the new dimension can be chunked by an int value."""
# When new_dim = 1 it should only add a new dimension to each patch
# and not change the original shape.
spool = dc.spool([random_patch] * 6)
# Test for single values along new dimension
new = spool.concatenate(new_dim=1)
assert len(new) == len(spool)
for patch in new:
coord = patch.get_coord("new_dim")
assert len(coord) == 1
# Test for concatenating two patches together
new = spool.concatenate(new_dim=2)
assert len(new) == len(spool) // 2
for patch in new:
coord = patch.get_coord("new_dim")
assert len(coord) == 2

def test_spool_up(self, random_patch):
"""Ensure a patch is returned in the wrapper is used."""
"""Ensure a patch is returned if the wrapper is used."""
func = _spool_up(concatenate_patches)
out = func([random_patch] * 3, time=None)
assert isinstance(out, dc.BaseSpool)
Expand All @@ -347,7 +367,10 @@ def test_new_dim_spool(self, random_patch):
"""Ensure a patch with new dim can be retrieved from spool."""
spool = dc.spool([random_patch, random_patch])
spool_concat = spool.concatenate(wave_rank=None)
assert "wave_rank" in spool_concat[0].dims
assert len(spool_concat) == 1
patch = spool_concat[0]
assert "wave_rank" in patch.dims
assert len(patch.get_coord("wave_rank")) == len(spool)

def test_patch_with_gap(self, random_patch):
"""Ensure a patch with a time gap still concats."""
Expand All @@ -356,7 +379,6 @@ def test_patch_with_gap(self, random_patch):
one_hour = dc.to_timedelta64(3600)
patch2 = random_patch.update_coords(time_min=time.max() + one_hour)
spool = dc.spool([random_patch, patch2])

# chunk rightfully wouldn't merge these patches, but concatenate will.
merged = spool.concatenate(time=None)
assert len(merged) == 1
Expand Down Expand Up @@ -390,7 +412,7 @@ def test_concat_different_sizes(self, random_patch):
old_times = np.concatenate([p1.get_array("time"), p2.get_array("time")])
assert np.all(new_time == old_times)

def test_concatenate_normal_with_non_dim(self, spool_with_non_coords):
def test_concat_normal_with_non_dim(self, spool_with_non_coords):
"""Ensure normal and non-dim patches can be concatenated together."""
old_arrays = [x.get_array("time") for x in spool_with_non_coords]
old_array = np.concatenate(old_arrays)
Expand All @@ -408,6 +430,26 @@ def test_concatenate_normal_with_non_dim(self, spool_with_non_coords):
nearly_eq = old_array == new_array
assert np.all(both_nan | nearly_eq)

def test_concat_dropped_coord(self, random_spool):
"""Ensure patches after dropping a coordinate can be concatenated together
and the concatenated patch can have a new dimension.
"""
sp = random_spool
pa_list = []
for pa in sp:
pa_dft = pa.dft("time")
cm = pa_dft.coords
pa_dft_dropped_time = pa_dft.update(coords=cm.update(time=None))
pa_list.append(pa_dft_dropped_time)
sp_dft = dc.spool(pa_list)
time_min_coord = sp_dft.get_contents()["time_min"]
sp_concat = sp_dft.concatenate(time_min=None)
pa_concat = sp_concat[0]
updated_coords = pa_concat.coords.update(time_min=time_min_coord)
pa_concat = pa_concat.update(coords=updated_coords)
assert pa_concat.shape[-1] == len(sp)
assert "time_min" in pa_concat.dims


class TestStackPatches:
"""Tests for stacking (adding) spool content."""
Expand Down

0 comments on commit fa08013

Please sign in to comment.