Skip to content

Commit

Permalink
Add parallelization to STFT.
Browse files Browse the repository at this point in the history
  • Loading branch information
atrabattoni committed Dec 3, 2024
1 parent 8c7ee47 commit ece415e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
28 changes: 28 additions & 0 deletions tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,31 @@ def test_signal(self):
)
idx = int(np.abs(np.square(result)).mean("time").argmax("frequency").values)
assert result["frequency"][idx].values == fc

def test_parrallel(self):
starttime = np.datetime64("2023-01-01T00:00:00")
endtime = starttime + 9999 * np.timedelta64(10, "ms")
da = xdas.DataArray(
data=np.random.rand(10000, 11),
coords={
"time": {"tie_indices": [0, 9999], "tie_values": [starttime, endtime]},
"distance": {"tie_indices": [0, 10], "tie_values": [0.0, 1.0]},
},
)
serial = xs.stft(
da,
nperseg=100,
noverlap=50,
window="hamming",
dim={"time": "frequency"},
parallel=False,
)
parallel = xs.stft(
da,
nperseg=100,
noverlap=50,
window="hamming",
dim={"time": "frequency"},
parallel=True,
)
assert serial.equals(parallel)
5 changes: 4 additions & 1 deletion xdas/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .core.coordinates import get_sampling_interval
from .core.dataarray import DataArray
from .parallel import parallelize


def stft(
Expand All @@ -15,7 +16,7 @@ def stft(
return_onesided=True,
dim={"last": "sprectrum"},
scaling="spectrum",
# parallel=None,
parallel=None,
):
if noverlap is None:
noverlap = nperseg // 2
Expand Down Expand Up @@ -57,6 +58,8 @@ def func(x):
result *= scale
return result

across = int(axis == 0)
func = parallelize(across, across, parallel)(func)
data = func(da.values)

dt = get_sampling_interval(da, input_dim, cast=False)
Expand Down

0 comments on commit ece415e

Please sign in to comment.