Skip to content

Commit

Permalink
fix_correlate_shift_coord_length
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmadtourei committed Jan 13, 2025
1 parent 1dc1ce4 commit e5f76ff
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
3 changes: 3 additions & 0 deletions dascore/proc/correlate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def correlate_shift(patch, dim, undo_weighting=True):
step = coord.step
new_start = -np.ceil((len(coord) - 1) / 2) * step
new_end = np.ceil((len(coord) - 1) / 2) * step
if len(coord) % 2 != 0: # Odd coord length
new_end += step
new_coord = dc.get_coord(start=new_start, stop=new_end, step=step)
assert len(new_coord) == len(coord)
cm = patch.coords
Expand Down Expand Up @@ -214,5 +216,6 @@ def correlate(
# Undo fft if this function did one, shift, and update coord.
if not input_dft:
idft = out.idft.func(out)
# breakpoint()
out = idft.correlate_shift.func(idft, fft_dim)
return out
1 change: 0 additions & 1 deletion docs/recipes/how_to_contribute.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ pip install -e ".[dev]"
```

```bash
cd dascore
pytest
```

Expand Down
28 changes: 28 additions & 0 deletions tests/test_proc/test_correlate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
class TestCorrelateShift:
"""Tests for the correlation shift function."""

@pytest.fixture(scope="class")
def random_patch_odd(self):
"""Create a random patch with odd number of time samples."""
patch = dc.get_example_patch("random_das", shape=(2, 11))
return patch

@pytest.fixture(scope="class")
def random_patch_even(self):
"""Create a random patch with even number of time samples."""
patch = dc.get_example_patch("random_das", shape=(2, 10))
return patch

def test_auto_correlation(self, random_dft_patch):
"""Perform auto correlation and undo shifting."""
dft_conj = random_dft_patch.conj()
Expand All @@ -26,6 +38,22 @@ def test_auto_correlation(self, random_dft_patch):
argmax = np.argmax(random_dft_patch.data, axis=time_ax)
assert np.all(coord_array[argmax] == dc.to_timedelta64(0))

def test_auto_correlation_odd_coord(self, random_patch_odd):
"""Ensure correlate_shift works when dim's coord length is odd."""
dft = random_patch_odd.dft(dim="time")
dft_conj = dft.conj()
dft_sq = dft * dft_conj
idft = dft_sq.idft()
assert isinstance(idft.correlate_shift(dim="time"), dc.Patch)

def test_auto_correlation_even_coord(self, random_patch_even):
"""Ensure correlate_shift works when dim's coord length is even."""
dft = random_patch_even.dft(dim="time")
dft_conj = dft.conj()
dft_sq = dft * dft_conj
idft = dft_sq.idft()
assert isinstance(idft.correlate_shift(dim="time"), dc.Patch)


class TestCorrelateInternal:
"""Tests case of intra-patch correlation function."""
Expand Down

0 comments on commit e5f76ff

Please sign in to comment.