"""Fast Fourier Transform linear operator."""
__all__ = ["FFTOp", "IFFTOp", "FFTGramOp"]
import numpy as np
import torch
from .. import fft as _fft
from . import base
[docs]class FFTOp(base.Linop):
"""
Fast Fourier Transform operator.
K-space sampling mask, if provided, is expected to to have the following dimensions:
* 2D MRI: ``(ncontrasts, ny, nx)``
* 3D MRI: ``(ncontrasts, nz, ny)``
Input images are expected to have the following dimensions:
* 2D MRI: ``(nslices, nsets, ncoils, ncontrasts, ny, nx)``
* 3D MRI: ``(nsets, ncoils, ncontrasts, nz, ny)``
where ``nsets`` represents multiple sets of coil sensitivity estimation
for soft-SENSE implementations (e.g., ESPIRIT), equal to ``1`` for conventional SENSE
and ``ncoils`` represents the number of receiver channels in the coil array.
Similarly, output k-space data are expected to have the following dimensions:
* 2D MRI: ``(nslices, nsets, ncoils, ncontrasts, ny, nx)``
* 3D MRI: ``(nsets, ncoils, ncontrasts, nz, ny)``
"""
[docs] def __init__(self, mask=None, basis_adjoint=None, device=None):
super().__init__(ndim=2)
if device is None:
device = "cpu"
self.device = device
if mask is not None:
self.mask = torch.as_tensor(mask, device=device)
else:
self.mask = None
if basis_adjoint is not None:
self.basis_adjoint = torch.as_tensor(basis_adjoint, device=device)
else:
self.basis_adjoint = None
def forward(self, x):
"""
Apply Sparse Fast Fourier Transform.
Parameters
----------
x : np.ndarray | torch.Tensor
Input image of shape ``(..., ncontrasts, ny, nx)`` (2D)
or ``(..., ncontrasts, nz, ny)`` (3D).
Returns
-------
y : np.ndarray | torch.Tensor
Output sparse kspace of shape ``(..., ncontrasts, nviews, nsamples)``.
"""
if isinstance(x, np.ndarray):
isnumpy = True
else:
isnumpy = False
# convert to tensor
x = torch.as_tensor(x)
if self.device is None:
self.device = x.device
# cast
x = x.to(self.device)
# get adjoint basis
if self.basis_adjoint is not None:
basis_adjoint = self.basis_adjoint
else:
basis_adjoint = None
# apply Fourier transform
y = _fft.fft(x, axes=(-1, -2), norm="ortho")
# project
if basis_adjoint is not None:
y = y[..., None]
y = y.swapaxes(-4, -1)
yshape = list(y.shape)
y = y.reshape(-1, y.shape[-1]) # (prod(y.shape[:-1]), ncoeff)
y = y @ basis_adjoint # (prod(y.shape[:-1]), ncontrasts)
y = y.reshape(*yshape[:-1], y.shape[-1])
y = y.swapaxes(-4, -1)
y = y[..., 0]
# mask if required
if self.mask is not None:
y = self.mask * y
# cast back to numpy if required
if isnumpy:
y = y.numpy(force=True)
return y
def _adjoint_linop(self):
# get adjoint basis
if self.basis_adjoint is not None:
basis = self.basis_adjoint.conj().t()
else:
basis = None
return IFFTOp(self.mask, basis, self.device)
[docs]class IFFTOp(base.Linop):
"""
Inverse Fast Fourier Transform operator.
K-space sampling mask, if provided, is expected to to have the following dimensions:
* 2D MRI: ``(ncontrasts, ny, nx)``
* 3D MRI: ``(ncontrasts, nz, ny)``
Input k-space data are expected to have the following dimensions:
* 2D MRI: ``(nslices, nsets, ncoils, ncontrasts, ny, nx)``
* 3D MRI: ``(nsets, ncoils, ncontrasts, nz, ny)``
where ``nsets`` represents multiple sets of coil sensitivity estimation
for soft-SENSE implementations (e.g., ESPIRIT), equal to ``1`` for conventional SENSE
and ``ncoils`` represents the number of receiver channels in the coil array.
Similarly, output images are expected to have the following dimensions:
* 2D MRI: ``(nslices, nsets, ncoils, ncontrasts, ny, nx)``
* 3D MRI: ``(nsets, ncoils, ncontrasts, nz, ny)``
"""
[docs] def __init__(self, mask=None, basis=None, device=None, **kwargs):
super().__init__(ndim=2, **kwargs)
if device is None:
device = "cpu"
self.device = device
if mask is not None:
self.mask = torch.as_tensor(mask, device=device)
else:
self.mask = None
if basis is not None:
self.basis = torch.as_tensor(basis, device=device)
else:
self.basis = None
def forward(self, y):
"""
Apply adjoint Non-Uniform Fast Fourier Transform.
Parameters
----------
y : torch.Tensor
Input sparse kspace of shape ``(..., ncontrasts, nviews, nsamples)``.
Returns
-------
x : np.ndarray | torch.Tensor
Output image of shape ``(..., ncontrasts, ny, nx)`` (2D)
or ``(..., ncontrasts, nz, ny)`` (3D).
"""
if isinstance(y, np.ndarray):
isnumpy = True
else:
isnumpy = False
# convert to tensor
y = torch.as_tensor(y)
if self.device is None:
self.device = y.device
# cast
y = y.to(self.device)
if self.mask is not None:
self.mask = self.mask.to(self.device)
if self.basis is not None:
self.basis = self.basis.to(self.device)
# mask if required
if self.mask is not None:
y = self.mask * y
# project
if self.basis is not None:
y = y[..., None]
y = y.swapaxes(-4, -1)
yshape = list(y.shape)
y = y.reshape(-1, y.shape[-1]) # (prod(y.shape[:-1]), ncoeff)
y = y @ self.basis # (prod(y.shape[:-1]), ncontrasts)
y = y.reshape(*yshape[:-1], y.shape[-1])
y = y.swapaxes(-4, -1)
y = y[..., 0]
# apply Fourier transform
x = _fft.ifft(y, axes=(-1, -2), norm="ortho")
# cast back to numpy if required
if isnumpy:
x = x.numpy(force=True)
return x
def _adjoint_linop(self):
# get adjoint basis
if self.basis is not None:
basis_adjoint = self.basis.conj().t()
else:
basis_adjoint = None
return FFTOp(self.mask, basis_adjoint, self.device)
[docs]class FFTGramOp(base.Linop):
"""
Self-adjoint Sparse Fast Fourier Transform operator.
K-space sampling mask, if provided, is expected to to have the following dimensions:
* 2D MRI: ``(ncontrasts, ny, nx)``
* 3D MRI: ``(ncontrasts, nz, ny)``
Input and output images are expected to have the following dimensions:
* 2D MRI: ``(nslices, nsets, ncoils, ncontrasts, ny, nx)``
* 3D MRI: ``(nsets, ncoils, ncontrasts, nz, ny)``
where ``nsets`` represents multiple sets of coil sensitivity estimation
for soft-SENSE implementations (e.g., ESPIRIT), equal to ``1`` for conventional SENSE
and ``ncoils`` represents the number of receiver channels in the coil array.
"""
[docs] def __init__(self, mask=None, basis=None, device=None, **kwargs):
super().__init__(ndim=2, **kwargs)
self.device = device
if device is None:
device = "cpu"
if basis is not None:
basis = torch.as_tensor(basis, device=device)
else:
basis = None
if mask is not None:
mask = torch.as_tensor(mask, device=device)
else:
mask = None
# calculate space-time kernel
if basis is not None and mask is not None:
T, K = basis.shape
nt, nz, ny = mask.shape # or (nt, ny, nx) for 2D
assert nt == T
tmp = mask.permute(2, 1, 0).reshape((ny, nz, T, 1, 1)) * basis.reshape(
(1, 1, nt, 1, K)
) # (ny, nz, T, 1, K) / (nx, ny, T, 1, K)
toeplitz_kern = (tmp * basis.reshape(1, 1, T, K, 1)).sum(
axis=2
) # (ny, nz, K, K) / (nx, ny, K, K)
toeplitz_kern = torch.fft.fftshift(
torch.fft.fftshift(toeplitz_kern, axis=0), axis=1
)
self._toeplitz_kern = (
toeplitz_kern.swapaxes(0, 1).reshape(-1, K, K).contiguous()
) # (nz*ny, K, K) / (ny*nx, K, K)
else:
self._toeplitz_kern = None
def forward(self, x):
"""
Apply Toeplitz convolution (``SparseFFT.H * SparseFFT``).
Parameters
----------
x : np.ndarray | torch.Tensor
Input image of shape ``(..., ncontrasts, ny, nx)`` (2D)
or ``(..., ncontrasts, nz, ny)`` (3D).
Returns
-------
y : np.ndarray | torch.Tensor
Output image of shape ``(..., ncontrasts, ny, nx)`` (2D)
or ``(..., ncontrasts, nz, ny)`` (3D).
"""
if isinstance(x, np.ndarray):
isnumpy = True
else:
isnumpy = False
# convert to tensor
x = torch.as_tensor(x)
if self.device is None:
self.device = x.device
# cast
x = x.to(self.device)
if self._toeplitz_kern is not None:
self._toeplitz_kern = self._toeplitz_kern.to(self.device)
# fourier transform
y = _fft.fft(x, axes=(-1, -2), norm="ortho", centered=False)
# project if required
if self._toeplitz_kern is not None:
y = y[..., None] # (..., ncoeff, nz, ny, 1) / (..., ncoeff, ny, nx, 1)
y = y.swapaxes(
-4, -1
) # (..., 1, nz, ny, ncoeff) / (..., 1, ny, nx, ncoeff)
yshape = list(y.shape)
y = y.reshape(
int(np.prod(yshape[:-4])), -1, y.shape[-1]
) # (prod(y.shape[:-4]), nz*ny, ncoeff) / (prod(y.shape[:-4]), ny*nx, ncoeff)
y = torch.einsum("...bi,bij->...bj", y, self._toeplitz_kern)
y = y.reshape(
*yshape
) # (..., 1, nz, ny, ncoeff) / # (..., 1, ny, nx, ncoeff)
y = y.swapaxes(
-4, -1
) # (..., ncoeff, nz, ny, 1) / # (..., ncoeff, ny, nx, 1)
y = y[..., 0] # (..., ncoeff, nz, ny) / # (..., ncoeff, ny, nx)
# apply Fourier transform
x = _fft.ifft(y, axes=(-1, -2), norm="ortho", centered=False)
# cast back to numpy if required
if isnumpy:
x = x.numpy(force=True)
return x
def _adjoint_linop(self):
return self