From 16c112562bbc8415b59dfc241917027ba44793ed Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Fri, 29 Nov 2024 16:11:52 +0100 Subject: [PATCH] Revert "Revert "Merge pull request #25 from xdas-dev/fix/decimate-virtual-stack"" This reverts commit 3a3f6c72de4a03504ae1ac4c39d94fb0d719fb80. --- tests/test_fft.py | 11 ++++++++++ tests/test_signal.py | 51 +++++++++++++++++++++++++++++++++++++++++++ xdas/core/routines.py | 2 +- xdas/fft.py | 2 ++ xdas/signal.py | 11 +++++++--- 5 files changed, 73 insertions(+), 4 deletions(-) create mode 100644 tests/test_fft.py diff --git a/tests/test_fft.py b/tests/test_fft.py new file mode 100644 index 0000000..371c644 --- /dev/null +++ b/tests/test_fft.py @@ -0,0 +1,11 @@ +import numpy as np + +import xdas as xd +import xdas.fft as xfft + + +class TestRFFT: + def test_with_non_dimensional(self): + da = xd.synthetics.wavelet_wavefronts() + da["latitude"] = ("distance", np.arange(da.sizes["distance"])) + xfft.rfft(da) diff --git a/tests/test_signal.py b/tests/test_signal.py index 92b40c5..51f658d 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -1,3 +1,6 @@ +import os +import tempfile + import numpy as np import scipy.signal as sp import xarray as xr @@ -129,3 +132,51 @@ def test_sosfiltfilt(self): da = wavelet_wavefronts() sos = sp.iirfilter(2, 0.5, btype="low", output="sos") xp.sosfiltfilt(sos, da, "time", padtype=None) + + def test_filter(self): + da = wavelet_wavefronts() + axis = da.get_axis_num("time") + fs = 1 / xdas.get_sampling_interval(da, "time") + sos = sp.butter( + 4, + [5, 10], + "band", + output="sos", + fs=fs, + ) + data = sp.sosfilt(sos, da.values, axis=axis) + expected = da.copy(data=data) + result = xp.filter( + da, + [5, 10], + btype="band", + corners=4, + zerophase=False, + dim="time", + parallel=False, + ) + assert result.equals(expected) + data = sp.sosfiltfilt(sos, da.values, axis=axis) + expected = da.copy(data=data) + result = xp.filter( + da, + [5, 10], + btype="band", + corners=4, + zerophase=True, + dim="time", + parallel=False, + ) + assert result.equals(expected) + + def test_decimate_virtual_stack(self): + da = wavelet_wavefronts() + expected = xp.decimate(da, 5, dim="time") + chunks = xdas.split(da, 5, "time") + with tempfile.TemporaryDirectory() as tmpdirname: + for i, chunk in enumerate(chunks): + chunk_path = os.path.join(tmpdirname, f"chunk_{i}.nc") + chunk.to_netcdf(chunk_path) + da_virtual = xdas.open_mfdataarray(os.path.join(tmpdirname, "chunk_*.nc")) + result = xp.decimate(da_virtual, 5, dim="time") + assert result.equals(expected) diff --git a/xdas/core/routines.py b/xdas/core/routines.py index a676e8d..11f1a9e 100644 --- a/xdas/core/routines.py +++ b/xdas/core/routines.py @@ -967,7 +967,7 @@ def plot_availability(obj, dim="first", **kwargs): ------- fig : plotly.graph_objects.Figure The timeline - + Notes ----- This function uses the `px.timeline` function from the `plotly.express` library. diff --git a/xdas/fft.py b/xdas/fft.py index 3cd3312..fdf5055 100644 --- a/xdas/fft.py +++ b/xdas/fft.py @@ -56,6 +56,7 @@ def fft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None): coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] for name in da.coords + if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) @@ -111,6 +112,7 @@ def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None): coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] for name in da.coords + if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) diff --git a/xdas/signal.py b/xdas/signal.py index f1cd199..71078c1 100644 --- a/xdas/signal.py +++ b/xdas/signal.py @@ -102,9 +102,9 @@ def filter(da, freq, btype, corners=4, zerophase=False, dim="last", parallel=Non fs = 1.0 / get_sampling_interval(da, dim) sos = sp.iirfilter(corners, freq, btype=btype, ftype="butter", output="sos", fs=fs) if zerophase: - func = parallelize((None, across), across, parallel)(sp.sosfilt) - else: func = parallelize((None, across), across, parallel)(sp.sosfiltfilt) + else: + func = parallelize((None, across), across, parallel)(sp.sosfilt) data = func(sos, da.values, axis=axis) return da.copy(data=data) @@ -708,10 +708,15 @@ def decimate(da, q, n=None, ftype="iir", zero_phase=True, dim="last", parallel=N """ axis = da.get_axis_num(dim) + dim = da.dims[axis] # TODO: this fist last thing is a bad idea... across = int(axis == 0) func = parallelize(across, across, parallel)(sp.decimate) data = func(da.values, q, n, ftype, axis, zero_phase) - return da[{dim: slice(None, None, q)}].copy(data=data) + coords = da.coords.copy() + for name in coords: + if coords[name].dim == dim: + coords[name] = coords[name][::q] + return DataArray(data, coords, da.dims, da.name, da.attrs) @atomized