From ece415ed88e058bd8d35432be33c838ee56dfb3e Mon Sep 17 00:00:00 2001 From: Alister Trabattoni Date: Tue, 3 Dec 2024 10:06:10 +0100 Subject: [PATCH] Add parallelization to STFT. --- tests/test_signal.py | 28 ++++++++++++++++++++++++++++ xdas/spectral.py | 5 ++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/test_signal.py b/tests/test_signal.py index be421a0..ca9aad5 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -202,3 +202,31 @@ def test_signal(self): ) idx = int(np.abs(np.square(result)).mean("time").argmax("frequency").values) assert result["frequency"][idx].values == fc + + def test_parrallel(self): + starttime = np.datetime64("2023-01-01T00:00:00") + endtime = starttime + 9999 * np.timedelta64(10, "ms") + da = xdas.DataArray( + data=np.random.rand(10000, 11), + coords={ + "time": {"tie_indices": [0, 9999], "tie_values": [starttime, endtime]}, + "distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]}, + }, + ) + serial = xs.stft( + da, + nperseg=100, + noverlap=50, + window="hamming", + dim={"time": "frequency"}, + parallel=False, + ) + parallel = xs.stft( + da, + nperseg=100, + noverlap=50, + window="hamming", + dim={"time": "frequency"}, + parallel=True, + ) + assert serial.equals(parallel) diff --git a/xdas/spectral.py b/xdas/spectral.py index 1aafbcc..3907d5d 100644 --- a/xdas/spectral.py +++ b/xdas/spectral.py @@ -4,6 +4,7 @@ from .core.coordinates import get_sampling_interval from .core.dataarray import DataArray +from .parallel import parallelize def stft( @@ -15,7 +16,7 @@ def stft( return_onesided=True, dim={"last": "sprectrum"}, scaling="spectrum", - # parallel=None, + parallel=None, ): if noverlap is None: noverlap = nperseg // 2 @@ -57,6 +58,8 @@ def func(x): result *= scale return result + across = int(axis == 0) + func = parallelize(across, across, parallel)(func) data = func(da.values) dt = get_sampling_interval(da, input_dim, cast=False)