Skip to content

Commit

Permalink
Add irfft with test and doc.
Browse files Browse the repository at this point in the history
  • Loading branch information
atrabattoni committed Dec 2, 2024
1 parent f900984 commit 658d123
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/api/fft.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
fft
ifft
rfft
irfft
```
24 changes: 21 additions & 3 deletions tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ def test_with_non_dimensional(self):
xfft.rfft(da)


class TestIFFT:
def test_base(self):
class TestIncerse:
def test_standard(self):
expected = xd.synthetics.wavelet_wavefronts()
result = xfft.ifft(
xfft.fft(expected, dim={"time": "frequency"}), dim={"frequency": "time"}
xfft.fft(expected, dim={"time": "frequency"}),
dim={"frequency": "time"},
)
assert np.allclose(np.real(result).values, expected.values)
assert np.allclose(np.imag(result).values, 0)
Expand All @@ -27,3 +28,20 @@ def test_base(self):
assert np.allclose(result["time"].values, ref)
else:
assert result[name].equals(expected[name])

def test_real(self):
expected = xd.synthetics.wavelet_wavefronts()
result = xfft.irfft(
xfft.rfft(expected, dim={"time": "frequency"}),
expected.sizes["time"],
dim={"frequency": "time"},
)
assert np.allclose(result.values, expected.values)
for name in result.coords:
if name == "time":
ref = expected["time"].values
ref = (ref - ref[0]) / np.timedelta64(1, "s")
ref += result["time"][0].values
assert np.allclose(result["time"].values, ref)
else:
assert result[name].equals(expected[name])
60 changes: 58 additions & 2 deletions xdas/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def fft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None):


@atomized
def rfft(da, n=None, dim={"last": "frequency"}, norm=None, parallel=None):
def rfft(da, n=None, dim={"last": "spectrum"}, norm=None, parallel=None):
"""
Compute the discrete Fourier Transform for real inputs along a given dimension.
Expand Down Expand Up @@ -138,7 +138,7 @@ def ifft(da, n=None, dim={"last": "time"}, norm=None, parallel=None):
by `dim` is used.
dim: {str: str}, optional
A mapping indicating as a key the dimension along which to compute the IFFT, and
as value the new name of the dimension. Default to {"last": "spectrum"}.
as value the new name of the dimension. Default to {"last": "time"}.
norm: {“backward”, “ortho”, “forward”}, optional
Normalization mode (see `numpy.fft`). Default is "backward". Indicates which
direction of the forward/backward pair of transforms is scaled and with what
Expand Down Expand Up @@ -173,3 +173,59 @@ def ifft(da, n=None, dim={"last": "time"}, norm=None, parallel=None):
}
dims = tuple(newdim if dim == olddim else dim for dim in da.dims)
return DataArray(data, coords, dims, da.name, da.attrs)


@atomized
def irfft(da, n=None, dim={"last": "time"}, norm=None, parallel=None):
"""
Compute the discrete Fourier Transform for real inputs along a given dimension.
This function computes the one-dimensional n-point discrete Fourier Transform (DFT)
or real-valued inputs with the efficient Fast Fourier Transform (FFT) algorithm.
Parameters
----------
da: DataArray
The data array to process, can be complex.
n: int, optional
Length of transformed dimension of the output. If n is smaller than the length
of the input, the input is cropped. If it is larger, the input is padded with
zeros. If n is not given, the length of the input along the dimension specified
by `dim` is used.
dim: {str: str}, optional
A mapping indicating as a key the dimension along which to compute the FFT, and
as value the new name of the dimension. Default to {"last": "time"}.
norm: {“backward”, “ortho”, “forward”}, optional
Normalization mode (see `numpy.fft`). Default is "backward". Indicates which
direction of the forward/backward pair of transforms is scaled and with what
normalization factor.
Returns
-------
DataArray:
The transformed input with an updated dimension name and values. The length of
the transformed dimension is (n/2)+1 if n is even or (n+1)/2 if n is odd.
Notes
-----
To perform a multidimensional fourrier transform, repeat this function on the
desired dimensions.
"""
((olddim, newdim),) = dim.items()
olddim = da.dims[da.get_axis_num(olddim)]
if n is None:
n = da.sizes[olddim]
axis = da.get_axis_num(olddim)
d = get_sampling_interval(da, olddim)
across = int(axis == 0)
func = parallelize(across, across, parallel)(np.fft.irfft)
f = np.fft.fftshift(np.fft.fftfreq(n, d))
data = func(da.values, n, axis, norm)
coords = {
newdim if name == olddim else name: f if name == olddim else da.coords[name]
for name in da.coords
if (da[name].dim != olddim or name == olddim)
}
dims = tuple(newdim if dim == olddim else dim for dim in da.dims)
return DataArray(data, coords, dims, da.name, da.attrs)

0 comments on commit 658d123

Please sign in to comment.