diff --git a/tests/test_fft.py b/tests/test_fft.py deleted file mode 100644 index 371c644..0000000 --- a/tests/test_fft.py +++ /dev/null @@ -1,11 +0,0 @@ -import numpy as np - -import xdas as xd -import xdas.fft as xfft - - -class TestRFFT: - def test_with_non_dimensional(self): - da = xd.synthetics.wavelet_wavefronts() - da["latitude"] = ("distance", np.arange(da.sizes["distance"])) - xfft.rfft(da) diff --git a/tests/test_signal.py b/tests/test_signal.py index 51f658d..92b40c5 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -1,6 +1,3 @@ -import os -import tempfile - import numpy as np import scipy.signal as sp import xarray as xr @@ -132,51 +129,3 @@ def test_sosfiltfilt(self): da = wavelet_wavefronts() sos = sp.iirfilter(2, 0.5, btype="low", output="sos") xp.sosfiltfilt(sos, da, "time", padtype=None) - - def test_filter(self): - da = wavelet_wavefronts() - axis = da.get_axis_num("time") - fs = 1 / xdas.get_sampling_interval(da, "time") - sos = sp.butter( - 4, - [5, 10], - "band", - output="sos", - fs=fs, - ) - data = sp.sosfilt(sos, da.values, axis=axis) - expected = da.copy(data=data) - result = xp.filter( - da, - [5, 10], - btype="band", - corners=4, - zerophase=False, - dim="time", - parallel=False, - ) - assert result.equals(expected) - data = sp.sosfiltfilt(sos, da.values, axis=axis) - expected = da.copy(data=data) - result = xp.filter( - da, - [5, 10], - btype="band", - corners=4, - zerophase=True, - dim="time", - parallel=False, - ) - assert result.equals(expected) - - def test_decimate_virtual_stack(self): - da = wavelet_wavefronts() - expected = xp.decimate(da, 5, dim="time") - chunks = xdas.split(da, 5, "time") - with tempfile.TemporaryDirectory() as tmpdirname: - for i, chunk in enumerate(chunks): - chunk_path = os.path.join(tmpdirname, f"chunk_{i}.nc") - chunk.to_netcdf(chunk_path) - da_virtual = xdas.open_mfdataarray(os.path.join(tmpdirname, "chunk_*.nc")) - result = xp.decimate(da_virtual, 5, dim="time") - assert result.equals(expected) diff --git a/xdas/core/routines.py b/xdas/core/routines.py index 11f1a9e..a676e8d 100644 --- a/xdas/core/routines.py +++ b/xdas/core/routines.py @@ -967,7 +967,7 @@ def plot_availability(obj, dim="first", **kwargs): ------- fig : plotly.graph_objects.Figure The timeline - + Notes ----- This function uses the `px.timeline` function from the `plotly.express` library. diff --git a/xdas/fft.py b/xdas/fft.py index fdf5055..3cd3312 100644 --- a/xdas/fft.py +++ b/xdas/fft.py @@ -56,7 +56,6 @@ def fft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None): coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] for name in da.coords - if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) @@ -112,7 +111,6 @@ def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None): coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] for name in da.coords - if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) diff --git a/xdas/signal.py b/xdas/signal.py index 71078c1..f1cd199 100644 --- a/xdas/signal.py +++ b/xdas/signal.py @@ -102,9 +102,9 @@ def filter(da, freq, btype, corners=4, zerophase=False, dim="last", parallel=Non fs = 1.0 / get_sampling_interval(da, dim) sos = sp.iirfilter(corners, freq, btype=btype, ftype="butter", output="sos", fs=fs) if zerophase: - func = parallelize((None, across), across, parallel)(sp.sosfiltfilt) - else: func = parallelize((None, across), across, parallel)(sp.sosfilt) + else: + func = parallelize((None, across), across, parallel)(sp.sosfiltfilt) data = func(sos, da.values, axis=axis) return da.copy(data=data) @@ -708,15 +708,10 @@ def decimate(da, q, n=None, ftype="iir", zero_phase=True, dim="last", parallel=N """ axis = da.get_axis_num(dim) - dim = da.dims[axis] # TODO: this fist last thing is a bad idea... across = int(axis == 0) func = parallelize(across, across, parallel)(sp.decimate) data = func(da.values, q, n, ftype, axis, zero_phase) - coords = da.coords.copy() - for name in coords: - if coords[name].dim == dim: - coords[name] = coords[name][::q] - return DataArray(data, coords, da.dims, da.name, da.attrs) + return da[{dim: slice(None, None, q)}].copy(data=data) @atomized