Skip to content

Commit

Permalink
backup
Browse files Browse the repository at this point in the history
  • Loading branch information
atrabattoni committed Nov 15, 2024
1 parent caf8d1a commit 334f45a
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 2 deletions.
24 changes: 23 additions & 1 deletion tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import xdas
import xdas.signal as xp
from xdas.synthetics import wavelet_wavefronts
from xdas.synthetics import randn_wavefronts, wavelet_wavefronts


class TestSignal:
Expand Down Expand Up @@ -130,6 +130,28 @@ def test_sosfiltfilt(self):
sos = sp.iirfilter(2, 0.5, btype="low", output="sos")
xp.sosfiltfilt(sos, da, "time", padtype=None)


class TestSTFT:
def test_stft(self):
from xdas.spectral import stft

starttime = np.datetime64("2023-01-01T00:00:00")
endtime = starttime + 9999 * np.timedelta64(10, "ms")
da = xdas.DataArray(
data=np.zeros((10000, 11)),
coords={
"time": {"tie_indices": [0, 9999], "tie_values": [starttime, endtime]},
"distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]},
},
)
result = stft(
da,
nperseg=100,
noverlap=50,
window="hamming",
dim={"time": "frequency"},
)

def test_filter(self):
da = wavelet_wavefronts()
axis = da.get_axis_num("time")
Expand Down
2 changes: 1 addition & 1 deletion xdas/core/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def plot_availability(obj, dim="first", **kwargs):
-------
fig : plotly.graph_objects.Figure
The timeline
Notes
-----
This function uses the `px.timeline` function from the `plotly.express` library.
Expand Down
74 changes: 74 additions & 0 deletions xdas/spectral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
from scipy.fft import fft, fftfreq, rfft, rfftfreq
from scipy.signal import get_window

from . import DataArray, get_sampling_interval


def stft(
da,
window="hann",
nperseg=256,
noverlap=None,
nfft=None,
return_onesided=True,
dim={"last": "sprectrum"},
scaling="spectrum",
# parallel=None,
):
if noverlap is None:
noverlap = nperseg // 2
if nfft is None:
nfft = nperseg
win = get_window(window, nperseg)
input_dim, output_dim = next(iter(dim.items()))
axis = da.get_axis_num(input_dim)
dt = get_sampling_interval(da, input_dim)
if scaling == "density":
scale = 1.0 / ((win * win).sum() / dt)
elif scaling == "spectrum":
scale = 1.0 / win.sum() ** 2
else:
raise ValueError("Scaling must be 'density' or 'spectrum'")
scale = np.sqrt(scale)
if return_onesided:
freqs = rfftfreq(nfft, dt)
else:
freqs = fftfreq(nfft, dt)
freqs = {"tie_indices": [0, nfft - 1], "tie_values": [freqs[0], freqs[-1]]}

def func(x):
if nperseg == 1 and noverlap == 0:
result = x[..., np.newaxis]
else:
step = nperseg - noverlap
result = np.lib.stride_tricks.sliding_window_view(
x, window_shape=nperseg, axis=axis, writeable=True
)
slc = [slice(None)] * result.ndim
slc[axis] = slice(None, None, step)
result = result[tuple(slc)]
result = win * result
if return_onesided:
result = rfft(result, n=nfft)
else:
result = fft(result, n=nfft)
result *= scale
return result

data = func(da.values)

dt = get_sampling_interval(da, input_dim, cast=False)
t0 = da.coords[input_dim].values[0]
starttime = t0 + (nperseg / 2) * dt
endtime = t0 + (da.shape[-1] - nperseg / 2) * dt
time = {"tie_indices": [0, da.shape[-1] - 1], "tie_values": [starttime, endtime]}

coords = da.coords.copy()
coords[input_dim] = time
coords[output_dim] = freqs

result = DataArray(data, coords)

dims = dim + (output_dim,)
return result.transpose(*dims)

0 comments on commit 334f45a

Please sign in to comment.