Skip to content

Commit

Permalink
Revert "Revert "Merge pull request #25 from xdas-dev/fix/decimate-vir…
Browse files Browse the repository at this point in the history
…tual-stack""

This reverts commit 3a3f6c7.
  • Loading branch information
atrabattoni committed Nov 29, 2024
1 parent 47ccac9 commit 16c1125
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 4 deletions.
11 changes: 11 additions & 0 deletions tests/test_fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import numpy as np

import xdas as xd
import xdas.fft as xfft


class TestRFFT:
def test_with_non_dimensional(self):
da = xd.synthetics.wavelet_wavefronts()
da["latitude"] = ("distance", np.arange(da.sizes["distance"]))
xfft.rfft(da)
51 changes: 51 additions & 0 deletions tests/test_signal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import tempfile

import numpy as np
import scipy.signal as sp
import xarray as xr
Expand Down Expand Up @@ -129,3 +132,51 @@ def test_sosfiltfilt(self):
da = wavelet_wavefronts()
sos = sp.iirfilter(2, 0.5, btype="low", output="sos")
xp.sosfiltfilt(sos, da, "time", padtype=None)

def test_filter(self):
da = wavelet_wavefronts()
axis = da.get_axis_num("time")
fs = 1 / xdas.get_sampling_interval(da, "time")
sos = sp.butter(
4,
[5, 10],
"band",
output="sos",
fs=fs,
)
data = sp.sosfilt(sos, da.values, axis=axis)
expected = da.copy(data=data)
result = xp.filter(
da,
[5, 10],
btype="band",
corners=4,
zerophase=False,
dim="time",
parallel=False,
)
assert result.equals(expected)
data = sp.sosfiltfilt(sos, da.values, axis=axis)
expected = da.copy(data=data)
result = xp.filter(
da,
[5, 10],
btype="band",
corners=4,
zerophase=True,
dim="time",
parallel=False,
)
assert result.equals(expected)

def test_decimate_virtual_stack(self):
da = wavelet_wavefronts()
expected = xp.decimate(da, 5, dim="time")
chunks = xdas.split(da, 5, "time")
with tempfile.TemporaryDirectory() as tmpdirname:
for i, chunk in enumerate(chunks):
chunk_path = os.path.join(tmpdirname, f"chunk_{i}.nc")
chunk.to_netcdf(chunk_path)
da_virtual = xdas.open_mfdataarray(os.path.join(tmpdirname, "chunk_*.nc"))
result = xp.decimate(da_virtual, 5, dim="time")
assert result.equals(expected)
2 changes: 1 addition & 1 deletion xdas/core/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def plot_availability(obj, dim="first", **kwargs):
-------
fig : plotly.graph_objects.Figure
The timeline
Notes
-----
This function uses the `px.timeline` function from the `plotly.express` library.
Expand Down
2 changes: 2 additions & 0 deletions xdas/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def fft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None):
coords = {
newdim if name == olddim else name: f if name == olddim else da.coords[name]
for name in da.coords
if (da[name].dim != olddim or name == olddim)
}
dims = tuple(newdim if dim == olddim else dim for dim in da.dims)
return DataArray(data, coords, dims, da.name, da.attrs)
Expand Down Expand Up @@ -111,6 +112,7 @@ def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None):
coords = {
newdim if name == olddim else name: f if name == olddim else da.coords[name]
for name in da.coords
if (da[name].dim != olddim or name == olddim)
}
dims = tuple(newdim if dim == olddim else dim for dim in da.dims)
return DataArray(data, coords, dims, da.name, da.attrs)
11 changes: 8 additions & 3 deletions xdas/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def filter(da, freq, btype, corners=4, zerophase=False, dim="last", parallel=Non
fs = 1.0 / get_sampling_interval(da, dim)
sos = sp.iirfilter(corners, freq, btype=btype, ftype="butter", output="sos", fs=fs)
if zerophase:
func = parallelize((None, across), across, parallel)(sp.sosfilt)
else:
func = parallelize((None, across), across, parallel)(sp.sosfiltfilt)
else:
func = parallelize((None, across), across, parallel)(sp.sosfilt)
data = func(sos, da.values, axis=axis)
return da.copy(data=data)

Expand Down Expand Up @@ -708,10 +708,15 @@ def decimate(da, q, n=None, ftype="iir", zero_phase=True, dim="last", parallel=N
"""
axis = da.get_axis_num(dim)
dim = da.dims[axis] # TODO: this fist last thing is a bad idea...
across = int(axis == 0)
func = parallelize(across, across, parallel)(sp.decimate)
data = func(da.values, q, n, ftype, axis, zero_phase)
return da[{dim: slice(None, None, q)}].copy(data=data)
coords = da.coords.copy()
for name in coords:
if coords[name].dim == dim:
coords[name] = coords[name][::q]
return DataArray(data, coords, da.dims, da.name, da.attrs)


@atomized
Expand Down

0 comments on commit 16c1125

Please sign in to comment.