Skip to content

Commit

Permalink
fix_correlate_shift_coord_length (#485)
Browse files Browse the repository at this point in the history
* fix_correlate_shift_coord_length

* remove_commented_breakpoint

* try drop uv

* lint

* add test deps

* handle the units

---------

Co-authored-by: Derrick Chambers <chambers.ja.derrick@gmail.com>
  • Loading branch information
ahmadtourei and d-chambers authored Jan 16, 2025
1 parent 1dc1ce4 commit 74c0d16
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 16 deletions.
15 changes: 8 additions & 7 deletions .github/workflows/run_min_dep_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,16 @@ jobs:
install-package: false
environment-file: './.github/min_deps_environment.yml'

# Then switch over to uv. We can use this exclusively once we drop pytables.
- name: Install uv
uses: astral-sh/setup-uv@v3
- name: Install dascore (min deps)
shell: bash -l {0}
run: pip install -e .[test]

# Runs test suite and calculates coverage
- name: run test suite
shell: bash -l {0}
run: uv run --extra test --python ${{ matrix.python-version }} pytest -s --cov dascore --cov-append --cov-report=xml
shell: bash -el {0}
run: ./.github/test_code.sh

# Runs examples in docstrings
- name: test docstrings
shell: bash -l {0}
run: uv run --extra test --python ${{ matrix.python-version }} pytest dascore --doctest-modules
shell: bash -el {0}
run: ./.github/test_code.sh doctest
8 changes: 7 additions & 1 deletion dascore/proc/correlate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ def correlate_shift(patch, dim, undo_weighting=True):
step = coord.step
new_start = -np.ceil((len(coord) - 1) / 2) * step
new_end = np.ceil((len(coord) - 1) / 2) * step
new_coord = dc.get_coord(start=new_start, stop=new_end, step=step)
# if len(coord) % 2 != 0: # Odd coord length
# new_end += step
_new_coord = dc.get_coord(
start=new_start, stop=new_end, step=step, units=coord.units
)
new_coord = _new_coord.change_length(len(coord))
# new_coord = dc.get_coord(start=new_start, stop=new_end, step=step)
assert len(new_coord) == len(coord)
cm = patch.coords
new_cm = cm.update(**{dim: new_coord}).rename_coord(**{dim: f"lag_{dim}"})
Expand Down
13 changes: 6 additions & 7 deletions dascore/proc/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,10 @@ def notch_filter(patch: PatchType, q, **kwargs) -> PatchType:
Used to specify the dimension(s) and associated frequency and/or wavelength
(or equivalent values) for the filter.
Notes
-----
See [scipy.signal.iirnotch]
(https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirnotch.html)
for more information.
See Also
--------
[scipy.signal.iirnotch](https://docs.scipy.org/doc/scipy/reference
/generated/scipy.signal.iirnotch.html).
Examples
--------
Expand All @@ -286,10 +285,10 @@ def notch_filter(patch: PatchType, q, **kwargs) -> PatchType:
>>> filtered = pa.notch_filter(time=60, q=30)
>>> # Apply a notch filter along distance axis to remove 5 m wavelength
>>> filtered = pa.notch_filter(distance=0.2, q=10)
>>> filtered = pa.notch_filter(distance=0.2, q=30)
>>> # Apply a notch filter along both time and distance axes
>>> filtered = pa.notch_filter(time=60, distance=0.2, q=40)
>>> filtered = pa.notch_filter(time=60, distance=0.2, q=30)
>>> # Optionally, units can be specified for a more expressive API.
>>> from dascore.units import m, ft, s, Hz
Expand Down
1 change: 0 additions & 1 deletion docs/recipes/how_to_contribute.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ pip install -e ".[dev]"
```

```bash
cd dascore
pytest
```

Expand Down
28 changes: 28 additions & 0 deletions tests/test_proc/test_correlate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
class TestCorrelateShift:
"""Tests for the correlation shift function."""

@pytest.fixture(scope="class")
def random_patch_odd(self):
"""Create a random patch with odd number of time samples."""
patch = dc.get_example_patch("random_das", shape=(2, 11))
return patch

@pytest.fixture(scope="class")
def random_patch_even(self):
"""Create a random patch with even number of time samples."""
patch = dc.get_example_patch("random_das", shape=(2, 10))
return patch

def test_auto_correlation(self, random_dft_patch):
"""Perform auto correlation and undo shifting."""
dft_conj = random_dft_patch.conj()
Expand All @@ -26,6 +38,22 @@ def test_auto_correlation(self, random_dft_patch):
argmax = np.argmax(random_dft_patch.data, axis=time_ax)
assert np.all(coord_array[argmax] == dc.to_timedelta64(0))

def test_auto_correlation_odd_coord(self, random_patch_odd):
"""Ensure correlate_shift works when dim's coord length is odd."""
dft = random_patch_odd.dft(dim="time")
dft_conj = dft.conj()
dft_sq = dft * dft_conj
idft = dft_sq.idft()
assert isinstance(idft.correlate_shift(dim="time"), dc.Patch)

def test_auto_correlation_even_coord(self, random_patch_even):
"""Ensure correlate_shift works when dim's coord length is even."""
dft = random_patch_even.dft(dim="time")
dft_conj = dft.conj()
dft_sq = dft * dft_conj
idft = dft_sq.idft()
assert isinstance(idft.correlate_shift(dim="time"), dc.Patch)


class TestCorrelateInternal:
"""Tests case of intra-patch correlation function."""
Expand Down

0 comments on commit 74c0d16

Please sign in to comment.