"""Fast Fourier Transform linear operator."""
__all__ = ["FFTOp", "IFFTOp", "FFTGramOp"]
import numpy as np
import torch
from .. import fft as _fft
from . import base
[docs]class FFTOp(base.Linop):
    """
    Fast Fourier Transform operator.
    K-space sampling mask, if provided, is expected to to have the following dimensions:
    * 2D MRI: ``(ncontrasts, ny, nx)``
    * 3D MRI: ``(ncontrasts, nz, ny)``
    Input images are expected to have the following dimensions:
    * 2D MRI: ``(nslices, nsets, ncoils, ncontrasts, ny, nx)``
    * 3D MRI: ``(nsets, ncoils, ncontrasts, nz, ny)``
    where ``nsets`` represents multiple sets of coil sensitivity estimation
    for soft-SENSE implementations (e.g., ESPIRIT), equal to ``1`` for conventional SENSE
    and ``ncoils`` represents the number of receiver channels in the coil array.
    Similarly, output k-space data are expected to have the following dimensions:
    * 2D MRI: ``(nslices, nsets, ncoils, ncontrasts, ny, nx)``
    * 3D MRI: ``(nsets, ncoils, ncontrasts, nz, ny)``
    """
[docs]    def __init__(self, mask=None, basis_adjoint=None, device=None):
        super().__init__(ndim=2)
        if device is None:
            device = "cpu"
        self.device = device
        if mask is not None:
            self.mask = torch.as_tensor(mask, device=device)
        else:
            self.mask = None
        if basis_adjoint is not None:
            self.basis_adjoint = torch.as_tensor(basis_adjoint, device=device)
        else:
            self.basis_adjoint = None 
    def forward(self, x):
        """
        Apply Sparse Fast Fourier Transform.
        Parameters
        ----------
        x : np.ndarray | torch.Tensor
            Input image of shape ``(..., ncontrasts, ny, nx)`` (2D)
            or ``(..., ncontrasts, nz, ny)`` (3D).
        Returns
        -------
        y : np.ndarray | torch.Tensor
            Output sparse kspace of shape ``(..., ncontrasts, nviews, nsamples)``.
        """
        if isinstance(x, np.ndarray):
            isnumpy = True
        else:
            isnumpy = False
        # convert to tensor
        x = torch.as_tensor(x)
        if self.device is None:
            self.device = x.device
        # cast
        x = x.to(self.device)
        # get adjoint basis
        if self.basis_adjoint is not None:
            basis_adjoint = self.basis_adjoint
        else:
            basis_adjoint = None
        # apply Fourier transform
        y = _fft.fft(x, axes=(-1, -2), norm="ortho")
        # project
        if basis_adjoint is not None:
            y = y[..., None]
            y = y.swapaxes(-4, -1)
            yshape = list(y.shape)
            y = y.reshape(-1, y.shape[-1])  # (prod(y.shape[:-1]), ncoeff)
            y = y @ basis_adjoint  # (prod(y.shape[:-1]), ncontrasts)
            y = y.reshape(*yshape[:-1], y.shape[-1])
            y = y.swapaxes(-4, -1)
            y = y[..., 0]
        # mask if required
        if self.mask is not None:
            y = self.mask * y
        # cast back to numpy if required
        if isnumpy:
            y = y.numpy(force=True)
        return y
    def _adjoint_linop(self):
        # get adjoint basis
        if self.basis_adjoint is not None:
            basis = self.basis_adjoint.conj().t()
        else:
            basis = None
        return IFFTOp(self.mask, basis, self.device) 
[docs]class IFFTOp(base.Linop):
    """
    Inverse Fast Fourier Transform operator.
    K-space sampling mask, if provided, is expected to to have the following dimensions:
    * 2D MRI: ``(ncontrasts, ny, nx)``
    * 3D MRI: ``(ncontrasts, nz, ny)``
    Input k-space data are expected to have the following dimensions:
    * 2D MRI: ``(nslices, nsets, ncoils, ncontrasts, ny, nx)``
    * 3D MRI: ``(nsets, ncoils, ncontrasts, nz, ny)``
    where ``nsets`` represents multiple sets of coil sensitivity estimation
    for soft-SENSE implementations (e.g., ESPIRIT), equal to ``1`` for conventional SENSE
    and ``ncoils`` represents the number of receiver channels in the coil array.
    Similarly, output images are expected to have the following dimensions:
    * 2D MRI: ``(nslices, nsets, ncoils, ncontrasts, ny, nx)``
    * 3D MRI: ``(nsets, ncoils, ncontrasts, nz, ny)``
    """
[docs]    def __init__(self, mask=None, basis=None, device=None, **kwargs):
        super().__init__(ndim=2, **kwargs)
        if device is None:
            device = "cpu"
        self.device = device
        if mask is not None:
            self.mask = torch.as_tensor(mask, device=device)
        else:
            self.mask = None
        if basis is not None:
            self.basis = torch.as_tensor(basis, device=device)
        else:
            self.basis = None 
    def forward(self, y):
        """
        Apply adjoint Non-Uniform Fast Fourier Transform.
        Parameters
        ----------
        y : torch.Tensor
            Input sparse kspace of shape ``(..., ncontrasts, nviews, nsamples)``.
        Returns
        -------
        x : np.ndarray | torch.Tensor
            Output image of shape ``(..., ncontrasts, ny, nx)`` (2D)
            or ``(..., ncontrasts, nz, ny)`` (3D).
        """
        if isinstance(y, np.ndarray):
            isnumpy = True
        else:
            isnumpy = False
        # convert to tensor
        y = torch.as_tensor(y)
        if self.device is None:
            self.device = y.device
        # cast
        y = y.to(self.device)
        if self.mask is not None:
            self.mask = self.mask.to(self.device)
        if self.basis is not None:
            self.basis = self.basis.to(self.device)
        # mask if required
        if self.mask is not None:
            y = self.mask * y
        # project
        if self.basis is not None:
            y = y[..., None]
            y = y.swapaxes(-4, -1)
            yshape = list(y.shape)
            y = y.reshape(-1, y.shape[-1])  # (prod(y.shape[:-1]), ncoeff)
            y = y @ self.basis  # (prod(y.shape[:-1]), ncontrasts)
            y = y.reshape(*yshape[:-1], y.shape[-1])
            y = y.swapaxes(-4, -1)
            y = y[..., 0]
        # apply Fourier transform
        x = _fft.ifft(y, axes=(-1, -2), norm="ortho")
        # cast back to numpy if required
        if isnumpy:
            x = x.numpy(force=True)
        return x
    def _adjoint_linop(self):
        # get adjoint basis
        if self.basis is not None:
            basis_adjoint = self.basis.conj().t()
        else:
            basis_adjoint = None
        return FFTOp(self.mask, basis_adjoint, self.device) 
[docs]class FFTGramOp(base.Linop):
    """
    Self-adjoint Sparse Fast Fourier Transform operator.
    K-space sampling mask, if provided, is expected to to have the following dimensions:
    * 2D MRI: ``(ncontrasts, ny, nx)``
    * 3D MRI: ``(ncontrasts, nz, ny)``
    Input and output images are expected to have the following dimensions:
    * 2D MRI: ``(nslices, nsets, ncoils, ncontrasts, ny, nx)``
    * 3D MRI: ``(nsets, ncoils, ncontrasts, nz, ny)``
    where ``nsets`` represents multiple sets of coil sensitivity estimation
    for soft-SENSE implementations (e.g., ESPIRIT), equal to ``1`` for conventional SENSE
    and ``ncoils`` represents the number of receiver channels in the coil array.
    """
[docs]    def __init__(self, mask=None, basis=None, device=None, **kwargs):
        super().__init__(ndim=2, **kwargs)
        self.device = device
        if device is None:
            device = "cpu"
        if basis is not None:
            basis = torch.as_tensor(basis, device=device)
        else:
            basis = None
        if mask is not None:
            mask = torch.as_tensor(mask, device=device)
        else:
            mask = None
        # calculate space-time kernel
        if basis is not None and mask is not None:
            T, K = basis.shape
            nt, nz, ny = mask.shape  # or (nt, ny, nx) for 2D
            assert nt == T
            tmp = mask.permute(2, 1, 0).reshape((ny, nz, T, 1, 1)) * basis.reshape(
                (1, 1, nt, 1, K)
            )  # (ny, nz, T, 1, K) / (nx, ny, T, 1, K)
            toeplitz_kern = (tmp * basis.reshape(1, 1, T, K, 1)).sum(
                axis=2
            )  # (ny, nz, K, K) / (nx, ny, K, K)
            toeplitz_kern = torch.fft.fftshift(
                torch.fft.fftshift(toeplitz_kern, axis=0), axis=1
            )
            self._toeplitz_kern = (
                toeplitz_kern.swapaxes(0, 1).reshape(-1, K, K).contiguous()
            )  # (nz*ny, K, K) / (ny*nx, K, K)
        else:
            self._toeplitz_kern = None 
    def forward(self, x):
        """
        Apply Toeplitz convolution (``SparseFFT.H * SparseFFT``).
        Parameters
        ----------
        x : np.ndarray | torch.Tensor
            Input image of shape ``(..., ncontrasts, ny, nx)`` (2D)
            or ``(..., ncontrasts, nz, ny)`` (3D).
        Returns
        -------
        y : np.ndarray | torch.Tensor
            Output image of shape ``(..., ncontrasts, ny, nx)`` (2D)
            or ``(..., ncontrasts, nz, ny)`` (3D).
        """
        if isinstance(x, np.ndarray):
            isnumpy = True
        else:
            isnumpy = False
        # convert to tensor
        x = torch.as_tensor(x)
        if self.device is None:
            self.device = x.device
        # cast
        x = x.to(self.device)
        if self._toeplitz_kern is not None:
            self._toeplitz_kern = self._toeplitz_kern.to(self.device)
        # fourier transform
        y = _fft.fft(x, axes=(-1, -2), norm="ortho", centered=False)
        # project if required
        if self._toeplitz_kern is not None:
            y = y[..., None]  # (..., ncoeff, nz, ny, 1) / (..., ncoeff, ny, nx, 1)
            y = y.swapaxes(
                -4, -1
            )  # (..., 1, nz, ny, ncoeff) / (..., 1, ny, nx, ncoeff)
            yshape = list(y.shape)
            y = y.reshape(
                int(np.prod(yshape[:-4])), -1, y.shape[-1]
            )  # (prod(y.shape[:-4]), nz*ny, ncoeff) / (prod(y.shape[:-4]), ny*nx, ncoeff)
            y = torch.einsum("...bi,bij->...bj", y, self._toeplitz_kern)
            y = y.reshape(
                *yshape
            )  # (..., 1, nz, ny, ncoeff) / # (..., 1, ny, nx, ncoeff)
            y = y.swapaxes(
                -4, -1
            )  # (..., ncoeff, nz, ny, 1) / # (..., ncoeff, ny, nx, 1)
            y = y[..., 0]  # (..., ncoeff, nz, ny) / # (..., ncoeff, ny, nx)
        # apply Fourier transform
        x = _fft.ifft(y, axes=(-1, -2), norm="ortho", centered=False)
        # cast back to numpy if required
        if isnumpy:
            x = x.numpy(force=True)
        return x
    def _adjoint_linop(self):
        return self