diff --git a/tests/test_signal.py b/tests/test_signal.py index 4ae1b8a..e14c41a 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -215,6 +215,7 @@ def test_compare_with_scipy(self): nfft=nfft, return_onesided=return_onesided, boundary=None, + padded=False, axis=0, scaling=scaling, ) @@ -231,8 +232,9 @@ def test_compare_with_scipy(self): / np.timedelta64(1, "s"), t, ) + assert result["distance"].equals(da["distance"]) - def test_signal(self): + def test_retrieve_frequency_peak(self): fs = 10e3 N = 1e5 fc = 3e3 @@ -276,3 +278,37 @@ def test_parrallel(self): parallel=True, ) assert serial.equals(parallel) + + def test_last_dimension_with_non_dimensional_coordinates(self): + starttime = np.datetime64("2023-01-01T00:00:00") + endtime = starttime + 99 * np.timedelta64(10, "ms") + da = xdas.DataArray( + data=np.random.rand(100, 1001), + coords={ + "time": {"tie_indices": [0, 99], "tie_values": [starttime, endtime]}, + "distance": {"tie_indices": [0, 1000], "tie_values": [0.0, 10_000.0]}, + "channel": ("distance", np.arange(1001)), + }, + ) + result = xs.stft( + da, + nperseg=100, + noverlap=50, + window="hamming", + dim={"distance": "wavenumber"}, + ) + f, t, Zxx = sp.stft( + da.values, + fs=1 / xs.get_sampling_interval(da, "distance"), + window="hamming", + nperseg=100, + noverlap=50, + boundary=None, + padded=False, + axis=1, + ) + assert np.allclose(result.values, np.transpose(Zxx, (0, 2, 1))) + assert result["time"].equals(da["time"]) + assert np.allclose(result["distance"].values, t) + assert np.allclose(result["wavenumber"].values, np.sort(f)) + assert "channel" not in result.coords # TODO: keep non-dimensional coordinates diff --git a/xdas/spectral.py b/xdas/spectral.py index de00544..a709547 100644 --- a/xdas/spectral.py +++ b/xdas/spectral.py @@ -71,8 +71,14 @@ def func(x): "tie_values": [starttime, endtime], } - coords = da.coords.copy() - coords[input_dim] = time + coords = {} + for name in da.coords: + if name == input_dim: + coords[input_dim] = time + elif name == output_dim: + coords[output_dim] = freqs + elif da[name].dim != input_dim: # TODO: keep non-dimensional coordinates + coords[name] = da.coords[name] coords[output_dim] = freqs dims = da.dims + (output_dim,)