From 10acb443ecef1c64d615d45719e07bc7e7f4d769 Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Mon, 2 Dec 2024 18:20:53 +0100 Subject: [PATCH] Make stft reachable from xdas.signal + more tests. --- tests/test_atoms.py | 10 +++---- tests/test_signal.py | 67 +++++++++++++++++++++++++++----------------- xdas/signal.py | 1 + xdas/spectral.py | 3 +- 4 files changed, 49 insertions(+), 32 deletions(-) diff --git a/tests/test_atoms.py b/tests/test_atoms.py index 63a8d63..7dc8338 100644 --- a/tests/test_atoms.py +++ b/tests/test_atoms.py @@ -3,7 +3,7 @@ import xdas import xdas as xd -import xdas.signal as xp +import xdas.signal as xs from xdas.atoms import ( DownSample, FIRFilter, @@ -22,8 +22,8 @@ class TestPartialAtom: def test_init(self): Sequential( [ - Partial(xp.taper, dim="time"), - Partial(xp.taper, dim="distance"), + Partial(xs.taper, dim="time"), + Partial(xs.taper, dim="distance"), Partial(np.abs), Partial(np.square), ] @@ -176,7 +176,7 @@ def test_firfilter(self): da = wavelet_wavefronts() chunks = xdas.split(da, 6, "time") taps = sp.firwin(11, 0.4, pass_zero="lowpass") - expected = xp.lfilter(taps, 1.0, da, "time") + expected = xs.lfilter(taps, 1.0, da, "time") expected["time"] -= np.timedelta64(20, "ms") * 5 atom = FIRFilter(11, 10.0, "lowpass", dim="time") result = atom(da) @@ -194,7 +194,7 @@ def test_resample_poly(self): da = wavelet_wavefronts() chunks = xdas.split(da, 6, "time") - expected = xp.resample_poly(da, 5, 2, "time") + expected = xs.resample_poly(da, 5, 2, "time") atom = ResamplePoly(125, maxfactor=10, dim="time") result = atom(da) result_chunked = xdas.concatenate( diff --git a/tests/test_signal.py b/tests/test_signal.py index 6ef55b4..be421a0 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -3,8 +3,8 @@ import xarray as xr import xdas -import xdas.signal as xp -from xdas.synthetics import randn_wavefronts, wavelet_wavefronts +import xdas.signal as xs +from xdas.synthetics import wavelet_wavefronts class TestSignal: @@ -28,8 +28,8 @@ def test_get_sample_spacing(self): }, }, ) - assert xp.get_sampling_interval(da, "time") == 0.008 - assert xp.get_sampling_interval(da, "distance") == 5.0 + assert xs.get_sampling_interval(da, "time") == 0.008 + assert xs.get_sampling_interval(da, "distance") == 5.0 def test_deterend(self): n = 100 @@ -37,7 +37,7 @@ def test_deterend(self): s = d * np.arange(n) da = xr.DataArray(np.arange(n), {"time": s}) da = xdas.DataArray.from_xarray(da) - da = xp.detrend(da) + da = xs.detrend(da) assert np.allclose(da.values, np.zeros(n)) def test_differentiate(self): @@ -46,7 +46,7 @@ def test_differentiate(self): s = (d / 2) + d * np.arange(n) da = xr.DataArray(np.ones(n), {"distance": s}) da = xdas.DataArray.from_xarray(da) - da = xp.differentiate(da, midpoints=True) + da = xs.differentiate(da, midpoints=True) assert np.allclose(da.values, np.zeros(n - 1)) def test_integrate(self): @@ -55,7 +55,7 @@ def test_integrate(self): s = (d / 2) + d * np.arange(n) da = xr.DataArray(np.ones(n), {"distance": s}) da = xdas.DataArray.from_xarray(da) - da = xp.integrate(da, midpoints=True) + da = xs.integrate(da, midpoints=True) assert np.allclose(da.values, da["distance"].values) def test_segment_mean_removal(self): @@ -69,7 +69,7 @@ def test_segment_mean_removal(self): da.loc[{"distance": slice(limits[0], limits[1])}] = 1.0 da.loc[{"distance": slice(limits[1], limits[2])}] = 2.0 da = xdas.DataArray.from_xarray(da) - da = xp.segment_mean_removal(da, limits) + da = xs.segment_mean_removal(da, limits) assert np.allclose(da.values, 0) def test_sliding_window_removal(self): @@ -80,55 +80,55 @@ def test_sliding_window_removal(self): data = np.ones(n) da = xr.DataArray(data, {"distance": s}) da = xdas.DataArray.from_xarray(da) - da = xp.sliding_mean_removal(da, 0.1 * n * d) + da = xs.sliding_mean_removal(da, 0.1 * n * d) assert np.allclose(da.values, 0) def test_medfilt(self): da = wavelet_wavefronts() - result1 = xp.medfilt(da, {"distance": 3}) - result2 = xp.medfilt(da, {"time": 1, "distance": 3}) + result1 = xs.medfilt(da, {"distance": 3}) + result2 = xs.medfilt(da, {"time": 1, "distance": 3}) assert result1.equals(result2) da.data = np.zeros(da.shape) - assert da.equals(xp.medfilt(da, {"time": 7, "distance": 3})) + assert da.equals(xs.medfilt(da, {"time": 7, "distance": 3})) def test_hilbert(self): da = wavelet_wavefronts() - result = xp.hilbert(da, dim="time") + result = xs.hilbert(da, dim="time") assert np.allclose(da.values, np.real(result.values)) def test_resample(self): da = wavelet_wavefronts() - result = xp.resample(da, 100, dim="time", window="hamming", domain="time") + result = xs.resample(da, 100, dim="time", window="hamming", domain="time") assert result.sizes["time"] == 100 def test_resample_poly(self): da = wavelet_wavefronts() - result = xp.resample_poly(da, 2, 5, dim="time") + result = xs.resample_poly(da, 2, 5, dim="time") assert result.sizes["time"] == 120 def test_lfilter(self): da = wavelet_wavefronts() b, a = sp.iirfilter(4, 0.5, btype="low") - result1 = xp.lfilter(b, a, da, "time") - result2, zf = xp.lfilter(b, a, da, "time", zi=...) + result1 = xs.lfilter(b, a, da, "time") + result2, zf = xs.lfilter(b, a, da, "time", zi=...) assert result1.equals(result2) def test_filtfilt(self): da = wavelet_wavefronts() b, a = sp.iirfilter(2, 0.5, btype="low") - xp.filtfilt(b, a, da, "time", padtype=None) + xs.filtfilt(b, a, da, "time", padtype=None) def test_sosfilter(self): da = wavelet_wavefronts() sos = sp.iirfilter(4, 0.5, btype="low", output="sos") - result1 = xp.sosfilt(sos, da, "time") - result2, zf = xp.sosfilt(sos, da, "time", zi=...) + result1 = xs.sosfilt(sos, da, "time") + result2, zf = xs.sosfilt(sos, da, "time", zi=...) assert result1.equals(result2) def test_sosfiltfilt(self): da = wavelet_wavefronts() sos = sp.iirfilter(2, 0.5, btype="low", output="sos") - xp.sosfiltfilt(sos, da, "time", padtype=None) + xs.sosfiltfilt(sos, da, "time", padtype=None) def test_filter(self): da = wavelet_wavefronts() @@ -143,7 +143,7 @@ def test_filter(self): ) data = sp.sosfilt(sos, da.values, axis=axis) expected = da.copy(data=data) - result = xp.filter( + result = xs.filter( da, [5, 10], btype="band", @@ -155,7 +155,7 @@ def test_filter(self): assert result.equals(expected) data = sp.sosfiltfilt(sos, da.values, axis=axis) expected = da.copy(data=data) - result = xp.filter( + result = xs.filter( da, [5, 10], btype="band", @@ -169,8 +169,6 @@ def test_filter(self): class TestSTFT: def test_stft(self): - from xdas.spectral import stft - starttime = np.datetime64("2023-01-01T00:00:00") endtime = starttime + 9999 * np.timedelta64(10, "ms") da = xdas.DataArray( @@ -180,10 +178,27 @@ def test_stft(self): "distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]}, }, ) - result = stft( + xs.stft( da, nperseg=100, noverlap=50, window="hamming", dim={"time": "frequency"}, ) + + def test_signal(self): + fs = 10e3 + N = 1e5 + fc = 3e3 + amp = 2 * np.sqrt(2) + time = np.arange(N) / float(fs) + data = amp * np.sin(2 * np.pi * fc * time) + da = xdas.DataArray( + data=data, + coords={"time": time}, + ) + result = xs.stft( + da, nperseg=1000, noverlap=500, window="hann", dim={"time": "frequency"} + ) + idx = int(np.abs(np.square(result)).mean("time").argmax("frequency").values) + assert result["frequency"][idx].values == fc diff --git a/xdas/signal.py b/xdas/signal.py index 1fac292..a062c23 100644 --- a/xdas/signal.py +++ b/xdas/signal.py @@ -5,6 +5,7 @@ from .core.coordinates import Coordinate, get_sampling_interval from .core.dataarray import DataArray from .parallel import parallelize +from .spectral import stft @atomized diff --git a/xdas/spectral.py b/xdas/spectral.py index d9b508a..55e4089 100644 --- a/xdas/spectral.py +++ b/xdas/spectral.py @@ -2,7 +2,8 @@ from scipy.fft import fft, fftfreq, fftshift, rfft, rfftfreq from scipy.signal import get_window -from . import DataArray, get_sampling_interval +from .core.coordinates import get_sampling_interval +from .core.dataarray import DataArray def stft(