Skip to content

Commit

Permalink
Make stft reachable from xdas.signal + more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
atrabattoni committed Dec 2, 2024
1 parent 205256e commit 10acb44
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 32 deletions.
10 changes: 5 additions & 5 deletions tests/test_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import xdas
import xdas as xd
import xdas.signal as xp
import xdas.signal as xs
from xdas.atoms import (
DownSample,
FIRFilter,
Expand All @@ -22,8 +22,8 @@ class TestPartialAtom:
def test_init(self):
Sequential(
[
Partial(xp.taper, dim="time"),
Partial(xp.taper, dim="distance"),
Partial(xs.taper, dim="time"),
Partial(xs.taper, dim="distance"),
Partial(np.abs),
Partial(np.square),
]
Expand Down Expand Up @@ -176,7 +176,7 @@ def test_firfilter(self):
da = wavelet_wavefronts()
chunks = xdas.split(da, 6, "time")
taps = sp.firwin(11, 0.4, pass_zero="lowpass")
expected = xp.lfilter(taps, 1.0, da, "time")
expected = xs.lfilter(taps, 1.0, da, "time")
expected["time"] -= np.timedelta64(20, "ms") * 5
atom = FIRFilter(11, 10.0, "lowpass", dim="time")
result = atom(da)
Expand All @@ -194,7 +194,7 @@ def test_resample_poly(self):
da = wavelet_wavefronts()
chunks = xdas.split(da, 6, "time")

expected = xp.resample_poly(da, 5, 2, "time")
expected = xs.resample_poly(da, 5, 2, "time")
atom = ResamplePoly(125, maxfactor=10, dim="time")
result = atom(da)
result_chunked = xdas.concatenate(
Expand Down
67 changes: 41 additions & 26 deletions tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import xarray as xr

import xdas
import xdas.signal as xp
from xdas.synthetics import randn_wavefronts, wavelet_wavefronts
import xdas.signal as xs
from xdas.synthetics import wavelet_wavefronts


class TestSignal:
Expand All @@ -28,16 +28,16 @@ def test_get_sample_spacing(self):
},
},
)
assert xp.get_sampling_interval(da, "time") == 0.008
assert xp.get_sampling_interval(da, "distance") == 5.0
assert xs.get_sampling_interval(da, "time") == 0.008
assert xs.get_sampling_interval(da, "distance") == 5.0

def test_deterend(self):
n = 100
d = 5.0
s = d * np.arange(n)
da = xr.DataArray(np.arange(n), {"time": s})
da = xdas.DataArray.from_xarray(da)
da = xp.detrend(da)
da = xs.detrend(da)
assert np.allclose(da.values, np.zeros(n))

def test_differentiate(self):
Expand All @@ -46,7 +46,7 @@ def test_differentiate(self):
s = (d / 2) + d * np.arange(n)
da = xr.DataArray(np.ones(n), {"distance": s})
da = xdas.DataArray.from_xarray(da)
da = xp.differentiate(da, midpoints=True)
da = xs.differentiate(da, midpoints=True)
assert np.allclose(da.values, np.zeros(n - 1))

def test_integrate(self):
Expand All @@ -55,7 +55,7 @@ def test_integrate(self):
s = (d / 2) + d * np.arange(n)
da = xr.DataArray(np.ones(n), {"distance": s})
da = xdas.DataArray.from_xarray(da)
da = xp.integrate(da, midpoints=True)
da = xs.integrate(da, midpoints=True)
assert np.allclose(da.values, da["distance"].values)

def test_segment_mean_removal(self):
Expand All @@ -69,7 +69,7 @@ def test_segment_mean_removal(self):
da.loc[{"distance": slice(limits[0], limits[1])}] = 1.0
da.loc[{"distance": slice(limits[1], limits[2])}] = 2.0
da = xdas.DataArray.from_xarray(da)
da = xp.segment_mean_removal(da, limits)
da = xs.segment_mean_removal(da, limits)
assert np.allclose(da.values, 0)

def test_sliding_window_removal(self):
Expand All @@ -80,55 +80,55 @@ def test_sliding_window_removal(self):
data = np.ones(n)
da = xr.DataArray(data, {"distance": s})
da = xdas.DataArray.from_xarray(da)
da = xp.sliding_mean_removal(da, 0.1 * n * d)
da = xs.sliding_mean_removal(da, 0.1 * n * d)
assert np.allclose(da.values, 0)

def test_medfilt(self):
da = wavelet_wavefronts()
result1 = xp.medfilt(da, {"distance": 3})
result2 = xp.medfilt(da, {"time": 1, "distance": 3})
result1 = xs.medfilt(da, {"distance": 3})
result2 = xs.medfilt(da, {"time": 1, "distance": 3})
assert result1.equals(result2)
da.data = np.zeros(da.shape)
assert da.equals(xp.medfilt(da, {"time": 7, "distance": 3}))
assert da.equals(xs.medfilt(da, {"time": 7, "distance": 3}))

def test_hilbert(self):
da = wavelet_wavefronts()
result = xp.hilbert(da, dim="time")
result = xs.hilbert(da, dim="time")
assert np.allclose(da.values, np.real(result.values))

def test_resample(self):
da = wavelet_wavefronts()
result = xp.resample(da, 100, dim="time", window="hamming", domain="time")
result = xs.resample(da, 100, dim="time", window="hamming", domain="time")
assert result.sizes["time"] == 100

def test_resample_poly(self):
da = wavelet_wavefronts()
result = xp.resample_poly(da, 2, 5, dim="time")
result = xs.resample_poly(da, 2, 5, dim="time")
assert result.sizes["time"] == 120

def test_lfilter(self):
da = wavelet_wavefronts()
b, a = sp.iirfilter(4, 0.5, btype="low")
result1 = xp.lfilter(b, a, da, "time")
result2, zf = xp.lfilter(b, a, da, "time", zi=...)
result1 = xs.lfilter(b, a, da, "time")
result2, zf = xs.lfilter(b, a, da, "time", zi=...)
assert result1.equals(result2)

def test_filtfilt(self):
da = wavelet_wavefronts()
b, a = sp.iirfilter(2, 0.5, btype="low")
xp.filtfilt(b, a, da, "time", padtype=None)
xs.filtfilt(b, a, da, "time", padtype=None)

def test_sosfilter(self):
da = wavelet_wavefronts()
sos = sp.iirfilter(4, 0.5, btype="low", output="sos")
result1 = xp.sosfilt(sos, da, "time")
result2, zf = xp.sosfilt(sos, da, "time", zi=...)
result1 = xs.sosfilt(sos, da, "time")
result2, zf = xs.sosfilt(sos, da, "time", zi=...)
assert result1.equals(result2)

def test_sosfiltfilt(self):
da = wavelet_wavefronts()
sos = sp.iirfilter(2, 0.5, btype="low", output="sos")
xp.sosfiltfilt(sos, da, "time", padtype=None)
xs.sosfiltfilt(sos, da, "time", padtype=None)

def test_filter(self):
da = wavelet_wavefronts()
Expand All @@ -143,7 +143,7 @@ def test_filter(self):
)
data = sp.sosfilt(sos, da.values, axis=axis)
expected = da.copy(data=data)
result = xp.filter(
result = xs.filter(
da,
[5, 10],
btype="band",
Expand All @@ -155,7 +155,7 @@ def test_filter(self):
assert result.equals(expected)
data = sp.sosfiltfilt(sos, da.values, axis=axis)
expected = da.copy(data=data)
result = xp.filter(
result = xs.filter(
da,
[5, 10],
btype="band",
Expand All @@ -169,8 +169,6 @@ def test_filter(self):

class TestSTFT:
def test_stft(self):
from xdas.spectral import stft

starttime = np.datetime64("2023-01-01T00:00:00")
endtime = starttime + 9999 * np.timedelta64(10, "ms")
da = xdas.DataArray(
Expand All @@ -180,10 +178,27 @@ def test_stft(self):
"distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]},
},
)
result = stft(
xs.stft(
da,
nperseg=100,
noverlap=50,
window="hamming",
dim={"time": "frequency"},
)

def test_signal(self):
fs = 10e3
N = 1e5
fc = 3e3
amp = 2 * np.sqrt(2)
time = np.arange(N) / float(fs)
data = amp * np.sin(2 * np.pi * fc * time)
da = xdas.DataArray(
data=data,
coords={"time": time},
)
result = xs.stft(
da, nperseg=1000, noverlap=500, window="hann", dim={"time": "frequency"}
)
idx = int(np.abs(np.square(result)).mean("time").argmax("frequency").values)
assert result["frequency"][idx].values == fc
1 change: 1 addition & 0 deletions xdas/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .core.coordinates import Coordinate, get_sampling_interval
from .core.dataarray import DataArray
from .parallel import parallelize
from .spectral import stft


@atomized
Expand Down
3 changes: 2 additions & 1 deletion xdas/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from scipy.fft import fft, fftfreq, fftshift, rfft, rfftfreq
from scipy.signal import get_window

from . import DataArray, get_sampling_interval
from .core.coordinates import get_sampling_interval
from .core.dataarray import DataArray


def stft(
Expand Down

0 comments on commit 10acb44

Please sign in to comment.