diff --git a/docs/api/fft.md b/docs/api/fft.md index f6b0567..066e696 100644 --- a/docs/api/fft.md +++ b/docs/api/fft.md @@ -11,5 +11,7 @@ :toctree: ../_autosummary fft + ifft rfft + irfft ``` \ No newline at end of file diff --git a/docs/api/signal.md b/docs/api/signal.md index 01bfc4e..a3e4530 100644 --- a/docs/api/signal.md +++ b/docs/api/signal.md @@ -34,4 +34,13 @@ sosfilt sosfiltfilt medfilt -``` \ No newline at end of file +``` + +## Spectral analysisi + +```{eval-rst} +.. autosummary:: + :toctree: ../_autosummary + + stft +``` diff --git a/docs/user-guide/atoms.md b/docs/user-guide/atoms.md index ae233fd..efcb411 100644 --- a/docs/user-guide/atoms.md +++ b/docs/user-guide/atoms.md @@ -23,12 +23,12 @@ There are three "flavours" declaring the atoms that can be used to compose a seq ```{code-cell} import numpy as np import xdas -import xdas.signal as xp +import xdas.signal as xs from xdas.atoms import Partial, Sequential, IIRFilter sequence = Sequential( [ - xp.taper(..., dim="time"), + xs.taper(..., dim="time"), Partial(np.square), IIRFilter(order=4, cutoff=1.5, btype="highpass", dim="time"), ] diff --git a/docs/user-guide/convert-displacement.md b/docs/user-guide/convert-displacement.md index 398f0ed..5ae49b6 100644 --- a/docs/user-guide/convert-displacement.md +++ b/docs/user-guide/convert-displacement.md @@ -28,11 +28,11 @@ strain_rate.plot(yincrease=False, vmin=-0.5, vmax=0.5); Then convert strain rate to deformation and then to displacement. ```{code-cell} -import xdas.signal as xp +import xdas.signal as xs -strain = xp.integrate(strain_rate, dim="time") -deformation = xp.integrate(strain, dim="distance") -displacement = xp.sliding_mean_removal(deformation, wlen=2000.0, dim="distance") +strain = xs.integrate(strain_rate, dim="time") +deformation = xs.integrate(strain, dim="distance") +displacement = xs.sliding_mean_removal(deformation, wlen=2000.0, dim="distance") displacement.plot(yincrease=False, vmin=-0.5, vmax=0.5); ``` diff --git a/tests/test_atoms.py b/tests/test_atoms.py index 63a8d63..7dc8338 100644 --- a/tests/test_atoms.py +++ b/tests/test_atoms.py @@ -3,7 +3,7 @@ import xdas import xdas as xd -import xdas.signal as xp +import xdas.signal as xs from xdas.atoms import ( DownSample, FIRFilter, @@ -22,8 +22,8 @@ class TestPartialAtom: def test_init(self): Sequential( [ - Partial(xp.taper, dim="time"), - Partial(xp.taper, dim="distance"), + Partial(xs.taper, dim="time"), + Partial(xs.taper, dim="distance"), Partial(np.abs), Partial(np.square), ] @@ -176,7 +176,7 @@ def test_firfilter(self): da = wavelet_wavefronts() chunks = xdas.split(da, 6, "time") taps = sp.firwin(11, 0.4, pass_zero="lowpass") - expected = xp.lfilter(taps, 1.0, da, "time") + expected = xs.lfilter(taps, 1.0, da, "time") expected["time"] -= np.timedelta64(20, "ms") * 5 atom = FIRFilter(11, 10.0, "lowpass", dim="time") result = atom(da) @@ -194,7 +194,7 @@ def test_resample_poly(self): da = wavelet_wavefronts() chunks = xdas.split(da, 6, "time") - expected = xp.resample_poly(da, 5, 2, "time") + expected = xs.resample_poly(da, 5, 2, "time") atom = ResamplePoly(125, maxfactor=10, dim="time") result = atom(da) result_chunked = xdas.concatenate( diff --git a/tests/test_dataarray.py b/tests/test_dataarray.py index c3a5c2a..16b828c 100644 --- a/tests/test_dataarray.py +++ b/tests/test_dataarray.py @@ -159,6 +159,20 @@ def test_sel(self): result = da.sel(distance=0, method="nearest", drop=True) assert "distance" not in result.coords + def test_better_error_when_sel_with_overlaps(self): + da = DataArray( + np.arange(80).reshape(20, 4), + { + "time": { + "tie_values": [0.0, 0.5, 0.4, 1.0], + "tie_indices": [0, 9, 10, 19], + }, + "distance": [0.0, 10.0, 20.0, 30.0], + }, + ) + with pytest.raises(ValueError, match="overlaps were found"): + da.sel(time=slice(0.1, 0.6)) + def test_isel(self): da = wavelet_wavefronts() result = da.isel(first=0) diff --git a/tests/test_fft.py b/tests/test_fft.py index 928f155..206e98f 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -1,9 +1,47 @@ import numpy as np + import xdas as xd import xdas.fft as xfft + class TestRFFT: def test_with_non_dimensional(self): da = xd.synthetics.wavelet_wavefronts() da["latitude"] = ("distance", np.arange(da.sizes["distance"])) xfft.rfft(da) + + +class TestIncerse: + def test_standard(self): + expected = xd.synthetics.wavelet_wavefronts() + result = xfft.ifft( + xfft.fft(expected, dim={"time": "frequency"}), + dim={"frequency": "time"}, + ) + assert np.allclose(np.real(result).values, expected.values) + assert np.allclose(np.imag(result).values, 0) + for name in result.coords: + if name == "time": + ref = expected["time"].values + ref = (ref - ref[0]) / np.timedelta64(1, "s") + ref += result["time"][0].values + assert np.allclose(result["time"].values, ref) + else: + assert result[name].equals(expected[name]) + + def test_real(self): + expected = xd.synthetics.wavelet_wavefronts() + result = xfft.irfft( + xfft.rfft(expected, dim={"time": "frequency"}), + expected.sizes["time"], + dim={"frequency": "time"}, + ) + assert np.allclose(result.values, expected.values) + for name in result.coords: + if name == "time": + ref = expected["time"].values + ref = (ref - ref[0]) / np.timedelta64(1, "s") + ref += result["time"][0].values + assert np.allclose(result["time"].values, ref) + else: + assert result[name].equals(expected[name]) diff --git a/tests/test_signal.py b/tests/test_signal.py index 18b561b..e14c41a 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -1,9 +1,12 @@ +import os +import tempfile + import numpy as np import scipy.signal as sp import xarray as xr import xdas -import xdas.signal as xp +import xdas.signal as xs from xdas.synthetics import wavelet_wavefronts @@ -28,8 +31,8 @@ def test_get_sample_spacing(self): }, }, ) - assert xp.get_sampling_interval(da, "time") == 0.008 - assert xp.get_sampling_interval(da, "distance") == 5.0 + assert xs.get_sampling_interval(da, "time") == 0.008 + assert xs.get_sampling_interval(da, "distance") == 5.0 def test_deterend(self): n = 100 @@ -37,7 +40,7 @@ def test_deterend(self): s = d * np.arange(n) da = xr.DataArray(np.arange(n), {"time": s}) da = xdas.DataArray.from_xarray(da) - da = xp.detrend(da) + da = xs.detrend(da) assert np.allclose(da.values, np.zeros(n)) def test_differentiate(self): @@ -46,7 +49,7 @@ def test_differentiate(self): s = (d / 2) + d * np.arange(n) da = xr.DataArray(np.ones(n), {"distance": s}) da = xdas.DataArray.from_xarray(da) - da = xp.differentiate(da, midpoints=True) + da = xs.differentiate(da, midpoints=True) assert np.allclose(da.values, np.zeros(n - 1)) def test_integrate(self): @@ -55,7 +58,7 @@ def test_integrate(self): s = (d / 2) + d * np.arange(n) da = xr.DataArray(np.ones(n), {"distance": s}) da = xdas.DataArray.from_xarray(da) - da = xp.integrate(da, midpoints=True) + da = xs.integrate(da, midpoints=True) assert np.allclose(da.values, da["distance"].values) def test_segment_mean_removal(self): @@ -69,7 +72,7 @@ def test_segment_mean_removal(self): da.loc[{"distance": slice(limits[0], limits[1])}] = 1.0 da.loc[{"distance": slice(limits[1], limits[2])}] = 2.0 da = xdas.DataArray.from_xarray(da) - da = xp.segment_mean_removal(da, limits) + da = xs.segment_mean_removal(da, limits) assert np.allclose(da.values, 0) def test_sliding_window_removal(self): @@ -80,55 +83,55 @@ def test_sliding_window_removal(self): data = np.ones(n) da = xr.DataArray(data, {"distance": s}) da = xdas.DataArray.from_xarray(da) - da = xp.sliding_mean_removal(da, 0.1 * n * d) + da = xs.sliding_mean_removal(da, 0.1 * n * d) assert np.allclose(da.values, 0) def test_medfilt(self): da = wavelet_wavefronts() - result1 = xp.medfilt(da, {"distance": 3}) - result2 = xp.medfilt(da, {"time": 1, "distance": 3}) + result1 = xs.medfilt(da, {"distance": 3}) + result2 = xs.medfilt(da, {"time": 1, "distance": 3}) assert result1.equals(result2) da.data = np.zeros(da.shape) - assert da.equals(xp.medfilt(da, {"time": 7, "distance": 3})) + assert da.equals(xs.medfilt(da, {"time": 7, "distance": 3})) def test_hilbert(self): da = wavelet_wavefronts() - result = xp.hilbert(da, dim="time") + result = xs.hilbert(da, dim="time") assert np.allclose(da.values, np.real(result.values)) def test_resample(self): da = wavelet_wavefronts() - result = xp.resample(da, 100, dim="time", window="hamming", domain="time") + result = xs.resample(da, 100, dim="time", window="hamming", domain="time") assert result.sizes["time"] == 100 def test_resample_poly(self): da = wavelet_wavefronts() - result = xp.resample_poly(da, 2, 5, dim="time") + result = xs.resample_poly(da, 2, 5, dim="time") assert result.sizes["time"] == 120 def test_lfilter(self): da = wavelet_wavefronts() b, a = sp.iirfilter(4, 0.5, btype="low") - result1 = xp.lfilter(b, a, da, "time") - result2, zf = xp.lfilter(b, a, da, "time", zi=...) + result1 = xs.lfilter(b, a, da, "time") + result2, zf = xs.lfilter(b, a, da, "time", zi=...) assert result1.equals(result2) def test_filtfilt(self): da = wavelet_wavefronts() b, a = sp.iirfilter(2, 0.5, btype="low") - xp.filtfilt(b, a, da, "time", padtype=None) + xs.filtfilt(b, a, da, "time", padtype=None) def test_sosfilter(self): da = wavelet_wavefronts() sos = sp.iirfilter(4, 0.5, btype="low", output="sos") - result1 = xp.sosfilt(sos, da, "time") - result2, zf = xp.sosfilt(sos, da, "time", zi=...) + result1 = xs.sosfilt(sos, da, "time") + result2, zf = xs.sosfilt(sos, da, "time", zi=...) assert result1.equals(result2) def test_sosfiltfilt(self): da = wavelet_wavefronts() sos = sp.iirfilter(2, 0.5, btype="low", output="sos") - xp.sosfiltfilt(sos, da, "time", padtype=None) + xs.sosfiltfilt(sos, da, "time", padtype=None) def test_filter(self): da = wavelet_wavefronts() @@ -143,7 +146,7 @@ def test_filter(self): ) data = sp.sosfilt(sos, da.values, axis=axis) expected = da.copy(data=data) - result = xp.filter( + result = xs.filter( da, [5, 10], btype="band", @@ -155,7 +158,7 @@ def test_filter(self): assert result.equals(expected) data = sp.sosfiltfilt(sos, da.values, axis=axis) expected = da.copy(data=data) - result = xp.filter( + result = xs.filter( da, [5, 10], btype="band", @@ -164,4 +167,148 @@ def test_filter(self): dim="time", parallel=False, ) - assert result.equals(expected) \ No newline at end of file + assert result.equals(expected) + + def test_decimate_virtual_stack(self): + da = wavelet_wavefronts() + expected = xs.decimate(da, 5, dim="time") + chunks = xdas.split(da, 5, "time") + with tempfile.TemporaryDirectory() as tmpdirname: + for i, chunk in enumerate(chunks): + chunk_path = os.path.join(tmpdirname, f"chunk_{i}.nc") + chunk.to_netcdf(chunk_path) + da_virtual = xdas.open_mfdataarray(os.path.join(tmpdirname, "chunk_*.nc")) + result = xs.decimate(da_virtual, 5, dim="time") + assert result.equals(expected) + + +class TestSTFT: + def test_compare_with_scipy(self): + starttime = np.datetime64("2023-01-01T00:00:00") + endtime = starttime + 9999 * np.timedelta64(10, "ms") + da = xdas.DataArray( + data=np.random.rand(10000, 11), + coords={ + "time": {"tie_indices": [0, 9999], "tie_values": [starttime, endtime]}, + "distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]}, + }, + ) + for scaling in ["spectrum", "psd"]: + for return_onesided in [True, False]: + for nfft in [None, 128]: + result = xs.stft( + da, + window="hamming", + nperseg=100, + noverlap=50, + nfft=nfft, + return_onesided=return_onesided, + dim={"time": "frequency"}, + scaling=scaling, + ) + f, t, Zxx = sp.stft( + da.values, + fs=1 / xs.get_sampling_interval(da, "time"), + window="hamming", + nperseg=100, + noverlap=50, + nfft=nfft, + return_onesided=return_onesided, + boundary=None, + padded=False, + axis=0, + scaling=scaling, + ) + if return_onesided: + assert np.allclose(result.values, np.transpose(Zxx, (2, 1, 0))) + else: + assert np.allclose( + result.values, + np.fft.fftshift(np.transpose(Zxx, (2, 1, 0)), axes=-1), + ) + assert np.allclose(result["frequency"].values, np.sort(f)) + assert np.allclose( + (result["time"].values - da["time"][0].values) + / np.timedelta64(1, "s"), + t, + ) + assert result["distance"].equals(da["distance"]) + + def test_retrieve_frequency_peak(self): + fs = 10e3 + N = 1e5 + fc = 3e3 + amp = 2 * np.sqrt(2) + time = np.arange(N) / float(fs) + data = amp * np.sin(2 * np.pi * fc * time) + da = xdas.DataArray( + data=data, + coords={"time": time}, + ) + result = xs.stft( + da, nperseg=1000, noverlap=500, window="hann", dim={"time": "frequency"} + ) + idx = int(np.abs(np.square(result)).mean("time").argmax("frequency").values) + assert result["frequency"][idx].values == fc + + def test_parrallel(self): + starttime = np.datetime64("2023-01-01T00:00:00") + endtime = starttime + 9999 * np.timedelta64(10, "ms") + da = xdas.DataArray( + data=np.random.rand(10000, 11), + coords={ + "time": {"tie_indices": [0, 9999], "tie_values": [starttime, endtime]}, + "distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]}, + }, + ) + serial = xs.stft( + da, + nperseg=100, + noverlap=50, + window="hamming", + dim={"time": "frequency"}, + parallel=False, + ) + parallel = xs.stft( + da, + nperseg=100, + noverlap=50, + window="hamming", + dim={"time": "frequency"}, + parallel=True, + ) + assert serial.equals(parallel) + + def test_last_dimension_with_non_dimensional_coordinates(self): + starttime = np.datetime64("2023-01-01T00:00:00") + endtime = starttime + 99 * np.timedelta64(10, "ms") + da = xdas.DataArray( + data=np.random.rand(100, 1001), + coords={ + "time": {"tie_indices": [0, 99], "tie_values": [starttime, endtime]}, + "distance": {"tie_indices": [0, 1000], "tie_values": [0.0, 10_000.0]}, + "channel": ("distance", np.arange(1001)), + }, + ) + result = xs.stft( + da, + nperseg=100, + noverlap=50, + window="hamming", + dim={"distance": "wavenumber"}, + ) + f, t, Zxx = sp.stft( + da.values, + fs=1 / xs.get_sampling_interval(da, "distance"), + window="hamming", + nperseg=100, + noverlap=50, + boundary=None, + padded=False, + axis=1, + ) + assert np.allclose(result.values, np.transpose(Zxx, (0, 2, 1))) + assert result["time"].equals(da["time"]) + assert np.allclose(result["distance"].values, t) + assert np.allclose(result["wavenumber"].values, np.sort(f)) + assert "channel" not in result.coords # TODO: keep non-dimensional coordinates diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 88ccefa..382a4fe 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -1,6 +1,6 @@ import numpy as np -import xdas.core.methods as xp +import xdas.core.methods as xm from xdas.core.dataarray import DataArray from xdas.synthetics import wavelet_wavefronts @@ -8,7 +8,7 @@ class TestXarray: def test_returns_dataarray(self): da = wavelet_wavefronts() - for name, func in xp.HANDLED_METHODS.items(): + for name, func in xm.HANDLED_METHODS.items(): if callable(func): if name in [ "percentile", @@ -31,7 +31,7 @@ def test_returns_dataarray(self): def test_mean(self): da = wavelet_wavefronts() - result = xp.mean(da, "time") + result = xm.mean(da, "time") result_method = da.mean("time") expected = np.mean(da, 0) assert result.equals(expected) diff --git a/xdas/atoms/core.py b/xdas/atoms/core.py index df403d2..3375ccd 100644 --- a/xdas/atoms/core.py +++ b/xdas/atoms/core.py @@ -194,15 +194,15 @@ class Sequential(Atom, list): Examples -------- >>> from xdas.atoms import Partial, Sequential - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> import numpy as np Basic usage: >>> seq = Sequential( ... [ - ... Partial(xp.taper, dim="time"), - ... Partial(xp.lfilter, [1.0], [0.5], ..., dim="time", zi=...), + ... Partial(xs.taper, dim="time"), + ... Partial(xs.lfilter, [1.0], [0.5], ..., dim="time", zi=...), ... Partial(np.square), ... ], ... name="Low frequency energy", @@ -217,7 +217,7 @@ class Sequential(Atom, list): >>> seq = Sequential( ... [ - ... Partial(xp.decimate, 16, dim="distance"), + ... Partial(xs.decimate, 16, dim="distance"), ... seq, ... ] ... ) @@ -330,12 +330,12 @@ class Partial(Atom): -------- >>> import numpy as np >>> import scipy.signal as sp - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.atoms import Partial Examples of a stateless atom: - >>> Partial(xp.decimate, 2, dim="time") + >>> Partial(xs.decimate, 2, dim="time") decimate(..., 2, dim=time) >>> Partial(np.square) @@ -344,7 +344,7 @@ class Partial(Atom): Examples of a stateful atom with input data as second argument: >>> sos = sp.iirfilter(4, 0.1, btype="lowpass", output="sos") - >>> Partial(xp.sosfilt, sos, ..., dim="time", zi=...) + >>> Partial(xs.sosfilt, sos, ..., dim="time", zi=...) sosfilt(, ..., dim=time) [stateful] """ diff --git a/xdas/core/coordinates.py b/xdas/core/coordinates.py index 58c5837..57e6405 100644 --- a/xdas/core/coordinates.py +++ b/xdas/core/coordinates.py @@ -769,7 +769,20 @@ def get_indexer(self, value, method=None): value = np.datetime64(value) else: value = np.asarray(value) - return inverse(value, self.tie_indices, self.tie_values, method) + try: + indexer = inverse(value, self.tie_indices, self.tie_values, method) + except ValueError as e: + if str(e) == "fp must be strictly increasing": + raise ValueError( + "overlaps were found in the coordinate. If this is due to some " + "jitter in the tie values, consider smoothing the coordinate by " + "including some tolerance. This can be done by " + "`da[dim] = da[dim].simplify(tolerance)`, or by specifying a " + "tolerance when opening multiple files." + ) + else: + raise e + return indexer def slice_indexer(self, start=None, stop=None, step=None, endpoint=True): if start is not None: diff --git a/xdas/core/routines.py b/xdas/core/routines.py index a676e8d..11f1a9e 100644 --- a/xdas/core/routines.py +++ b/xdas/core/routines.py @@ -967,7 +967,7 @@ def plot_availability(obj, dim="first", **kwargs): ------- fig : plotly.graph_objects.Figure The timeline - + Notes ----- This function uses the `px.timeline` function from the `plotly.express` library. diff --git a/xdas/fft.py b/xdas/fft.py index 2599ae4..704ab65 100644 --- a/xdas/fft.py +++ b/xdas/fft.py @@ -55,14 +55,15 @@ def fft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None): data = func(da.values) coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] - for name in da.coords if (da[name].dim != olddim or name == olddim) + for name in da.coords + if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) @atomized -def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None): +def rfft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None): """ Compute the discrete Fourier Transform for real inputs along a given dimension. @@ -110,7 +111,121 @@ def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None): data = func(da.values, n, axis, norm) coords = { newdim if name == olddim else name: f if name == olddim else da.coords[name] - for name in da.coords if (da[name].dim != olddim or name == olddim) + for name in da.coords + if (da[name].dim != olddim or name == olddim) + } + dims = tuple(newdim if dim == olddim else dim for dim in da.dims) + return DataArray(data, coords, dims, da.name, da.attrs) + + +@atomized +def ifft(da, n=None, dim={"last": "time"}, norm=None, parallel=None): + """ + Compute the inverse discrete Fourier Transform along a given dimension. + + This function computes the inverse of the one-dimensional n-point discrete Fourier + transform computed by fft. In other words, ifft(fft(a)) == a to within numerical + accuracy. + + Parameters + ---------- + da: DataArray + The data array to process, should be complex. + n: int, optional + Length of transformed dimension of the output. If n is smaller than the length + of the input, the input is cropped. If it is larger, the input is padded with + zeros. If n is not given, the length of the input along the dimension specified + by `dim` is used. + dim: {str: str}, optional + A mapping indicating as a key the dimension along which to compute the IFFT, and + as value the new name of the dimension. Default to {"last": "time"}. + norm: {“backward”, “ortho”, “forward”}, optional + Normalization mode (see `numpy.fft`). Default is "backward". Indicates which + direction of the forward/backward pair of transforms is scaled and with what + normalization factor. + + Returns + ------- + DataArray: + The transformed input with an updated dimension name and values. + + Notes + ----- + - To perform a multidimensional inverse fourrier transform, repeat this function on + the desired dimensions. + + """ + ((olddim, newdim),) = dim.items() + olddim = da.dims[da.get_axis_num(olddim)] + if n is None: + n = da.sizes[olddim] + axis = da.get_axis_num(olddim) + d = get_sampling_interval(da, olddim) + f = np.fft.ifftshift(np.fft.fftfreq(n, d)) + func = lambda x: np.fft.ifft(np.fft.ifftshift(x, axis), n, axis, norm) + across = int(axis == 0) + func = parallelize(across, across, parallel)(func) + data = func(da.values) + coords = { + newdim if name == olddim else name: f if name == olddim else da.coords[name] + for name in da.coords + if (da[name].dim != olddim or name == olddim) + } + dims = tuple(newdim if dim == olddim else dim for dim in da.dims) + return DataArray(data, coords, dims, da.name, da.attrs) + + +@atomized +def irfft(da, n=None, dim={"last": "time"}, norm=None, parallel=None): + """ + Compute the discrete Fourier Transform for real inputs along a given dimension. + + This function computes the one-dimensional n-point discrete Fourier Transform (DFT) + or real-valued inputs with the efficient Fast Fourier Transform (FFT) algorithm. + + Parameters + ---------- + da: DataArray + The data array to process, can be complex. + n: int, optional + Length of transformed dimension of the output. If n is smaller than the length + of the input, the input is cropped. If it is larger, the input is padded with + zeros. If n is not given, the length of the input along the dimension specified + by `dim` is used. + dim: {str: str}, optional + A mapping indicating as a key the dimension along which to compute the FFT, and + as value the new name of the dimension. Default to {"last": "time"}. + norm: {“backward”, “ortho”, “forward”}, optional + Normalization mode (see `numpy.fft`). Default is "backward". Indicates which + direction of the forward/backward pair of transforms is scaled and with what + normalization factor. + + Returns + ------- + DataArray: + The transformed input with an updated dimension name and values. The length of + the transformed dimension is (n/2)+1 if n is even or (n+1)/2 if n is odd. + + Notes + ----- + To perform a multidimensional fourrier transform, repeat this function on the + desired dimensions. + + """ + ((olddim, newdim),) = dim.items() + olddim = da.dims[da.get_axis_num(olddim)] + if n is None: + n = da.sizes[olddim] + axis = da.get_axis_num(olddim) + d = get_sampling_interval(da, olddim) + across = int(axis == 0) + func = parallelize(across, across, parallel)(np.fft.irfft) + f = np.fft.fftshift(np.fft.fftfreq(n, d)) + data = func(da.values, n, axis, norm) + coords = { + newdim if name == olddim else name: f if name == olddim else da.coords[name] + for name in da.coords + if (da[name].dim != olddim or name == olddim) } dims = tuple(newdim if dim == olddim else dim for dim in da.dims) return DataArray(data, coords, dims, da.name, da.attrs) diff --git a/xdas/signal.py b/xdas/signal.py index 1fac292..5d15552 100644 --- a/xdas/signal.py +++ b/xdas/signal.py @@ -5,6 +5,7 @@ from .core.coordinates import Coordinate, get_sampling_interval from .core.dataarray import DataArray from .parallel import parallelize +from .spectral import stft @atomized @@ -141,11 +142,11 @@ def hilbert(da, N=None, dim="last", parallel=None): -------- In this example we use the Hilbert transform to determine the analytic signal. - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() - >>> xp.hilbert(da, dim="time") + >>> xs.hilbert(da, dim="time") [[ 0.0497+0.1632j -0.0635+0.0125j ... 0.1352-0.3107j -0.2832-0.0126j] [-0.1096-0.0335j 0.124 +0.0257j ... -0.0444+0.2409j 0.1378-0.2702j] @@ -202,11 +203,11 @@ def resample(da, num, dim="last", window=None, domain="time", parallel=None): A synthetic dataarray is resample from 300 to 100 samples along the time dimension. The 'hamming' window is used. - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() - >>> xp.resample(da, 100, dim='time', window='hamming', domain='time') + >>> xs.resample(da, 100, dim='time', window='hamming', domain='time') [[ 0.039988 0.04855 -0.08251 ... 0.02539 -0.055219 -0.006693] [-0.032913 -0.016732 0.033743 ... 0.028534 -0.037685 0.032918] @@ -293,11 +294,11 @@ def resample_poly( with an original shape of 300 in time. The choosed window is a 'hamming' window. The dataarray is synthetic data. - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() - >>> xp.resample_poly(da, 2, 5, dim='time') + >>> xs.resample_poly(da, 2, 5, dim='time') [[-0.006378 0.012767 -0.002068 ... -0.033461 0.002603 -0.027478] [ 0.008851 -0.037799 0.009595 ... 0.053291 -0.0396 0.026909] @@ -376,12 +377,12 @@ def lfilter(b, a, da, dim="last", zi=None, parallel=None): Examples -------- >>> import scipy.signal as sp - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() >>> b, a = sp.iirfilter(4, 0.5, btype="low") - >>> xp.lfilter(b, a, da, dim='time') + >>> xs.lfilter(b, a, da, dim='time') [[ 0.004668 -0.005968 0.007386 ... -0.0138 0.01271 -0.026618] [ 0.008372 -0.01222 0.022552 ... -0.041387 0.046667 -0.093521] @@ -484,12 +485,12 @@ def filtfilt( Examples -------- >>> import scipy.signal as sp - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() >>> b, a = sp.iirfilter(4, 0.5, btype="low") - >>> xp.lfilter(b, a, da, dim='time') + >>> xs.lfilter(b, a, da, dim='time') [[ 0.004668 -0.005968 0.007386 ... -0.0138 0.01271 -0.026618] [ 0.008372 -0.01222 0.022552 ... -0.041387 0.046667 -0.093521] @@ -553,12 +554,12 @@ def sosfilt(sos, da, dim="last", zi=None, parallel=None): Examples -------- >>> import scipy.signal as sp - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() >>> sos = sp.iirfilter(4, 0.5, btype="low", output="sos") - >>> xp.sosfilt(sos, da, dim='time') + >>> xs.sosfilt(sos, da, dim='time') [[ 0.004668 -0.005968 0.007386 ... -0.0138 0.01271 -0.026618] [ 0.008372 -0.01222 0.022552 ... -0.041387 0.046667 -0.093521] @@ -641,12 +642,12 @@ def sosfiltfilt(sos, da, dim="last", padtype="odd", padlen=None, parallel=None): Examples -------- >>> import scipy.signal as sp - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() >>> sos = sp.iirfilter(4, 0.5, btype="low", output="sos") - >>> xp.sosfiltfilt(sos, da, dim='time') + >>> xs.sosfiltfilt(sos, da, dim='time') [[ 0.04968 -0.063651 0.078731 ... -0.146869 0.135149 -0.283111] [-0.01724 0.018588 -0.037267 ... 0.025092 -0.107095 0.127912] @@ -708,10 +709,15 @@ def decimate(da, q, n=None, ftype="iir", zero_phase=True, dim="last", parallel=N """ axis = da.get_axis_num(dim) + dim = da.dims[axis] # TODO: this fist last thing is a bad idea... across = int(axis == 0) func = parallelize(across, across, parallel)(sp.decimate) data = func(da.values, q, n, ftype, axis, zero_phase) - return da[{dim: slice(None, None, q)}].copy(data=data) + coords = da.coords.copy() + for name in coords: + if coords[name].dim == dim: + coords[name] = coords[name][::q] + return DataArray(data, coords, da.dims, da.name, da.attrs) @atomized @@ -898,11 +904,11 @@ def medfilt(da, kernel_dim): # TODO: parallelize A median filter is applied to some synthetic dataarray with a median window size of 7 along the time dimension and 5 along the space dimension. - >>> import xdas.signal as xp + >>> import xdas.signal as xs >>> from xdas.synthetics import wavelet_wavefronts >>> da = wavelet_wavefronts() - >>> xp.medfilt(da, {"time": 7, "distance": 5}) + >>> xs.medfilt(da, {"time": 7, "distance": 5}) [[ 0. 0. 0. ... 0. 0. 0. ] [ 0. 0. 0. ... 0. 0. 0. ] diff --git a/xdas/spectral.py b/xdas/spectral.py new file mode 100644 index 0000000..dc7e6a4 --- /dev/null +++ b/xdas/spectral.py @@ -0,0 +1,137 @@ +import numpy as np +from scipy.fft import fft, fftfreq, fftshift, rfft, rfftfreq +from scipy.signal import get_window + +from .core.coordinates import get_sampling_interval +from .core.dataarray import DataArray +from .parallel import parallelize + + +def stft( + da, + window="hann", + nperseg=256, + noverlap=None, + nfft=None, + return_onesided=True, + dim={"last": "sprectrum"}, + scaling="spectrum", + parallel=None, +): + """ + Compute the Short-Time Fourier Transform (STFT) of a data array. + + Parameters + ---------- + da : DataArray + Input data array. + window : str or tuple or array_like, optional + Desired window to use. If a string or tuple, it is passed to + `scipy.signal.get_window` to generate the window values, which are + DFT-even by default. See `scipy.signal.get_window` for a list of + windows and required parameters. If an array, it will be used + directly as the window and its length must be `nperseg`. + nperseg : int, optional + Length of each segment. Defaults to 256. + noverlap : int, optional + Number of points to overlap between segments. If None, `noverlap` + defaults to `nperseg // 2`. Defaults to None. + nfft : int, optional + Length of the FFT used, if a zero padded FFT is desired. If None, + the FFT length is `nperseg`. Defaults to None. + return_onesided : bool, optional + If True, return a one-sided spectrum for real data. If False return + a two-sided spectrum. Defaults to True. + dim : dict, optional + Dictionary specifying the input and output dimensions. Defaults to + {"last": "spectrum"}. + scaling : {'spectrum', 'psd'}, optional + Selects between computing the power spectral density ('psd') where + `scale` is 1 / (sum of window squared) and computing the spectrum + ('spectrum') where `scale` is 1 / (sum of window). Defaults to + 'spectrum'. + parallel : optional + Parallelization option. Defaults to None. + + Returns + ------- + DataArray + STFT of `da`. + + Notes + ----- + The STFT represents a signal in the time-frequency domain by computing + discrete Fourier transforms (DFT) over short overlapping segments of + the signal. + + See Also + -------- + scipy.signal.stft : Compute the Short-Time Fourier Transform (STFT). + + """ + if noverlap is None: + noverlap = nperseg // 2 + if nfft is None: + nfft = nperseg + win = get_window(window, nperseg) + input_dim, output_dim = next(iter(dim.items())) + axis = da.get_axis_num(input_dim) + dt = get_sampling_interval(da, input_dim) + if scaling == "spectrum": + scale = 1.0 / win.sum() ** 2 + elif scaling == "psd": + scale = 1.0 / ((win * win).sum() / dt) + else: + raise ValueError("Scaling must be 'spectrum' or 'psd'") + scale = np.sqrt(scale) + if return_onesided: + freqs = rfftfreq(nfft, dt) + else: + freqs = fftshift(fftfreq(nfft, dt)) + freqs = {"tie_indices": [0, len(freqs) - 1], "tie_values": [freqs[0], freqs[-1]]} + + def func(x): + if nperseg == 1 and noverlap == 0: + result = x[..., np.newaxis] + else: + step = nperseg - noverlap + result = np.lib.stride_tricks.sliding_window_view( + x, window_shape=nperseg, axis=axis, writeable=True + ) + slc = [slice(None)] * result.ndim + slc[axis] = slice(None, None, step) + result = result[tuple(slc)] + result = win * result + if return_onesided: + result = rfft(result, n=nfft) + else: + result = fftshift(fft(result, n=nfft), axes=-1) + result *= scale + return result + + across = int(axis == 0) + func = parallelize(across, across, parallel)(func) + data = func(da.values) + + dt = get_sampling_interval(da, input_dim, cast=False) + t0 = da.coords[input_dim].values[0] + starttime = t0 + (nperseg / 2) * dt + endtime = starttime + (data.shape[axis] - 1) * (nperseg - noverlap) * dt + time = { + "tie_indices": [0, data.shape[axis] - 1], + "tie_values": [starttime, endtime], + } + + coords = {} + for name in da.coords: + if name == input_dim: + coords[input_dim] = time + elif name == output_dim: + coords[output_dim] = freqs + elif da[name].dim != input_dim: # TODO: keep non-dimensional coordinates + coords[name] = da.coords[name] + coords[output_dim] = freqs + + dims = da.dims + (output_dim,) + + return DataArray(data, coords, dims)