"""FFT subroutines."""
__all__ = ["fft", "ifft"]
import numpy as np
import torch
[docs]def fft(input, axes=None, norm="ortho", centered=True):
"""
Centered Fast Fourier Transform.
Adapted from [1].
Parameters
----------
input : np.ndarray | torch.Tensor
Input signal.
axes : Iterable[int], optional
Axes over which to compute the FFT.
If not specified, apply FFT over all the axes.
norm : str, optional
FFT normalization. The default is ``ortho``.
centered : bool, optional
FFT centering. The default is ``True``.
Returns
-------
output : np.ndarray | torch.Tensor
Output signal.
Examples
--------
>>> import torch
>>> import deepmr
First, create test image:
>>> image = torch.zeros(32, 32, dtype=torch.complex64)
>>> image = image[16, 16] = 1.0
We now perform a 2D FFT:
>>> kspace = deepmr.fft.fft(image)
We can visualize the data:
>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots(1, 2)
>>> im = ax[0].imshow(abs(image))
>>> ax[0].set_title("Image", color="orangered", fontweight="bold")
>>> ax[0].axis("off")
>>> ax[0].set_alpha(0.0)
>>> fig.colorbar(im, ax=ax[0], shrink=0.5)
>>> ksp = ax[1].imshow(abs(kspace))
>>> ax[1].set_title("k-Space", color="orangered", fontweight="bold")
>>> ax[1].axis("off")
>>> ax[1].set_alpha(0.0)
>>> fig.colorbar(ksp, ax=ax[1], shrink=0.5)
>>> plt.show()
References
----------
[1] https://github.com/mikgroup/sigpy
"""
# check if we are using numpy arrays
if isinstance(input, np.ndarray):
isnumpy = True
else:
isnumpy = False
# make sure this is a tensor
input = torch.as_tensor(input)
ax = _normalize_axes(axes, input.ndim)
if centered:
output = torch.fft.fftshift(
torch.fft.fftn(torch.fft.ifftshift(input, dim=ax), dim=ax, norm=norm),
dim=ax,
)
else:
output = torch.fft.fftn(input, dim=ax, norm=norm)
if isnumpy:
output = np.asarray(output)
return output
[docs]def ifft(input, axes=None, norm="ortho", centered=True):
"""
Centered inverse Fast Fourier Transform.
Adapted from [1].
Parameters
----------
input : np.ndarray | torch.Tensor
Input signal.
axes : Iterable[int]
Axes over which to compute the iFFT.
If not specified, apply iFFT over all the axes.
norm : str, optional
FFT normalization. The default is ``ortho``.
centered : bool, optional
FFT centering. The default is ``True``.
Returns
-------
output : np.ndarray | torch.Tensor
Output signal.
Examples
--------
>>> import torch
>>> import deepmr
First, create test image:
>>> kspace = torch.ones(32, 32, dtype=torch.complex64)
We now perform a 2D iFFT:
>>> image = deepmr.fft.ifft(kspace)
We can visualize the data:
>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots(1, 2)
>>> ksp = ax[1].imshow(abs(kspace))
>>> ax[0].set_title("k-Space", color="orangered", fontweight="bold")
>>> ax[0].axis("off")
>>> ax[0].set_alpha(0.0)
>>> fig.colorbar(ksp, ax=ax[0], shrink=0.5)
>>> im = ax[0].imshow(abs(image))
>>> ax[1].set_title("Image", color="orangered", fontweight="bold")
>>> ax[1].axis("off")
>>> ax[1].set_alpha(0.0)
>>> fig.colorbar(im, ax=ax[1], shrink=0.5)
>>> plt.show()
References
----------
[1] https://github.com/mikgroup/sigpy
"""
# check if we are using numpy arrays
if isinstance(input, np.ndarray):
isnumpy = True
else:
isnumpy = False
# make sure this is a tensor
input = torch.as_tensor(input)
ax = _normalize_axes(axes, input.ndim)
if centered:
output = torch.fft.fftshift(
torch.fft.ifftn(torch.fft.ifftshift(input, dim=ax), dim=ax, norm=norm),
dim=ax,
)
else:
output = torch.fft.ifftn(input, dim=ax, norm=norm)
if isnumpy:
output = np.asarray(output)
return output
# %% local subroutines
def _normalize_axes(axes, ndim):
if axes is None:
return tuple(range(ndim))
else:
return tuple(a % ndim for a in sorted(axes))