From 334f45a69838855b946d741b2357034bc6e6a084 Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Fri, 15 Nov 2024 13:14:18 +0100 Subject: [PATCH] backup --- tests/test_signal.py | 24 +++++++++++++- xdas/core/routines.py | 2 +- xdas/spectral.py | 74 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 xdas/spectral.py diff --git a/tests/test_signal.py b/tests/test_signal.py index 18b561b..6f2a6e6 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -4,7 +4,7 @@ import xdas import xdas.signal as xp -from xdas.synthetics import wavelet_wavefronts +from xdas.synthetics import randn_wavefronts, wavelet_wavefronts class TestSignal: @@ -130,6 +130,28 @@ def test_sosfiltfilt(self): sos = sp.iirfilter(2, 0.5, btype="low", output="sos") xp.sosfiltfilt(sos, da, "time", padtype=None) + +class TestSTFT: + def test_stft(self): + from xdas.spectral import stft + + starttime = np.datetime64("2023-01-01T00:00:00") + endtime = starttime + 9999 * np.timedelta64(10, "ms") + da = xdas.DataArray( + data=np.zeros((10000, 11)), + coords={ + "time": {"tie_indices": [0, 9999], "tie_values": [starttime, endtime]}, + "distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]}, + }, + ) + result = stft( + da, + nperseg=100, + noverlap=50, + window="hamming", + dim={"time": "frequency"}, + ) + def test_filter(self): da = wavelet_wavefronts() axis = da.get_axis_num("time") diff --git a/xdas/core/routines.py b/xdas/core/routines.py index a676e8d..11f1a9e 100644 --- a/xdas/core/routines.py +++ b/xdas/core/routines.py @@ -967,7 +967,7 @@ def plot_availability(obj, dim="first", **kwargs): ------- fig : plotly.graph_objects.Figure The timeline - + Notes ----- This function uses the `px.timeline` function from the `plotly.express` library. diff --git a/xdas/spectral.py b/xdas/spectral.py new file mode 100644 index 0000000..f834d16 --- /dev/null +++ b/xdas/spectral.py @@ -0,0 +1,74 @@ +import numpy as np +from scipy.fft import fft, fftfreq, rfft, rfftfreq +from scipy.signal import get_window + +from . import DataArray, get_sampling_interval + + +def stft( + da, + window="hann", + nperseg=256, + noverlap=None, + nfft=None, + return_onesided=True, + dim={"last": "sprectrum"}, + scaling="spectrum", + # parallel=None, +): + if noverlap is None: + noverlap = nperseg // 2 + if nfft is None: + nfft = nperseg + win = get_window(window, nperseg) + input_dim, output_dim = next(iter(dim.items())) + axis = da.get_axis_num(input_dim) + dt = get_sampling_interval(da, input_dim) + if scaling == "density": + scale = 1.0 / ((win * win).sum() / dt) + elif scaling == "spectrum": + scale = 1.0 / win.sum() ** 2 + else: + raise ValueError("Scaling must be 'density' or 'spectrum'") + scale = np.sqrt(scale) + if return_onesided: + freqs = rfftfreq(nfft, dt) + else: + freqs = fftfreq(nfft, dt) + freqs = {"tie_indices": [0, nfft - 1], "tie_values": [freqs[0], freqs[-1]]} + + def func(x): + if nperseg == 1 and noverlap == 0: + result = x[..., np.newaxis] + else: + step = nperseg - noverlap + result = np.lib.stride_tricks.sliding_window_view( + x, window_shape=nperseg, axis=axis, writeable=True + ) + slc = [slice(None)] * result.ndim + slc[axis] = slice(None, None, step) + result = result[tuple(slc)] + result = win * result + if return_onesided: + result = rfft(result, n=nfft) + else: + result = fft(result, n=nfft) + result *= scale + return result + + data = func(da.values) + + dt = get_sampling_interval(da, input_dim, cast=False) + t0 = da.coords[input_dim].values[0] + starttime = t0 + (nperseg / 2) * dt + endtime = t0 + (da.shape[-1] - nperseg / 2) * dt + time = {"tie_indices": [0, da.shape[-1] - 1], "tie_values": [starttime, endtime]} + + coords = da.coords.copy() + coords[input_dim] = time + coords[output_dim] = freqs + + result = DataArray(data, coords) + + dims = dim + (output_dim,) + return result.transpose(*dims)