Skip to content

Commit

Permalink
Drop non-dimensional coordinates for now in stft.
Browse files Browse the repository at this point in the history
  • Loading branch information
atrabattoni committed Dec 3, 2024
1 parent 71c2020 commit 146a54d
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
38 changes: 37 additions & 1 deletion tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def test_compare_with_scipy(self):
nfft=nfft,
return_onesided=return_onesided,
boundary=None,
padded=False,
axis=0,
scaling=scaling,
)
Expand All @@ -231,8 +232,9 @@ def test_compare_with_scipy(self):
/ np.timedelta64(1, "s"),
t,
)
assert result["distance"].equals(da["distance"])

def test_signal(self):
def test_retrieve_frequency_peak(self):
fs = 10e3
N = 1e5
fc = 3e3
Expand Down Expand Up @@ -276,3 +278,37 @@ def test_parrallel(self):
parallel=True,
)
assert serial.equals(parallel)

def test_last_dimension_with_non_dimensional_coordinates(self):
starttime = np.datetime64("2023-01-01T00:00:00")
endtime = starttime + 99 * np.timedelta64(10, "ms")
da = xdas.DataArray(
data=np.random.rand(100, 1001),
coords={
"time": {"tie_indices": [0, 99], "tie_values": [starttime, endtime]},
"distance": {"tie_indices": [0, 1000], "tie_values": [0.0, 10_000.0]},
"channel": ("distance", np.arange(1001)),
},
)
result = xs.stft(
da,
nperseg=100,
noverlap=50,
window="hamming",
dim={"distance": "wavenumber"},
)
f, t, Zxx = sp.stft(
da.values,
fs=1 / xs.get_sampling_interval(da, "distance"),
window="hamming",
nperseg=100,
noverlap=50,
boundary=None,
padded=False,
axis=1,
)
assert np.allclose(result.values, np.transpose(Zxx, (0, 2, 1)))
assert result["time"].equals(da["time"])
assert np.allclose(result["distance"].values, t)
assert np.allclose(result["wavenumber"].values, np.sort(f))
assert "channel" not in result.coords # TODO: keep non-dimensional coordinates
10 changes: 8 additions & 2 deletions xdas/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,14 @@ def func(x):
"tie_values": [starttime, endtime],
}

coords = da.coords.copy()
coords[input_dim] = time
coords = {}
for name in da.coords:
if name == input_dim:
coords[input_dim] = time
elif name == output_dim:
coords[output_dim] = freqs
elif da[name].dim != input_dim: # TODO: keep non-dimensional coordinates
coords[name] = da.coords[name]
coords[output_dim] = freqs

dims = da.dims + (output_dim,)
Expand Down

0 comments on commit 146a54d

Please sign in to comment.