Skip to content

Commit

Permalink
fix issue 475 and 474
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers committed Dec 30, 2024
1 parent 4f48bd6 commit a2b569e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
13 changes: 10 additions & 3 deletions dascore/utils/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def get_intervals(
if overlap > length:
msg = "Cant chunk when overlap is greater than chunk size"
raise ParameterError(msg)
if (stop - start) < length and not keep_partials:
# If the step is known, we need to account for it in the total duration
# See 474.
_raw_duration = stop - start
duration = _raw_duration + step if step is not None else _raw_duration
if duration < length and not keep_partials:
msg = "Cant chunk when data interval is less than chunk size. "
raise ChunkError(msg)
# reference with no overlap
Expand All @@ -93,12 +97,13 @@ def get_intervals(
ends = reference + length - step
starts = reference
# trim end to not surpass stop
if ends[-1] > stop:
bad_ends = ends > stop
if bad_ends.any():
if not keep_partials:
ends_filt = ends <= stop
ends, starts = ends[ends_filt], starts[ends_filt]
else:
ends[-1] = stop
ends[bad_ends] = stop
return np.stack([starts, ends]).T


Expand Down Expand Up @@ -263,6 +268,8 @@ def _get_chunk_overlap_inds(self, src1, src2, chu1, chu2):
"""Get an index mapping from source to chunk."""
chunk_starts = np.searchsorted(src1, chu1, side="right") - 1
chunk_ends = np.searchsorted(src2, chu2, side="left")
# Ensure no chunks run off the end of the source.
assert np.all(chunk_ends < len(src1)), "Invalid chunk range found"
# add 1 to end so it is an exclusive end range
return np.stack([chunk_starts, chunk_ends + 1], axis=1)

Expand Down
17 changes: 17 additions & 0 deletions tests/test_core/test_patch_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,23 @@ def test_too_big_overlap_raises(self, diverse_spool):
with pytest.raises(ParameterError, match=msg):
diverse_spool.chunk(time=10, overlap=11)

def test_issue_474(self, random_spool):
"""Ensure spools can be chunked with the duration reported by coord."""
# See #474
patch1 = random_spool.chunk(time=...)[0]
duration = patch1.coords.coord_range("time")
merged2 = random_spool.chunk(time=duration)
patch2 = merged2[0]
assert patch1.equals(patch2)

def test_issue_475(self, diverse_spool):
"""Ensure the partially chunked spool can be merged."""
# See #475
spool = diverse_spool.chunk(time=3, overlap=1, keep_partial=True)
merged_spool = spool.chunk(time=None)
assert isinstance(merged_spool, dc.BaseSpool)
assert len(merged_spool)


class TestChunkMerge:
"""Tests for merging patches together using chunk method."""
Expand Down

0 comments on commit a2b569e

Please sign in to comment.