From e5f76ff203a39c859561ba361fb72171f2acf935 Mon Sep 17 00:00:00 2001 From: ahmadtourei Date: Mon, 13 Jan 2025 12:49:37 -0700 Subject: [PATCH] fix_correlate_shift_coord_length --- dascore/proc/correlate.py | 3 +++ docs/recipes/how_to_contribute.qmd | 1 - tests/test_proc/test_correlate.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/dascore/proc/correlate.py b/dascore/proc/correlate.py index 7fafad06..bf80c0d6 100644 --- a/dascore/proc/correlate.py +++ b/dascore/proc/correlate.py @@ -73,6 +73,8 @@ def correlate_shift(patch, dim, undo_weighting=True): step = coord.step new_start = -np.ceil((len(coord) - 1) / 2) * step new_end = np.ceil((len(coord) - 1) / 2) * step + if len(coord) % 2 != 0: # Odd coord length + new_end += step new_coord = dc.get_coord(start=new_start, stop=new_end, step=step) assert len(new_coord) == len(coord) cm = patch.coords @@ -214,5 +216,6 @@ def correlate( # Undo fft if this function did one, shift, and update coord. if not input_dft: idft = out.idft.func(out) + # breakpoint() out = idft.correlate_shift.func(idft, fft_dim) return out diff --git a/docs/recipes/how_to_contribute.qmd b/docs/recipes/how_to_contribute.qmd index 87ed3ec3..5a13f438 100644 --- a/docs/recipes/how_to_contribute.qmd +++ b/docs/recipes/how_to_contribute.qmd @@ -29,7 +29,6 @@ pip install -e ".[dev]" ``` ```bash -cd dascore pytest ``` diff --git a/tests/test_proc/test_correlate.py b/tests/test_proc/test_correlate.py index 8b559ecc..f32e5966 100644 --- a/tests/test_proc/test_correlate.py +++ b/tests/test_proc/test_correlate.py @@ -12,6 +12,18 @@ class TestCorrelateShift: """Tests for the correlation shift function.""" + @pytest.fixture(scope="class") + def random_patch_odd(self): + """Create a random patch with odd number of time samples.""" + patch = dc.get_example_patch("random_das", shape=(2, 11)) + return patch + + @pytest.fixture(scope="class") + def random_patch_even(self): + """Create a random patch with even number of time samples.""" + patch = dc.get_example_patch("random_das", shape=(2, 10)) + return patch + def test_auto_correlation(self, random_dft_patch): """Perform auto correlation and undo shifting.""" dft_conj = random_dft_patch.conj() @@ -26,6 +38,22 @@ def test_auto_correlation(self, random_dft_patch): argmax = np.argmax(random_dft_patch.data, axis=time_ax) assert np.all(coord_array[argmax] == dc.to_timedelta64(0)) + def test_auto_correlation_odd_coord(self, random_patch_odd): + """Ensure correlate_shift works when dim's coord length is odd.""" + dft = random_patch_odd.dft(dim="time") + dft_conj = dft.conj() + dft_sq = dft * dft_conj + idft = dft_sq.idft() + assert isinstance(idft.correlate_shift(dim="time"), dc.Patch) + + def test_auto_correlation_even_coord(self, random_patch_even): + """Ensure correlate_shift works when dim's coord length is even.""" + dft = random_patch_even.dft(dim="time") + dft_conj = dft.conj() + dft_sq = dft * dft_conj + idft = dft_sq.idft() + assert isinstance(idft.correlate_shift(dim="time"), dc.Patch) + class TestCorrelateInternal: """Tests case of intra-patch correlation function."""