From 334f45a69838855b946d741b2357034bc6e6a084 Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Fri, 15 Nov 2024 13:14:18 +0100 Subject: [PATCH 01/20] backup --- tests/test_signal.py | 24 +++++++++++++- xdas/core/routines.py | 2 +- xdas/spectral.py | 74 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 xdas/spectral.py diff --git a/tests/test_signal.py b/tests/test_signal.py index 18b561b..6f2a6e6 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -4,7 +4,7 @@ import xdas import xdas.signal as xp -from xdas.synthetics import wavelet_wavefronts +from xdas.synthetics import randn_wavefronts, wavelet_wavefronts class TestSignal: @@ -130,6 +130,28 @@ def test_sosfiltfilt(self): sos = sp.iirfilter(2, 0.5, btype="low", output="sos") xp.sosfiltfilt(sos, da, "time", padtype=None) + +class TestSTFT: + def test_stft(self): + from xdas.spectral import stft + + starttime = np.datetime64("2023-01-01T00:00:00") + endtime = starttime + 9999 * np.timedelta64(10, "ms") + da = xdas.DataArray( + data=np.zeros((10000, 11)), + coords={ + "time": {"tie_indices": [0, 9999], "tie_values": [starttime, endtime]}, + "distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]}, + }, + ) + result = stft( + da, + nperseg=100, + noverlap=50, + window="hamming", + dim={"time": "frequency"}, + ) + def test_filter(self): da = wavelet_wavefronts() axis = da.get_axis_num("time") diff --git a/xdas/core/routines.py b/xdas/core/routines.py index a676e8d..11f1a9e 100644 --- a/xdas/core/routines.py +++ b/xdas/core/routines.py @@ -967,7 +967,7 @@ def plot_availability(obj, dim="first", **kwargs): ------- fig : plotly.graph_objects.Figure The timeline - + Notes ----- This function uses the `px.timeline` function from the `plotly.express` library. diff --git a/xdas/spectral.py b/xdas/spectral.py new file mode 100644 index 0000000..f834d16 --- /dev/null +++ b/xdas/spectral.py @@ -0,0 +1,74 @@ +import numpy as np +from scipy.fft import fft, fftfreq, rfft, rfftfreq +from scipy.signal import get_window + +from . import DataArray, get_sampling_interval + + +def stft( + da, + window="hann", + nperseg=256, + noverlap=None, + nfft=None, + return_onesided=True, + dim={"last": "sprectrum"}, + scaling="spectrum", + # parallel=None, +): + if noverlap is None: + noverlap = nperseg // 2 + if nfft is None: + nfft = nperseg + win = get_window(window, nperseg) + input_dim, output_dim = next(iter(dim.items())) + axis = da.get_axis_num(input_dim) + dt = get_sampling_interval(da, input_dim) + if scaling == "density": + scale = 1.0 / ((win * win).sum() / dt) + elif scaling == "spectrum": + scale = 1.0 / win.sum() ** 2 + else: + raise ValueError("Scaling must be 'density' or 'spectrum'") + scale = np.sqrt(scale) + if return_onesided: + freqs = rfftfreq(nfft, dt) + else: + freqs = fftfreq(nfft, dt) + freqs = {"tie_indices": [0, nfft - 1], "tie_values": [freqs[0], freqs[-1]]} + + def func(x): + if nperseg == 1 and noverlap == 0: + result = x[..., np.newaxis] + else: + step = nperseg - noverlap + result = np.lib.stride_tricks.sliding_window_view( + x, window_shape=nperseg, axis=axis, writeable=True + ) + slc = [slice(None)] * result.ndim + slc[axis] = slice(None, None, step) + result = result[tuple(slc)] + result = win * result + if return_onesided: + result = rfft(result, n=nfft) + else: + result = fft(result, n=nfft) + result *= scale + return result + + data = func(da.values) + + dt = get_sampling_interval(da, input_dim, cast=False) + t0 = da.coords[input_dim].values[0] + starttime = t0 + (nperseg / 2) * dt + endtime = t0 + (da.shape[-1] - nperseg / 2) * dt + time = {"tie_indices": [0, da.shape[-1] - 1], "tie_values": [starttime, endtime]} + + coords = da.coords.copy() + coords[input_dim] = time + coords[output_dim] = freqs + + result = DataArray(data, coords) + + dims = dim + (output_dim,) + return result.transpose(*dims) From 8b7106f52e022d7e57506c21346c85c192e7d0ab Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Fri, 29 Nov 2024 15:11:30 +0100 Subject: [PATCH 02/20] Add failing test. --- tests/test_signal.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_signal.py b/tests/test_signal.py index 18b561b..bfaca8a 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -5,6 +5,8 @@ import xdas import xdas.signal as xp from xdas.synthetics import wavelet_wavefronts +import tempfile +import os class TestSignal: @@ -164,4 +166,16 @@ def test_filter(self): dim="time", parallel=False, ) - assert result.equals(expected) \ No newline at end of file + assert result.equals(expected) + + def test_decimate_virtual_stack(self): + da = wavelet_wavefronts() + expected = xp.decimate(da, 5, dim="time") + chunks = xdas.split(da, 5, "time") + with tempfile.TemporaryDirectory() as tmpdirname: + for i, chunk in enumerate(chunks): + chunk_path = os.path.join(tmpdirname, f"chunk_{i}.nc") + chunk.to_netcdf(chunk_path) + da_virtual = xdas.open_mfdataarray(os.path.join(tmpdirname, "chunk_*.nc")) + result = xp.decimate(da_virtual, 5, dim="time") + assert result.equals(expected) From 104ec2f8f56c45d6e2a53294e6fd9321125332dd Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Fri, 29 Nov 2024 15:25:13 +0100 Subject: [PATCH 03/20] Fix decimation of virtual stacks. --- xdas/signal.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/xdas/signal.py b/xdas/signal.py index 1fac292..71078c1 100644 --- a/xdas/signal.py +++ b/xdas/signal.py @@ -708,10 +708,15 @@ def decimate(da, q, n=None, ftype="iir", zero_phase=True, dim="last", parallel=N """ axis = da.get_axis_num(dim) + dim = da.dims[axis] # TODO: this fist last thing is a bad idea... across = int(axis == 0) func = parallelize(across, across, parallel)(sp.decimate) data = func(da.values, q, n, ftype, axis, zero_phase) - return da[{dim: slice(None, None, q)}].copy(data=data) + coords = da.coords.copy() + for name in coords: + if coords[name].dim == dim: + coords[name] = coords[name][::q] + return DataArray(data, coords, da.dims, da.name, da.attrs) @atomized From fa5c0c7fdf47d2a0a2ab5b0ce24b8129bd47e69b Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Fri, 29 Nov 2024 15:25:27 +0100 Subject: [PATCH 04/20] Format. --- tests/test_fft.py | 2 ++ tests/test_signal.py | 5 +++-- xdas/core/routines.py | 2 +- xdas/fft.py | 6 ++++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/test_fft.py b/tests/test_fft.py index 928f155..371c644 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -1,7 +1,9 @@ import numpy as np + import xdas as xd import xdas.fft as xfft + class TestRFFT: def test_with_non_dimensional(self): da = xd.synthetics.wavelet_wavefronts() diff --git a/tests/test_signal.py b/tests/test_signal.py index bfaca8a..51f658d 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -1,3 +1,6 @@ +import os +import tempfile + import numpy as np import scipy.signal as sp import xarray as xr @@ -5,8 +8,6 @@ import xdas import xdas.signal as xp from xdas.synthetics import wavelet_wavefronts -import tempfile -import os class TestSignal: diff --git a/xdas/core/routines.py b/xdas/core/routines.py index a676e8d..11f1a9e 100644 --- a/xdas/core/routines.py +++ b/xdas/core/routines.py @@ -967,7 +967,7 @@ def plot_availability(obj, dim="first", **kwargs): ------- fig : plotly.graph_objects.Figure The timeline - + Notes ----- This function uses the `px.timeline` function from the `plotly.express` library. diff --git a/xdas/fft.py b/xdas/fft.py index 2599ae4..fdf5055 100644 --- a/xdas/fft.py +++ b/xdas/fft.py @@ -55,7 +55,8 @@ def fft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None): data = func(da.values) coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] - for name in da.coords if (da[name].dim != olddim or name == olddim) + for name in da.coords + if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) @@ -110,7 +111,8 @@ def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None): data = func(da.values, n, axis, norm) coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] - for name in da.coords if (da[name].dim != olddim or name == olddim) + for name in da.coords + if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) From 3a3f6c72de4a03504ae1ac4c39d94fb0d719fb80 Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Fri, 29 Nov 2024 15:44:18 +0100 Subject: [PATCH 05/20] Revert "Merge pull request #25 from xdas-dev/fix/decimate-virtual-stack" This reverts commit f727e9c37893d5ab8d214c6063192d44f273dafa, reversing changes made to 9220bfc15e163d55bed408aad1f69a8b1b0ccbec. --- tests/test_fft.py | 11 ---------- tests/test_signal.py | 51 ------------------------------------------- xdas/core/routines.py | 2 +- xdas/fft.py | 2 -- xdas/signal.py | 11 +++------- 5 files changed, 4 insertions(+), 73 deletions(-) delete mode 100644 tests/test_fft.py diff --git a/tests/test_fft.py b/tests/test_fft.py deleted file mode 100644 index 371c644..0000000 --- a/tests/test_fft.py +++ /dev/null @@ -1,11 +0,0 @@ -import numpy as np - -import xdas as xd -import xdas.fft as xfft - - -class TestRFFT: - def test_with_non_dimensional(self): - da = xd.synthetics.wavelet_wavefronts() - da["latitude"] = ("distance", np.arange(da.sizes["distance"])) - xfft.rfft(da) diff --git a/tests/test_signal.py b/tests/test_signal.py index 51f658d..92b40c5 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -1,6 +1,3 @@ -import os -import tempfile - import numpy as np import scipy.signal as sp import xarray as xr @@ -132,51 +129,3 @@ def test_sosfiltfilt(self): da = wavelet_wavefronts() sos = sp.iirfilter(2, 0.5, btype="low", output="sos") xp.sosfiltfilt(sos, da, "time", padtype=None) - - def test_filter(self): - da = wavelet_wavefronts() - axis = da.get_axis_num("time") - fs = 1 / xdas.get_sampling_interval(da, "time") - sos = sp.butter( - 4, - [5, 10], - "band", - output="sos", - fs=fs, - ) - data = sp.sosfilt(sos, da.values, axis=axis) - expected = da.copy(data=data) - result = xp.filter( - da, - [5, 10], - btype="band", - corners=4, - zerophase=False, - dim="time", - parallel=False, - ) - assert result.equals(expected) - data = sp.sosfiltfilt(sos, da.values, axis=axis) - expected = da.copy(data=data) - result = xp.filter( - da, - [5, 10], - btype="band", - corners=4, - zerophase=True, - dim="time", - parallel=False, - ) - assert result.equals(expected) - - def test_decimate_virtual_stack(self): - da = wavelet_wavefronts() - expected = xp.decimate(da, 5, dim="time") - chunks = xdas.split(da, 5, "time") - with tempfile.TemporaryDirectory() as tmpdirname: - for i, chunk in enumerate(chunks): - chunk_path = os.path.join(tmpdirname, f"chunk_{i}.nc") - chunk.to_netcdf(chunk_path) - da_virtual = xdas.open_mfdataarray(os.path.join(tmpdirname, "chunk_*.nc")) - result = xp.decimate(da_virtual, 5, dim="time") - assert result.equals(expected) diff --git a/xdas/core/routines.py b/xdas/core/routines.py index 11f1a9e..a676e8d 100644 --- a/xdas/core/routines.py +++ b/xdas/core/routines.py @@ -967,7 +967,7 @@ def plot_availability(obj, dim="first", **kwargs): ------- fig : plotly.graph_objects.Figure The timeline - + Notes ----- This function uses the `px.timeline` function from the `plotly.express` library. diff --git a/xdas/fft.py b/xdas/fft.py index fdf5055..3cd3312 100644 --- a/xdas/fft.py +++ b/xdas/fft.py @@ -56,7 +56,6 @@ def fft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None): coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] for name in da.coords - if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) @@ -112,7 +111,6 @@ def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None): coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] for name in da.coords - if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) diff --git a/xdas/signal.py b/xdas/signal.py index 71078c1..f1cd199 100644 --- a/xdas/signal.py +++ b/xdas/signal.py @@ -102,9 +102,9 @@ def filter(da, freq, btype, corners=4, zerophase=False, dim="last", parallel=Non fs = 1.0 / get_sampling_interval(da, dim) sos = sp.iirfilter(corners, freq, btype=btype, ftype="butter", output="sos", fs=fs) if zerophase: - func = parallelize((None, across), across, parallel)(sp.sosfiltfilt) - else: func = parallelize((None, across), across, parallel)(sp.sosfilt) + else: + func = parallelize((None, across), across, parallel)(sp.sosfiltfilt) data = func(sos, da.values, axis=axis) return da.copy(data=data) @@ -708,15 +708,10 @@ def decimate(da, q, n=None, ftype="iir", zero_phase=True, dim="last", parallel=N """ axis = da.get_axis_num(dim) - dim = da.dims[axis] # TODO: this fist last thing is a bad idea... across = int(axis == 0) func = parallelize(across, across, parallel)(sp.decimate) data = func(da.values, q, n, ftype, axis, zero_phase) - coords = da.coords.copy() - for name in coords: - if coords[name].dim == dim: - coords[name] = coords[name][::q] - return DataArray(data, coords, da.dims, da.name, da.attrs) + return da[{dim: slice(None, None, q)}].copy(data=data) @atomized From fa3d209fd1f0c82a1ba5fe423e998bbd6ab547bd Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Fri, 29 Nov 2024 16:01:45 +0100 Subject: [PATCH 06/20] Revert "Format." This reverts commit fa5c0c7fdf47d2a0a2ab5b0ce24b8129bd47e69b. --- tests/test_fft.py | 9 +++++++++ tests/test_signal.py | 2 ++ xdas/fft.py | 4 ++-- 3 files changed, 13 insertions(+), 2 deletions(-) create mode 100644 tests/test_fft.py diff --git a/tests/test_fft.py b/tests/test_fft.py new file mode 100644 index 0000000..928f155 --- /dev/null +++ b/tests/test_fft.py @@ -0,0 +1,9 @@ +import numpy as np +import xdas as xd +import xdas.fft as xfft + +class TestRFFT: + def test_with_non_dimensional(self): + da = xd.synthetics.wavelet_wavefronts() + da["latitude"] = ("distance", np.arange(da.sizes["distance"])) + xfft.rfft(da) diff --git a/tests/test_signal.py b/tests/test_signal.py index 92b40c5..a61dc0d 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -5,6 +5,8 @@ import xdas import xdas.signal as xp from xdas.synthetics import wavelet_wavefronts +import tempfile +import os class TestSignal: diff --git a/xdas/fft.py b/xdas/fft.py index 3cd3312..2599ae4 100644 --- a/xdas/fft.py +++ b/xdas/fft.py @@ -55,7 +55,7 @@ def fft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None): data = func(da.values) coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] - for name in da.coords + for name in da.coords if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) @@ -110,7 +110,7 @@ def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None): data = func(da.values, n, axis, norm) coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] - for name in da.coords + for name in da.coords if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) From 47ccac9be0629d3d4ce868c8dc1131799c36d507 Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Fri, 29 Nov 2024 16:11:18 +0100 Subject: [PATCH 07/20] Revert "Revert "Format."" This reverts commit fa3d209fd1f0c82a1ba5fe423e998bbd6ab547bd. --- tests/test_fft.py | 9 --------- tests/test_signal.py | 2 -- xdas/fft.py | 4 ++-- 3 files changed, 2 insertions(+), 13 deletions(-) delete mode 100644 tests/test_fft.py diff --git a/tests/test_fft.py b/tests/test_fft.py deleted file mode 100644 index 928f155..0000000 --- a/tests/test_fft.py +++ /dev/null @@ -1,9 +0,0 @@ -import numpy as np -import xdas as xd -import xdas.fft as xfft - -class TestRFFT: - def test_with_non_dimensional(self): - da = xd.synthetics.wavelet_wavefronts() - da["latitude"] = ("distance", np.arange(da.sizes["distance"])) - xfft.rfft(da) diff --git a/tests/test_signal.py b/tests/test_signal.py index a61dc0d..92b40c5 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -5,8 +5,6 @@ import xdas import xdas.signal as xp from xdas.synthetics import wavelet_wavefronts -import tempfile -import os class TestSignal: diff --git a/xdas/fft.py b/xdas/fft.py index 2599ae4..3cd3312 100644 --- a/xdas/fft.py +++ b/xdas/fft.py @@ -55,7 +55,7 @@ def fft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None): data = func(da.values) coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] - for name in da.coords if (da[name].dim != olddim or name == olddim) + for name in da.coords } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) @@ -110,7 +110,7 @@ def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None): data = func(da.values, n, axis, norm) coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] - for name in da.coords if (da[name].dim != olddim or name == olddim) + for name in da.coords } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) From 16c112562bbc8415b59dfc241917027ba44793ed Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Fri, 29 Nov 2024 16:11:52 +0100 Subject: [PATCH 08/20] Revert "Revert "Merge pull request #25 from xdas-dev/fix/decimate-virtual-stack"" This reverts commit 3a3f6c72de4a03504ae1ac4c39d94fb0d719fb80. --- tests/test_fft.py | 11 ++++++++++ tests/test_signal.py | 51 +++++++++++++++++++++++++++++++++++++++++++ xdas/core/routines.py | 2 +- xdas/fft.py | 2 ++ xdas/signal.py | 11 +++++++--- 5 files changed, 73 insertions(+), 4 deletions(-) create mode 100644 tests/test_fft.py diff --git a/tests/test_fft.py b/tests/test_fft.py new file mode 100644 index 0000000..371c644 --- /dev/null +++ b/tests/test_fft.py @@ -0,0 +1,11 @@ +import numpy as np + +import xdas as xd +import xdas.fft as xfft + + +class TestRFFT: + def test_with_non_dimensional(self): + da = xd.synthetics.wavelet_wavefronts() + da["latitude"] = ("distance", np.arange(da.sizes["distance"])) + xfft.rfft(da) diff --git a/tests/test_signal.py b/tests/test_signal.py index 92b40c5..51f658d 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -1,3 +1,6 @@ +import os +import tempfile + import numpy as np import scipy.signal as sp import xarray as xr @@ -129,3 +132,51 @@ def test_sosfiltfilt(self): da = wavelet_wavefronts() sos = sp.iirfilter(2, 0.5, btype="low", output="sos") xp.sosfiltfilt(sos, da, "time", padtype=None) + + def test_filter(self): + da = wavelet_wavefronts() + axis = da.get_axis_num("time") + fs = 1 / xdas.get_sampling_interval(da, "time") + sos = sp.butter( + 4, + [5, 10], + "band", + output="sos", + fs=fs, + ) + data = sp.sosfilt(sos, da.values, axis=axis) + expected = da.copy(data=data) + result = xp.filter( + da, + [5, 10], + btype="band", + corners=4, + zerophase=False, + dim="time", + parallel=False, + ) + assert result.equals(expected) + data = sp.sosfiltfilt(sos, da.values, axis=axis) + expected = da.copy(data=data) + result = xp.filter( + da, + [5, 10], + btype="band", + corners=4, + zerophase=True, + dim="time", + parallel=False, + ) + assert result.equals(expected) + + def test_decimate_virtual_stack(self): + da = wavelet_wavefronts() + expected = xp.decimate(da, 5, dim="time") + chunks = xdas.split(da, 5, "time") + with tempfile.TemporaryDirectory() as tmpdirname: + for i, chunk in enumerate(chunks): + chunk_path = os.path.join(tmpdirname, f"chunk_{i}.nc") + chunk.to_netcdf(chunk_path) + da_virtual = xdas.open_mfdataarray(os.path.join(tmpdirname, "chunk_*.nc")) + result = xp.decimate(da_virtual, 5, dim="time") + assert result.equals(expected) diff --git a/xdas/core/routines.py b/xdas/core/routines.py index a676e8d..11f1a9e 100644 --- a/xdas/core/routines.py +++ b/xdas/core/routines.py @@ -967,7 +967,7 @@ def plot_availability(obj, dim="first", **kwargs): ------- fig : plotly.graph_objects.Figure The timeline - + Notes ----- This function uses the `px.timeline` function from the `plotly.express` library. diff --git a/xdas/fft.py b/xdas/fft.py index 3cd3312..fdf5055 100644 --- a/xdas/fft.py +++ b/xdas/fft.py @@ -56,6 +56,7 @@ def fft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None): coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] for name in da.coords + if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) @@ -111,6 +112,7 @@ def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None): coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] for name in da.coords + if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) diff --git a/xdas/signal.py b/xdas/signal.py index f1cd199..71078c1 100644 --- a/xdas/signal.py +++ b/xdas/signal.py @@ -102,9 +102,9 @@ def filter(da, freq, btype, corners=4, zerophase=False, dim="last", parallel=Non fs = 1.0 / get_sampling_interval(da, dim) sos = sp.iirfilter(corners, freq, btype=btype, ftype="butter", output="sos", fs=fs) if zerophase: - func = parallelize((None, across), across, parallel)(sp.sosfilt) - else: func = parallelize((None, across), across, parallel)(sp.sosfiltfilt) + else: + func = parallelize((None, across), across, parallel)(sp.sosfilt) data = func(sos, da.values, axis=axis) return da.copy(data=data) @@ -708,10 +708,15 @@ def decimate(da, q, n=None, ftype="iir", zero_phase=True, dim="last", parallel=N """ axis = da.get_axis_num(dim) + dim = da.dims[axis] # TODO: this fist last thing is a bad idea... across = int(axis == 0) func = parallelize(across, across, parallel)(sp.decimate) data = func(da.values, q, n, ftype, axis, zero_phase) - return da[{dim: slice(None, None, q)}].copy(data=data) + coords = da.coords.copy() + for name in coords: + if coords[name].dim == dim: + coords[name] = coords[name][::q] + return DataArray(data, coords, da.dims, da.name, da.attrs) @atomized From 0403c9ff155e56171e391df0341af45cc45d8abf Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Fri, 29 Nov 2024 16:19:16 +0100 Subject: [PATCH 09/20] Add failling test for sel with overlaps. --- tests/test_dataarray.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_dataarray.py b/tests/test_dataarray.py index c3a5c2a..2779cf4 100644 --- a/tests/test_dataarray.py +++ b/tests/test_dataarray.py @@ -159,6 +159,19 @@ def test_sel(self): result = da.sel(distance=0, method="nearest", drop=True) assert "distance" not in result.coords + def test_better_error_when_sel_with_overlaps(self): + da = DataArray( + np.arange(80).reshape(20, 4), + { + "time": { + "tie_values": [0.0, 0.5, 0.4, 1.0], + "tie_indices": [0, 9, 10, 19], + }, + "distance": [0.0, 10.0, 20.0, 30.0], + }, + ) + da.sel(time=slice(0.1, 0.6)) + def test_isel(self): da = wavelet_wavefronts() result = da.isel(first=0) From 19be6b41af808b5f9ab85c699f189e7efa6bf0d8 Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Fri, 29 Nov 2024 16:33:37 +0100 Subject: [PATCH 10/20] Add nicer message about overlap related errors. --- tests/test_dataarray.py | 3 ++- xdas/core/coordinates.py | 15 ++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/test_dataarray.py b/tests/test_dataarray.py index 2779cf4..16b828c 100644 --- a/tests/test_dataarray.py +++ b/tests/test_dataarray.py @@ -170,7 +170,8 @@ def test_better_error_when_sel_with_overlaps(self): "distance": [0.0, 10.0, 20.0, 30.0], }, ) - da.sel(time=slice(0.1, 0.6)) + with pytest.raises(ValueError, match="overlaps were found"): + da.sel(time=slice(0.1, 0.6)) def test_isel(self): da = wavelet_wavefronts() diff --git a/xdas/core/coordinates.py b/xdas/core/coordinates.py index 58c5837..57e6405 100644 --- a/xdas/core/coordinates.py +++ b/xdas/core/coordinates.py @@ -769,7 +769,20 @@ def get_indexer(self, value, method=None): value = np.datetime64(value) else: value = np.asarray(value) - return inverse(value, self.tie_indices, self.tie_values, method) + try: + indexer = inverse(value, self.tie_indices, self.tie_values, method) + except ValueError as e: + if str(e) == "fp must be strictly increasing": + raise ValueError( + "overlaps were found in the coordinate. If this is due to some " + "jitter in the tie values, consider smoothing the coordinate by " + "including some tolerance. This can be done by " + "`da[dim] = da[dim].simplify(tolerance)`, or by specifying a " + "tolerance when opening multiple files." + ) + else: + raise e + return indexer def slice_indexer(self, start=None, stop=None, step=None, endpoint=True): if start is not None: From f9009840fd2232f0afcabba11bed100d62207170 Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Mon, 2 Dec 2024 15:35:40 +0100 Subject: [PATCH 11/20] Add ifft with test and doc. --- docs/api/fft.md | 1 + tests/test_fft.py | 18 +++++++++++++++ xdas/fft.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+) diff --git a/docs/api/fft.md b/docs/api/fft.md index f6b0567..00b7201 100644 --- a/docs/api/fft.md +++ b/docs/api/fft.md @@ -11,5 +11,6 @@ :toctree: ../_autosummary fft + ifft rfft ``` \ No newline at end of file diff --git a/tests/test_fft.py b/tests/test_fft.py index 371c644..267af28 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -9,3 +9,21 @@ def test_with_non_dimensional(self): da = xd.synthetics.wavelet_wavefronts() da["latitude"] = ("distance", np.arange(da.sizes["distance"])) xfft.rfft(da) + + +class TestIFFT: + def test_base(self): + expected = xd.synthetics.wavelet_wavefronts() + result = xfft.ifft( + xfft.fft(expected, dim={"time": "frequency"}), dim={"frequency": "time"} + ) + assert np.allclose(np.real(result).values, expected.values) + assert np.allclose(np.imag(result).values, 0) + for name in result.coords: + if name == "time": + ref = expected["time"].values + ref = (ref - ref[0]) / np.timedelta64(1, "s") + ref += result["time"][0].values + assert np.allclose(result["time"].values, ref) + else: + assert result[name].equals(expected[name]) diff --git a/xdas/fft.py b/xdas/fft.py index fdf5055..443b1f8 100644 --- a/xdas/fft.py +++ b/xdas/fft.py @@ -116,3 +116,60 @@ def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None): } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) + + +@atomized +def ifft(da, n=None, dim={"last": "time"}, norm=None, parallel=None): + """ + Compute the inverse discrete Fourier Transform along a given dimension. + + This function computes the inverse of the one-dimensional n-point discrete Fourier + transform computed by fft. In other words, ifft(fft(a)) == a to within numerical + accuracy. + + Parameters + ---------- + da: DataArray + The data array to process, should be complex. + n: int, optional + Length of transformed dimension of the output. If n is smaller than the length + of the input, the input is cropped. If it is larger, the input is padded with + zeros. If n is not given, the length of the input along the dimension specified + by `dim` is used. + dim: {str: str}, optional + A mapping indicating as a key the dimension along which to compute the IFFT, and + as value the new name of the dimension. Default to {"last": "spectrum"}. + norm: {“backward”, “ortho”, “forward”}, optional + Normalization mode (see `numpy.fft`). Default is "backward". Indicates which + direction of the forward/backward pair of transforms is scaled and with what + normalization factor. + + Returns + ------- + DataArray: + The transformed input with an updated dimension name and values. + + Notes + ----- + - To perform a multidimensional inverse fourrier transform, repeat this function on + the desired dimensions. + + """ + ((olddim, newdim),) = dim.items() + olddim = da.dims[da.get_axis_num(olddim)] + if n is None: + n = da.sizes[olddim] + axis = da.get_axis_num(olddim) + d = get_sampling_interval(da, olddim) + f = np.fft.ifftshift(np.fft.fftfreq(n, d)) + func = lambda x: np.fft.ifft(np.fft.ifftshift(x, axis), n, axis, norm) + across = int(axis == 0) + func = parallelize(across, across, parallel)(func) + data = func(da.values) + coords = { + newdim if name == olddim else name: f if name == olddim else da.coords[name] + for name in da.coords + if (da[name].dim != olddim or name == olddim) + } + dims = tuple(newdim if dim == olddim else dim for dim in da.dims) + return DataArray(data, coords, dims, da.name, da.attrs) From 658d12387290e3e18e87ed0b21cd9cba61de3f6b Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Mon, 2 Dec 2024 16:02:22 +0100 Subject: [PATCH 12/20] Add irfft with test and doc. --- docs/api/fft.md | 1 + tests/test_fft.py | 24 ++++++++++++++++--- xdas/fft.py | 60 +++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 80 insertions(+), 5 deletions(-) diff --git a/docs/api/fft.md b/docs/api/fft.md index 00b7201..066e696 100644 --- a/docs/api/fft.md +++ b/docs/api/fft.md @@ -13,4 +13,5 @@ fft ifft rfft + irfft ``` \ No newline at end of file diff --git a/tests/test_fft.py b/tests/test_fft.py index 267af28..206e98f 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -11,11 +11,12 @@ def test_with_non_dimensional(self): xfft.rfft(da) -class TestIFFT: - def test_base(self): +class TestIncerse: + def test_standard(self): expected = xd.synthetics.wavelet_wavefronts() result = xfft.ifft( - xfft.fft(expected, dim={"time": "frequency"}), dim={"frequency": "time"} + xfft.fft(expected, dim={"time": "frequency"}), + dim={"frequency": "time"}, ) assert np.allclose(np.real(result).values, expected.values) assert np.allclose(np.imag(result).values, 0) @@ -27,3 +28,20 @@ def test_base(self): assert np.allclose(result["time"].values, ref) else: assert result[name].equals(expected[name]) + + def test_real(self): + expected = xd.synthetics.wavelet_wavefronts() + result = xfft.irfft( + xfft.rfft(expected, dim={"time": "frequency"}), + expected.sizes["time"], + dim={"frequency": "time"}, + ) + assert np.allclose(result.values, expected.values) + for name in result.coords: + if name == "time": + ref = expected["time"].values + ref = (ref - ref[0]) / np.timedelta64(1, "s") + ref += result["time"][0].values + assert np.allclose(result["time"].values, ref) + else: + assert result[name].equals(expected[name]) diff --git a/xdas/fft.py b/xdas/fft.py index 443b1f8..704ab65 100644 --- a/xdas/fft.py +++ b/xdas/fft.py @@ -63,7 +63,7 @@ def fft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None): @atomized -def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None): +def rfft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None): """ Compute the discrete Fourier Transform for real inputs along a given dimension. @@ -138,7 +138,7 @@ def ifft(da, n=None, dim={"last": "time"}, norm=None, parallel=None): by `dim` is used. dim: {str: str}, optional A mapping indicating as a key the dimension along which to compute the IFFT, and - as value the new name of the dimension. Default to {"last": "spectrum"}. + as value the new name of the dimension. Default to {"last": "time"}. norm: {“backward”, “ortho”, “forward”}, optional Normalization mode (see `numpy.fft`). Default is "backward". Indicates which direction of the forward/backward pair of transforms is scaled and with what @@ -173,3 +173,59 @@ def ifft(da, n=None, dim={"last": "time"}, norm=None, parallel=None): } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) + + +@atomized +def irfft(da, n=None, dim={"last": "time"}, norm=None, parallel=None): + """ + Compute the discrete Fourier Transform for real inputs along a given dimension. + + This function computes the one-dimensional n-point discrete Fourier Transform (DFT) + or real-valued inputs with the efficient Fast Fourier Transform (FFT) algorithm. + + Parameters + ---------- + da: DataArray + The data array to process, can be complex. + n: int, optional + Length of transformed dimension of the output. If n is smaller than the length + of the input, the input is cropped. If it is larger, the input is padded with + zeros. If n is not given, the length of the input along the dimension specified + by `dim` is used. + dim: {str: str}, optional + A mapping indicating as a key the dimension along which to compute the FFT, and + as value the new name of the dimension. Default to {"last": "time"}. + norm: {“backward”, “ortho”, “forward”}, optional + Normalization mode (see `numpy.fft`). Default is "backward". Indicates which + direction of the forward/backward pair of transforms is scaled and with what + normalization factor. + + Returns + ------- + DataArray: + The transformed input with an updated dimension name and values. The length of + the transformed dimension is (n/2)+1 if n is even or (n+1)/2 if n is odd. + + Notes + ----- + To perform a multidimensional fourrier transform, repeat this function on the + desired dimensions. + + """ + ((olddim, newdim),) = dim.items() + olddim = da.dims[da.get_axis_num(olddim)] + if n is None: + n = da.sizes[olddim] + axis = da.get_axis_num(olddim) + d = get_sampling_interval(da, olddim) + across = int(axis == 0) + func = parallelize(across, across, parallel)(np.fft.irfft) + f = np.fft.fftshift(np.fft.fftfreq(n, d)) + data = func(da.values, n, axis, norm) + coords = { + newdim if name == olddim else name: f if name == olddim else da.coords[name] + for name in da.coords + if (da[name].dim != olddim or name == olddim) + } + dims = tuple(newdim if dim == olddim else dim for dim in da.dims) + return DataArray(data, coords, dims, da.name, da.attrs) From 205256ec9d4444c2996ddbce113135b095729729 Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Mon, 2 Dec 2024 18:05:35 +0100 Subject: [PATCH 13/20] First working stft. --- tests/test_fft.py | 2 ++ tests/test_signal.py | 46 ++++++++++++++++++++++---------------------- xdas/fft.py | 6 ++++-- xdas/spectral.py | 18 ++++++++--------- 4 files changed, 38 insertions(+), 34 deletions(-) diff --git a/tests/test_fft.py b/tests/test_fft.py index 928f155..371c644 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -1,7 +1,9 @@ import numpy as np + import xdas as xd import xdas.fft as xfft + class TestRFFT: def test_with_non_dimensional(self): da = xd.synthetics.wavelet_wavefronts() diff --git a/tests/test_signal.py b/tests/test_signal.py index 6f2a6e6..6ef55b4 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -130,28 +130,6 @@ def test_sosfiltfilt(self): sos = sp.iirfilter(2, 0.5, btype="low", output="sos") xp.sosfiltfilt(sos, da, "time", padtype=None) - -class TestSTFT: - def test_stft(self): - from xdas.spectral import stft - - starttime = np.datetime64("2023-01-01T00:00:00") - endtime = starttime + 9999 * np.timedelta64(10, "ms") - da = xdas.DataArray( - data=np.zeros((10000, 11)), - coords={ - "time": {"tie_indices": [0, 9999], "tie_values": [starttime, endtime]}, - "distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]}, - }, - ) - result = stft( - da, - nperseg=100, - noverlap=50, - window="hamming", - dim={"time": "frequency"}, - ) - def test_filter(self): da = wavelet_wavefronts() axis = da.get_axis_num("time") @@ -186,4 +164,26 @@ def test_filter(self): dim="time", parallel=False, ) - assert result.equals(expected) \ No newline at end of file + assert result.equals(expected) + + +class TestSTFT: + def test_stft(self): + from xdas.spectral import stft + + starttime = np.datetime64("2023-01-01T00:00:00") + endtime = starttime + 9999 * np.timedelta64(10, "ms") + da = xdas.DataArray( + data=np.zeros((10000, 11)), + coords={ + "time": {"tie_indices": [0, 9999], "tie_values": [starttime, endtime]}, + "distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]}, + }, + ) + result = stft( + da, + nperseg=100, + noverlap=50, + window="hamming", + dim={"time": "frequency"}, + ) diff --git a/xdas/fft.py b/xdas/fft.py index 2599ae4..fdf5055 100644 --- a/xdas/fft.py +++ b/xdas/fft.py @@ -55,7 +55,8 @@ def fft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None): data = func(da.values) coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] - for name in da.coords if (da[name].dim != olddim or name == olddim) + for name in da.coords + if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) @@ -110,7 +111,8 @@ def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None): data = func(da.values, n, axis, norm) coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] - for name in da.coords if (da[name].dim != olddim or name == olddim) + for name in da.coords + if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) diff --git a/xdas/spectral.py b/xdas/spectral.py index f834d16..d9b508a 100644 --- a/xdas/spectral.py +++ b/xdas/spectral.py @@ -1,5 +1,5 @@ import numpy as np -from scipy.fft import fft, fftfreq, rfft, rfftfreq +from scipy.fft import fft, fftfreq, fftshift, rfft, rfftfreq from scipy.signal import get_window from . import DataArray, get_sampling_interval @@ -34,8 +34,8 @@ def stft( if return_onesided: freqs = rfftfreq(nfft, dt) else: - freqs = fftfreq(nfft, dt) - freqs = {"tie_indices": [0, nfft - 1], "tie_values": [freqs[0], freqs[-1]]} + freqs = fftshift(fftfreq(nfft, dt)) + freqs = {"tie_indices": [0, len(freqs) - 1], "tie_values": [freqs[0], freqs[-1]]} def func(x): if nperseg == 1 and noverlap == 0: @@ -61,14 +61,14 @@ def func(x): dt = get_sampling_interval(da, input_dim, cast=False) t0 = da.coords[input_dim].values[0] starttime = t0 + (nperseg / 2) * dt - endtime = t0 + (da.shape[-1] - nperseg / 2) * dt - time = {"tie_indices": [0, da.shape[-1] - 1], "tie_values": [starttime, endtime]} + endtime = t0 + (data.shape[axis] - nperseg / 2) * dt + time = { + "tie_indices": [0, data.shape[axis] - 1], + "tie_values": [starttime, endtime], + } coords = da.coords.copy() coords[input_dim] = time coords[output_dim] = freqs - result = DataArray(data, coords) - - dims = dim + (output_dim,) - return result.transpose(*dims) + return DataArray(data, coords) From 10acb443ecef1c64d615d45719e07bc7e7f4d769 Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Mon, 2 Dec 2024 18:20:53 +0100 Subject: [PATCH 14/20] Make stft reachable from xdas.signal + more tests. --- tests/test_atoms.py | 10 +++---- tests/test_signal.py | 67 +++++++++++++++++++++++++++----------------- xdas/signal.py | 1 + xdas/spectral.py | 3 +- 4 files changed, 49 insertions(+), 32 deletions(-) diff --git a/tests/test_atoms.py b/tests/test_atoms.py index 63a8d63..7dc8338 100644 --- a/tests/test_atoms.py +++ b/tests/test_atoms.py @@ -3,7 +3,7 @@ import xdas import xdas as xd -import xdas.signal as xp +import xdas.signal as xs from xdas.atoms import ( DownSample, FIRFilter, @@ -22,8 +22,8 @@ class TestPartialAtom: def test_init(self): Sequential( [ - Partial(xp.taper, dim="time"), - Partial(xp.taper, dim="distance"), + Partial(xs.taper, dim="time"), + Partial(xs.taper, dim="distance"), Partial(np.abs), Partial(np.square), ] @@ -176,7 +176,7 @@ def test_firfilter(self): da = wavelet_wavefronts() chunks = xdas.split(da, 6, "time") taps = sp.firwin(11, 0.4, pass_zero="lowpass") - expected = xp.lfilter(taps, 1.0, da, "time") + expected = xs.lfilter(taps, 1.0, da, "time") expected["time"] -= np.timedelta64(20, "ms") * 5 atom = FIRFilter(11, 10.0, "lowpass", dim="time") result = atom(da) @@ -194,7 +194,7 @@ def test_resample_poly(self): da = wavelet_wavefronts() chunks = xdas.split(da, 6, "time") - expected = xp.resample_poly(da, 5, 2, "time") + expected = xs.resample_poly(da, 5, 2, "time") atom = ResamplePoly(125, maxfactor=10, dim="time") result = atom(da) result_chunked = xdas.concatenate( diff --git a/tests/test_signal.py b/tests/test_signal.py index 6ef55b4..be421a0 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -3,8 +3,8 @@ import xarray as xr import xdas -import xdas.signal as xp -from xdas.synthetics import randn_wavefronts, wavelet_wavefronts +import xdas.signal as xs +from xdas.synthetics import wavelet_wavefronts class TestSignal: @@ -28,8 +28,8 @@ def test_get_sample_spacing(self): }, }, ) - assert xp.get_sampling_interval(da, "time") == 0.008 - assert xp.get_sampling_interval(da, "distance") == 5.0 + assert xs.get_sampling_interval(da, "time") == 0.008 + assert xs.get_sampling_interval(da, "distance") == 5.0 def test_deterend(self): n = 100 @@ -37,7 +37,7 @@ def test_deterend(self): s = d * np.arange(n) da = xr.DataArray(np.arange(n), {"time": s}) da = xdas.DataArray.from_xarray(da) - da = xp.detrend(da) + da = xs.detrend(da) assert np.allclose(da.values, np.zeros(n)) def test_differentiate(self): @@ -46,7 +46,7 @@ def test_differentiate(self): s = (d / 2) + d * np.arange(n) da = xr.DataArray(np.ones(n), {"distance": s}) da = xdas.DataArray.from_xarray(da) - da = xp.differentiate(da, midpoints=True) + da = xs.differentiate(da, midpoints=True) assert np.allclose(da.values, np.zeros(n - 1)) def test_integrate(self): @@ -55,7 +55,7 @@ def test_integrate(self): s = (d / 2) + d * np.arange(n) da = xr.DataArray(np.ones(n), {"distance": s}) da = xdas.DataArray.from_xarray(da) - da = xp.integrate(da, midpoints=True) + da = xs.integrate(da, midpoints=True) assert np.allclose(da.values, da["distance"].values) def test_segment_mean_removal(self): @@ -69,7 +69,7 @@ def test_segment_mean_removal(self): da.loc[{"distance": slice(limits[0], limits[1])}] = 1.0 da.loc[{"distance": slice(limits[1], limits[2])}] = 2.0 da = xdas.DataArray.from_xarray(da) - da = xp.segment_mean_removal(da, limits) + da = xs.segment_mean_removal(da, limits) assert np.allclose(da.values, 0) def test_sliding_window_removal(self): @@ -80,55 +80,55 @@ def test_sliding_window_removal(self): data = np.ones(n) da = xr.DataArray(data, {"distance": s}) da = xdas.DataArray.from_xarray(da) - da = xp.sliding_mean_removal(da, 0.1 * n * d) + da = xs.sliding_mean_removal(da, 0.1 * n * d) assert np.allclose(da.values, 0) def test_medfilt(self): da = wavelet_wavefronts() - result1 = xp.medfilt(da, {"distance": 3}) - result2 = xp.medfilt(da, {"time": 1, "distance": 3}) + result1 = xs.medfilt(da, {"distance": 3}) + result2 = xs.medfilt(da, {"time": 1, "distance": 3}) assert result1.equals(result2) da.data = np.zeros(da.shape) - assert da.equals(xp.medfilt(da, {"time": 7, "distance": 3})) + assert da.equals(xs.medfilt(da, {"time": 7, "distance": 3})) def test_hilbert(self): da = wavelet_wavefronts() - result = xp.hilbert(da, dim="time") + result = xs.hilbert(da, dim="time") assert np.allclose(da.values, np.real(result.values)) def test_resample(self): da = wavelet_wavefronts() - result = xp.resample(da, 100, dim="time", window="hamming", domain="time") + result = xs.resample(da, 100, dim="time", window="hamming", domain="time") assert result.sizes["time"] == 100 def test_resample_poly(self): da = wavelet_wavefronts() - result = xp.resample_poly(da, 2, 5, dim="time") + result = xs.resample_poly(da, 2, 5, dim="time") assert result.sizes["time"] == 120 def test_lfilter(self): da = wavelet_wavefronts() b, a = sp.iirfilter(4, 0.5, btype="low") - result1 = xp.lfilter(b, a, da, "time") - result2, zf = xp.lfilter(b, a, da, "time", zi=...) + result1 = xs.lfilter(b, a, da, "time") + result2, zf = xs.lfilter(b, a, da, "time", zi=...) assert result1.equals(result2) def test_filtfilt(self): da = wavelet_wavefronts() b, a = sp.iirfilter(2, 0.5, btype="low") - xp.filtfilt(b, a, da, "time", padtype=None) + xs.filtfilt(b, a, da, "time", padtype=None) def test_sosfilter(self): da = wavelet_wavefronts() sos = sp.iirfilter(4, 0.5, btype="low", output="sos") - result1 = xp.sosfilt(sos, da, "time") - result2, zf = xp.sosfilt(sos, da, "time", zi=...) + result1 = xs.sosfilt(sos, da, "time") + result2, zf = xs.sosfilt(sos, da, "time", zi=...) assert result1.equals(result2) def test_sosfiltfilt(self): da = wavelet_wavefronts() sos = sp.iirfilter(2, 0.5, btype="low", output="sos") - xp.sosfiltfilt(sos, da, "time", padtype=None) + xs.sosfiltfilt(sos, da, "time", padtype=None) def test_filter(self): da = wavelet_wavefronts() @@ -143,7 +143,7 @@ def test_filter(self): ) data = sp.sosfilt(sos, da.values, axis=axis) expected = da.copy(data=data) - result = xp.filter( + result = xs.filter( da, [5, 10], btype="band", @@ -155,7 +155,7 @@ def test_filter(self): assert result.equals(expected) data = sp.sosfiltfilt(sos, da.values, axis=axis) expected = da.copy(data=data) - result = xp.filter( + result = xs.filter( da, [5, 10], btype="band", @@ -169,8 +169,6 @@ def test_filter(self): class TestSTFT: def test_stft(self): - from xdas.spectral import stft - starttime = np.datetime64("2023-01-01T00:00:00") endtime = starttime + 9999 * np.timedelta64(10, "ms") da = xdas.DataArray( @@ -180,10 +178,27 @@ def test_stft(self): "distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]}, }, ) - result = stft( + xs.stft( da, nperseg=100, noverlap=50, window="hamming", dim={"time": "frequency"}, ) + + def test_signal(self): + fs = 10e3 + N = 1e5 + fc = 3e3 + amp = 2 * np.sqrt(2) + time = np.arange(N) / float(fs) + data = amp * np.sin(2 * np.pi * fc * time) + da = xdas.DataArray( + data=data, + coords={"time": time}, + ) + result = xs.stft( + da, nperseg=1000, noverlap=500, window="hann", dim={"time": "frequency"} + ) + idx = int(np.abs(np.square(result)).mean("time").argmax("frequency").values) + assert result["frequency"][idx].values == fc diff --git a/xdas/signal.py b/xdas/signal.py index 1fac292..a062c23 100644 --- a/xdas/signal.py +++ b/xdas/signal.py @@ -5,6 +5,7 @@ from .core.coordinates import Coordinate, get_sampling_interval from .core.dataarray import DataArray from .parallel import parallelize +from .spectral import stft @atomized diff --git a/xdas/spectral.py b/xdas/spectral.py index d9b508a..55e4089 100644 --- a/xdas/spectral.py +++ b/xdas/spectral.py @@ -2,7 +2,8 @@ from scipy.fft import fft, fftfreq, fftshift, rfft, rfftfreq from scipy.signal import get_window -from . import DataArray, get_sampling_interval +from .core.coordinates import get_sampling_interval +from .core.dataarray import DataArray def stft( From 8c7ee47755e24a28ee522cf3a45f1b9c09c2b488 Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Mon, 2 Dec 2024 18:36:45 +0100 Subject: [PATCH 15/20] Provide dims in stft output construction to avoir errors. --- xdas/spectral.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xdas/spectral.py b/xdas/spectral.py index 55e4089..1aafbcc 100644 --- a/xdas/spectral.py +++ b/xdas/spectral.py @@ -72,4 +72,6 @@ def func(x): coords[input_dim] = time coords[output_dim] = freqs - return DataArray(data, coords) + dims = da.dims + (output_dim,) + + return DataArray(data, coords, dims) From ece415ed88e058bd8d35432be33c838ee56dfb3e Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Tue, 3 Dec 2024 10:06:10 +0100 Subject: [PATCH 16/20] Add parallelization to STFT. --- tests/test_signal.py | 28 ++++++++++++++++++++++++++++ xdas/spectral.py | 5 ++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/test_signal.py b/tests/test_signal.py index be421a0..ca9aad5 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -202,3 +202,31 @@ def test_signal(self): ) idx = int(np.abs(np.square(result)).mean("time").argmax("frequency").values) assert result["frequency"][idx].values == fc + + def test_parrallel(self): + starttime = np.datetime64("2023-01-01T00:00:00") + endtime = starttime + 9999 * np.timedelta64(10, "ms") + da = xdas.DataArray( + data=np.random.rand(10000, 11), + coords={ + "time": {"tie_indices": [0, 9999], "tie_values": [starttime, endtime]}, + "distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]}, + }, + ) + serial = xs.stft( + da, + nperseg=100, + noverlap=50, + window="hamming", + dim={"time": "frequency"}, + parallel=False, + ) + parallel = xs.stft( + da, + nperseg=100, + noverlap=50, + window="hamming", + dim={"time": "frequency"}, + parallel=True, + ) + assert serial.equals(parallel) diff --git a/xdas/spectral.py b/xdas/spectral.py index 1aafbcc..3907d5d 100644 --- a/xdas/spectral.py +++ b/xdas/spectral.py @@ -4,6 +4,7 @@ from .core.coordinates import get_sampling_interval from .core.dataarray import DataArray +from .parallel import parallelize def stft( @@ -15,7 +16,7 @@ def stft( return_onesided=True, dim={"last": "sprectrum"}, scaling="spectrum", - # parallel=None, + parallel=None, ): if noverlap is None: noverlap = nperseg // 2 @@ -57,6 +58,8 @@ def func(x): result *= scale return result + across = int(axis == 0) + func = parallelize(across, across, parallel)(func) data = func(da.values) dt = get_sampling_interval(da, input_dim, cast=False) From 3b32b8a016eac1f6ae90ed30111ed4bb517967d5 Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Tue, 3 Dec 2024 10:16:08 +0100 Subject: [PATCH 17/20] import xdas.signal as xs. --- docs/user-guide/atoms.md | 4 ++-- docs/user-guide/convert-displacement.md | 8 +++---- tests/test_signal.py | 4 ++-- tests/test_xarray.py | 6 ++--- xdas/atoms/core.py | 14 +++++------ xdas/signal.py | 32 ++++++++++++------------- 6 files changed, 34 insertions(+), 34 deletions(-) diff --git a/docs/user-guide/atoms.md b/docs/user-guide/atoms.md index ae233fd..efcb411 100644 --- a/docs/user-guide/atoms.md +++ b/docs/user-guide/atoms.md @@ -23,12 +23,12 @@ There are three "flavours" declaring the atoms that can be used to compose a seq ```{code-cell} import numpy as np import xdas -import xdas.signal as xp +import xdas.signal as xs from xdas.atoms import Partial, Sequential, IIRFilter sequence = Sequential( [ - xp.taper(..., dim="time"), + xs.taper(..., dim="time"), Partial(np.square), IIRFilter(order=4, cutoff=1.5, btype="highpass", dim="time"), ] diff --git a/docs/user-guide/convert-displacement.md b/docs/user-guide/convert-displacement.md index 398f0ed..5ae49b6 100644 --- a/docs/user-guide/convert-displacement.md +++ b/docs/user-guide/convert-displacement.md @@ -28,11 +28,11 @@ strain_rate.plot(yincrease=False, vmin=-0.5, vmax=0.5); Then convert strain rate to deformation and then to displacement. ```{code-cell} -import xdas.signal as xp +import xdas.signal as xs -strain = xp.integrate(strain_rate, dim="time") -deformation = xp.integrate(strain, dim="distance") -displacement = xp.sliding_mean_removal(deformation, wlen=2000.0, dim="distance") +strain = xs.integrate(strain_rate, dim="time") +deformation = xs.integrate(strain, dim="distance") +displacement = xs.sliding_mean_removal(deformation, wlen=2000.0, dim="distance") displacement.plot(yincrease=False, vmin=-0.5, vmax=0.5); ``` diff --git a/tests/test_signal.py b/tests/test_signal.py index 2e0f8d0..7736873 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -171,14 +171,14 @@ def test_filter(self): def test_decimate_virtual_stack(self): da = wavelet_wavefronts() - expected = xp.decimate(da, 5, dim="time") + expected = xs.decimate(da, 5, dim="time") chunks = xdas.split(da, 5, "time") with tempfile.TemporaryDirectory() as tmpdirname: for i, chunk in enumerate(chunks): chunk_path = os.path.join(tmpdirname, f"chunk_{i}.nc") chunk.to_netcdf(chunk_path) da_virtual = xdas.open_mfdataarray(os.path.join(tmpdirname, "chunk_*.nc")) - result = xp.decimate(da_virtual, 5, dim="time") + result = xs.decimate(da_virtual, 5, dim="time") assert result.equals(expected) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 88ccefa..382a4fe 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -1,6 +1,6 @@ import numpy as np -import xdas.core.methods as xp +import xdas.core.methods as xm from xdas.core.dataarray import DataArray from xdas.synthetics import wavelet_wavefronts @@ -8,7 +8,7 @@ class TestXarray: def test_returns_dataarray(self): da = wavelet_wavefronts() - for name, func in xp.HANDLED_METHODS.items(): + for name, func in xm.HANDLED_METHODS.items(): if callable(func): if name in [ "percentile", @@ -31,7 +31,7 @@ def test_returns_dataarray(self): def test_mean(self): da = wavelet_wavefronts() - result = xp.mean(da, "time") + result = xm.mean(da, "time") result_method = da.mean("time") expected = np.mean(da, 0) assert result.equals(expected) diff --git a/xdas/atoms/core.py b/xdas/atoms/core.py index df403d2..3375ccd 100644 --- a/xdas/atoms/core.py +++ b/xdas/atoms/core.py @@ -194,15 +194,15 @@ class Sequential(Atom, list): Examples -------- >>> from xdas.atoms import Partial, Sequential - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> import numpy as np Basic usage: >>> seq = Sequential( ... [ - ... Partial(xp.taper, dim="time"), - ... Partial(xp.lfilter, [1.0], [0.5], ..., dim="time", zi=...), + ... Partial(xs.taper, dim="time"), + ... Partial(xs.lfilter, [1.0], [0.5], ..., dim="time", zi=...), ... Partial(np.square), ... ], ... name="Low frequency energy", @@ -217,7 +217,7 @@ class Sequential(Atom, list): >>> seq = Sequential( ... [ - ... Partial(xp.decimate, 16, dim="distance"), + ... Partial(xs.decimate, 16, dim="distance"), ... seq, ... ] ... ) @@ -330,12 +330,12 @@ class Partial(Atom): -------- >>> import numpy as np >>> import scipy.signal as sp - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.atoms import Partial Examples of a stateless atom: - >>> Partial(xp.decimate, 2, dim="time") + >>> Partial(xs.decimate, 2, dim="time") decimate(..., 2, dim=time) >>> Partial(np.square) @@ -344,7 +344,7 @@ class Partial(Atom): Examples of a stateful atom with input data as second argument: >>> sos = sp.iirfilter(4, 0.1, btype="lowpass", output="sos") - >>> Partial(xp.sosfilt, sos, ..., dim="time", zi=...) + >>> Partial(xs.sosfilt, sos, ..., dim="time", zi=...) sosfilt(, ..., dim=time) [stateful] """ diff --git a/xdas/signal.py b/xdas/signal.py index cd892b4..5d15552 100644 --- a/xdas/signal.py +++ b/xdas/signal.py @@ -142,11 +142,11 @@ def hilbert(da, N=None, dim="last", parallel=None): -------- In this example we use the Hilbert transform to determine the analytic signal. - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() - >>> xp.hilbert(da, dim="time") + >>> xs.hilbert(da, dim="time") [[ 0.0497+0.1632j -0.0635+0.0125j ... 0.1352-0.3107j -0.2832-0.0126j] [-0.1096-0.0335j 0.124 +0.0257j ... -0.0444+0.2409j 0.1378-0.2702j] @@ -203,11 +203,11 @@ def resample(da, num, dim="last", window=None, domain="time", parallel=None): A synthetic dataarray is resample from 300 to 100 samples along the time dimension. The 'hamming' window is used. - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() - >>> xp.resample(da, 100, dim='time', window='hamming', domain='time') + >>> xs.resample(da, 100, dim='time', window='hamming', domain='time') [[ 0.039988 0.04855 -0.08251 ... 0.02539 -0.055219 -0.006693] [-0.032913 -0.016732 0.033743 ... 0.028534 -0.037685 0.032918] @@ -294,11 +294,11 @@ def resample_poly( with an original shape of 300 in time. The choosed window is a 'hamming' window. The dataarray is synthetic data. - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() - >>> xp.resample_poly(da, 2, 5, dim='time') + >>> xs.resample_poly(da, 2, 5, dim='time') [[-0.006378 0.012767 -0.002068 ... -0.033461 0.002603 -0.027478] [ 0.008851 -0.037799 0.009595 ... 0.053291 -0.0396 0.026909] @@ -377,12 +377,12 @@ def lfilter(b, a, da, dim="last", zi=None, parallel=None): Examples -------- >>> import scipy.signal as sp - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() >>> b, a = sp.iirfilter(4, 0.5, btype="low") - >>> xp.lfilter(b, a, da, dim='time') + >>> xs.lfilter(b, a, da, dim='time') [[ 0.004668 -0.005968 0.007386 ... -0.0138 0.01271 -0.026618] [ 0.008372 -0.01222 0.022552 ... -0.041387 0.046667 -0.093521] @@ -485,12 +485,12 @@ def filtfilt( Examples -------- >>> import scipy.signal as sp - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() >>> b, a = sp.iirfilter(4, 0.5, btype="low") - >>> xp.lfilter(b, a, da, dim='time') + >>> xs.lfilter(b, a, da, dim='time') [[ 0.004668 -0.005968 0.007386 ... -0.0138 0.01271 -0.026618] [ 0.008372 -0.01222 0.022552 ... -0.041387 0.046667 -0.093521] @@ -554,12 +554,12 @@ def sosfilt(sos, da, dim="last", zi=None, parallel=None): Examples -------- >>> import scipy.signal as sp - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() >>> sos = sp.iirfilter(4, 0.5, btype="low", output="sos") - >>> xp.sosfilt(sos, da, dim='time') + >>> xs.sosfilt(sos, da, dim='time') [[ 0.004668 -0.005968 0.007386 ... -0.0138 0.01271 -0.026618] [ 0.008372 -0.01222 0.022552 ... -0.041387 0.046667 -0.093521] @@ -642,12 +642,12 @@ def sosfiltfilt(sos, da, dim="last", padtype="odd", padlen=None, parallel=None): Examples -------- >>> import scipy.signal as sp - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() >>> sos = sp.iirfilter(4, 0.5, btype="low", output="sos") - >>> xp.sosfiltfilt(sos, da, dim='time') + >>> xs.sosfiltfilt(sos, da, dim='time') [[ 0.04968 -0.063651 0.078731 ... -0.146869 0.135149 -0.283111] [-0.01724 0.018588 -0.037267 ... 0.025092 -0.107095 0.127912] @@ -904,11 +904,11 @@ def medfilt(da, kernel_dim): # TODO: parallelize A median filter is applied to some synthetic dataarray with a median window size of 7 along the time dimension and 5 along the space dimension. - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() - >>> xp.medfilt(da, {"time": 7, "distance": 5}) + >>> xs.medfilt(da, {"time": 7, "distance": 5}) [[ 0. 0. 0. ... 0. 0. 0. ] [ 0. 0. 0. ... 0. 0. 0. ] From 71c2020776631d6cdf6f1a13e1d6305d4ed483de Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Tue, 3 Dec 2024 10:44:53 +0100 Subject: [PATCH 18/20] Compare xdas stft implementation with scipy one. --- tests/test_signal.py | 50 +++++++++++++++++++++++++++++++++++--------- xdas/spectral.py | 12 +++++------ 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/tests/test_signal.py b/tests/test_signal.py index 7736873..4ae1b8a 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -182,25 +182,55 @@ def test_decimate_virtual_stack(self): assert result.equals(expected) - class TestSTFT: - def test_stft(self): + def test_compare_with_scipy(self): starttime = np.datetime64("2023-01-01T00:00:00") endtime = starttime + 9999 * np.timedelta64(10, "ms") da = xdas.DataArray( - data=np.zeros((10000, 11)), + data=np.random.rand(10000, 11), coords={ "time": {"tie_indices": [0, 9999], "tie_values": [starttime, endtime]}, "distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]}, }, ) - xs.stft( - da, - nperseg=100, - noverlap=50, - window="hamming", - dim={"time": "frequency"}, - ) + for scaling in ["spectrum", "psd"]: + for return_onesided in [True, False]: + for nfft in [None, 128]: + result = xs.stft( + da, + window="hamming", + nperseg=100, + noverlap=50, + nfft=nfft, + return_onesided=return_onesided, + dim={"time": "frequency"}, + scaling=scaling, + ) + f, t, Zxx = sp.stft( + da.values, + fs=1 / xs.get_sampling_interval(da, "time"), + window="hamming", + nperseg=100, + noverlap=50, + nfft=nfft, + return_onesided=return_onesided, + boundary=None, + axis=0, + scaling=scaling, + ) + if return_onesided: + assert np.allclose(result.values, np.transpose(Zxx, (2, 1, 0))) + else: + assert np.allclose( + result.values, + np.fft.fftshift(np.transpose(Zxx, (2, 1, 0)), axes=-1), + ) + assert np.allclose(result["frequency"].values, np.sort(f)) + assert np.allclose( + (result["time"].values - da["time"][0].values) + / np.timedelta64(1, "s"), + t, + ) def test_signal(self): fs = 10e3 diff --git a/xdas/spectral.py b/xdas/spectral.py index 3907d5d..de00544 100644 --- a/xdas/spectral.py +++ b/xdas/spectral.py @@ -26,12 +26,12 @@ def stft( input_dim, output_dim = next(iter(dim.items())) axis = da.get_axis_num(input_dim) dt = get_sampling_interval(da, input_dim) - if scaling == "density": - scale = 1.0 / ((win * win).sum() / dt) - elif scaling == "spectrum": + if scaling == "spectrum": scale = 1.0 / win.sum() ** 2 + elif scaling == "psd": + scale = 1.0 / ((win * win).sum() / dt) else: - raise ValueError("Scaling must be 'density' or 'spectrum'") + raise ValueError("Scaling must be 'spectrum' or 'psd'") scale = np.sqrt(scale) if return_onesided: freqs = rfftfreq(nfft, dt) @@ -54,7 +54,7 @@ def func(x): if return_onesided: result = rfft(result, n=nfft) else: - result = fft(result, n=nfft) + result = fftshift(fft(result, n=nfft), axes=-1) result *= scale return result @@ -65,7 +65,7 @@ def func(x): dt = get_sampling_interval(da, input_dim, cast=False) t0 = da.coords[input_dim].values[0] starttime = t0 + (nperseg / 2) * dt - endtime = t0 + (data.shape[axis] - nperseg / 2) * dt + endtime = starttime + (data.shape[axis] - 1) * (nperseg - noverlap) * dt time = { "tie_indices": [0, data.shape[axis] - 1], "tie_values": [starttime, endtime], From 146a54d506e14e283911e72b2acd60745349287c Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Tue, 3 Dec 2024 13:41:00 +0100 Subject: [PATCH 19/20] Drop non-dimensional coordinates for now in stft. --- tests/test_signal.py | 38 +++++++++++++++++++++++++++++++++++++- xdas/spectral.py | 10 ++++++++-- 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/tests/test_signal.py b/tests/test_signal.py index 4ae1b8a..e14c41a 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -215,6 +215,7 @@ def test_compare_with_scipy(self): nfft=nfft, return_onesided=return_onesided, boundary=None, + padded=False, axis=0, scaling=scaling, ) @@ -231,8 +232,9 @@ def test_compare_with_scipy(self): / np.timedelta64(1, "s"), t, ) + assert result["distance"].equals(da["distance"]) - def test_signal(self): + def test_retrieve_frequency_peak(self): fs = 10e3 N = 1e5 fc = 3e3 @@ -276,3 +278,37 @@ def test_parrallel(self): parallel=True, ) assert serial.equals(parallel) + + def test_last_dimension_with_non_dimensional_coordinates(self): + starttime = np.datetime64("2023-01-01T00:00:00") + endtime = starttime + 99 * np.timedelta64(10, "ms") + da = xdas.DataArray( + data=np.random.rand(100, 1001), + coords={ + "time": {"tie_indices": [0, 99], "tie_values": [starttime, endtime]}, + "distance": {"tie_indices": [0, 1000], "tie_values": [0.0, 10_000.0]}, + "channel": ("distance", np.arange(1001)), + }, + ) + result = xs.stft( + da, + nperseg=100, + noverlap=50, + window="hamming", + dim={"distance": "wavenumber"}, + ) + f, t, Zxx = sp.stft( + da.values, + fs=1 / xs.get_sampling_interval(da, "distance"), + window="hamming", + nperseg=100, + noverlap=50, + boundary=None, + padded=False, + axis=1, + ) + assert np.allclose(result.values, np.transpose(Zxx, (0, 2, 1))) + assert result["time"].equals(da["time"]) + assert np.allclose(result["distance"].values, t) + assert np.allclose(result["wavenumber"].values, np.sort(f)) + assert "channel" not in result.coords # TODO: keep non-dimensional coordinates diff --git a/xdas/spectral.py b/xdas/spectral.py index de00544..a709547 100644 --- a/xdas/spectral.py +++ b/xdas/spectral.py @@ -71,8 +71,14 @@ def func(x): "tie_values": [starttime, endtime], } - coords = da.coords.copy() - coords[input_dim] = time + coords = {} + for name in da.coords: + if name == input_dim: + coords[input_dim] = time + elif name == output_dim: + coords[output_dim] = freqs + elif da[name].dim != input_dim: # TODO: keep non-dimensional coordinates + coords[name] = da.coords[name] coords[output_dim] = freqs dims = da.dims + (output_dim,) From d85cdc4ef02dee5fd3c7815f51f47c554f3857e1 Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Tue, 3 Dec 2024 13:50:39 +0100 Subject: [PATCH 20/20] Add stft documentation. --- docs/api/signal.md | 11 +++++++++- xdas/spectral.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/docs/api/signal.md b/docs/api/signal.md index 01bfc4e..a3e4530 100644 --- a/docs/api/signal.md +++ b/docs/api/signal.md @@ -34,4 +34,13 @@ sosfilt sosfiltfilt medfilt -``` \ No newline at end of file +``` + +## Spectral analysisi + +```{eval-rst} +.. autosummary:: + :toctree: ../_autosummary + + stft +``` diff --git a/xdas/spectral.py b/xdas/spectral.py index a709547..dc7e6a4 100644 --- a/xdas/spectral.py +++ b/xdas/spectral.py @@ -18,6 +18,57 @@ def stft( scaling="spectrum", parallel=None, ): + """ + Compute the Short-Time Fourier Transform (STFT) of a data array. + + Parameters + ---------- + da : DataArray + Input data array. + window : str or tuple or array_like, optional + Desired window to use. If a string or tuple, it is passed to + `scipy.signal.get_window` to generate the window values, which are + DFT-even by default. See `scipy.signal.get_window` for a list of + windows and required parameters. If an array, it will be used + directly as the window and its length must be `nperseg`. + nperseg : int, optional + Length of each segment. Defaults to 256. + noverlap : int, optional + Number of points to overlap between segments. If None, `noverlap` + defaults to `nperseg // 2`. Defaults to None. + nfft : int, optional + Length of the FFT used, if a zero padded FFT is desired. If None, + the FFT length is `nperseg`. Defaults to None. + return_onesided : bool, optional + If True, return a one-sided spectrum for real data. If False return + a two-sided spectrum. Defaults to True. + dim : dict, optional + Dictionary specifying the input and output dimensions. Defaults to + {"last": "spectrum"}. + scaling : {'spectrum', 'psd'}, optional + Selects between computing the power spectral density ('psd') where + `scale` is 1 / (sum of window squared) and computing the spectrum + ('spectrum') where `scale` is 1 / (sum of window). Defaults to + 'spectrum'. + parallel : optional + Parallelization option. Defaults to None. + + Returns + ------- + DataArray + STFT of `da`. + + Notes + ----- + The STFT represents a signal in the time-frequency domain by computing + discrete Fourier transforms (DFT) over short overlapping segments of + the signal. + + See Also + -------- + scipy.signal.stft : Compute the Short-Time Fourier Transform (STFT). + + """ if noverlap is None: noverlap = nperseg // 2 if nfft is None: