From a2b569e9fb90edd0c5923bf781113db2a97d82ff Mon Sep 17 00:00:00 2001 From: derrick chambers Date: Mon, 30 Dec 2024 12:26:33 -0800 Subject: [PATCH] fix issue 475 and 474 --- dascore/utils/chunk.py | 13 ++++++++++--- tests/test_core/test_patch_chunk.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/dascore/utils/chunk.py b/dascore/utils/chunk.py index 671e6df4..ca7bb98e 100644 --- a/dascore/utils/chunk.py +++ b/dascore/utils/chunk.py @@ -78,7 +78,11 @@ def get_intervals( if overlap > length: msg = "Cant chunk when overlap is greater than chunk size" raise ParameterError(msg) - if (stop - start) < length and not keep_partials: + # If the step is known, we need to account for it in the total duration + # See 474. + _raw_duration = stop - start + duration = _raw_duration + step if step is not None else _raw_duration + if duration < length and not keep_partials: msg = "Cant chunk when data interval is less than chunk size. " raise ChunkError(msg) # reference with no overlap @@ -93,12 +97,13 @@ def get_intervals( ends = reference + length - step starts = reference # trim end to not surpass stop - if ends[-1] > stop: + bad_ends = ends > stop + if bad_ends.any(): if not keep_partials: ends_filt = ends <= stop ends, starts = ends[ends_filt], starts[ends_filt] else: - ends[-1] = stop + ends[bad_ends] = stop return np.stack([starts, ends]).T @@ -263,6 +268,8 @@ def _get_chunk_overlap_inds(self, src1, src2, chu1, chu2): """Get an index mapping from source to chunk.""" chunk_starts = np.searchsorted(src1, chu1, side="right") - 1 chunk_ends = np.searchsorted(src2, chu2, side="left") + # Ensure no chunks run off the end of the source. + assert np.all(chunk_ends < len(src1)), "Invalid chunk range found" # add 1 to end so it is an exclusive end range return np.stack([chunk_starts, chunk_ends + 1], axis=1) diff --git a/tests/test_core/test_patch_chunk.py b/tests/test_core/test_patch_chunk.py index 541826ff..8bccf7bb 100644 --- a/tests/test_core/test_patch_chunk.py +++ b/tests/test_core/test_patch_chunk.py @@ -166,6 +166,23 @@ def test_too_big_overlap_raises(self, diverse_spool): with pytest.raises(ParameterError, match=msg): diverse_spool.chunk(time=10, overlap=11) + def test_issue_474(self, random_spool): + """Ensure spools can be chunked with the duration reported by coord.""" + # See #474 + patch1 = random_spool.chunk(time=...)[0] + duration = patch1.coords.coord_range("time") + merged2 = random_spool.chunk(time=duration) + patch2 = merged2[0] + assert patch1.equals(patch2) + + def test_issue_475(self, diverse_spool): + """Ensure the partially chunked spool can be merged.""" + # See #475 + spool = diverse_spool.chunk(time=3, overlap=1, keep_partial=True) + merged_spool = spool.chunk(time=None) + assert isinstance(merged_spool, dc.BaseSpool) + assert len(merged_spool) + class TestChunkMerge: """Tests for merging patches together using chunk method."""