Skip to content

Commit

Permalink
Revert "Merge pull request #25 from xdas-dev/fix/decimate-virtual-stack"
Browse files Browse the repository at this point in the history
This reverts commit f727e9c, reversing
changes made to 9220bfc.
  • Loading branch information
atrabattoni committed Nov 29, 2024
1 parent f727e9c commit 3a3f6c7
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 73 deletions.
11 changes: 0 additions & 11 deletions tests/test_fft.py

This file was deleted.

51 changes: 0 additions & 51 deletions tests/test_signal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
import tempfile

import numpy as np
import scipy.signal as sp
import xarray as xr
Expand Down Expand Up @@ -132,51 +129,3 @@ def test_sosfiltfilt(self):
da = wavelet_wavefronts()
sos = sp.iirfilter(2, 0.5, btype="low", output="sos")
xp.sosfiltfilt(sos, da, "time", padtype=None)

def test_filter(self):
da = wavelet_wavefronts()
axis = da.get_axis_num("time")
fs = 1 / xdas.get_sampling_interval(da, "time")
sos = sp.butter(
4,
[5, 10],
"band",
output="sos",
fs=fs,
)
data = sp.sosfilt(sos, da.values, axis=axis)
expected = da.copy(data=data)
result = xp.filter(
da,
[5, 10],
btype="band",
corners=4,
zerophase=False,
dim="time",
parallel=False,
)
assert result.equals(expected)
data = sp.sosfiltfilt(sos, da.values, axis=axis)
expected = da.copy(data=data)
result = xp.filter(
da,
[5, 10],
btype="band",
corners=4,
zerophase=True,
dim="time",
parallel=False,
)
assert result.equals(expected)

def test_decimate_virtual_stack(self):
da = wavelet_wavefronts()
expected = xp.decimate(da, 5, dim="time")
chunks = xdas.split(da, 5, "time")
with tempfile.TemporaryDirectory() as tmpdirname:
for i, chunk in enumerate(chunks):
chunk_path = os.path.join(tmpdirname, f"chunk_{i}.nc")
chunk.to_netcdf(chunk_path)
da_virtual = xdas.open_mfdataarray(os.path.join(tmpdirname, "chunk_*.nc"))
result = xp.decimate(da_virtual, 5, dim="time")
assert result.equals(expected)
2 changes: 1 addition & 1 deletion xdas/core/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def plot_availability(obj, dim="first", **kwargs):
-------
fig : plotly.graph_objects.Figure
The timeline
Notes
-----
This function uses the `px.timeline` function from the `plotly.express` library.
Expand Down
2 changes: 0 additions & 2 deletions xdas/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def fft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None):
coords = {
newdim if name == olddim else name: f if name == olddim else da.coords[name]
for name in da.coords
if (da[name].dim != olddim or name == olddim)
}
dims = tuple(newdim if dim == olddim else dim for dim in da.dims)
return DataArray(data, coords, dims, da.name, da.attrs)
Expand Down Expand Up @@ -112,7 +111,6 @@ def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None):
coords = {
newdim if name == olddim else name: f if name == olddim else da.coords[name]
for name in da.coords
if (da[name].dim != olddim or name == olddim)
}
dims = tuple(newdim if dim == olddim else dim for dim in da.dims)
return DataArray(data, coords, dims, da.name, da.attrs)
11 changes: 3 additions & 8 deletions xdas/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def filter(da, freq, btype, corners=4, zerophase=False, dim="last", parallel=Non
fs = 1.0 / get_sampling_interval(da, dim)
sos = sp.iirfilter(corners, freq, btype=btype, ftype="butter", output="sos", fs=fs)
if zerophase:
func = parallelize((None, across), across, parallel)(sp.sosfiltfilt)
else:
func = parallelize((None, across), across, parallel)(sp.sosfilt)
else:
func = parallelize((None, across), across, parallel)(sp.sosfiltfilt)
data = func(sos, da.values, axis=axis)
return da.copy(data=data)

Expand Down Expand Up @@ -708,15 +708,10 @@ def decimate(da, q, n=None, ftype="iir", zero_phase=True, dim="last", parallel=N
"""
axis = da.get_axis_num(dim)
dim = da.dims[axis] # TODO: this fist last thing is a bad idea...
across = int(axis == 0)
func = parallelize(across, across, parallel)(sp.decimate)
data = func(da.values, q, n, ftype, axis, zero_phase)
coords = da.coords.copy()
for name in coords:
if coords[name].dim == dim:
coords[name] = coords[name][::q]
return DataArray(data, coords, da.dims, da.name, da.attrs)
return da[{dim: slice(None, None, q)}].copy(data=data)


@atomized
Expand Down

0 comments on commit 3a3f6c7

Please sign in to comment.