"""NUFFT subroutines."""
__all__ = [
    "plan_nufft",
    "plan_toeplitz_nufft",
    "apply_nufft",
    "apply_nufft_adj",
    "apply_nufft_selfadj",
]
import gc
import math
from dataclasses import dataclass
import numpy as np
import torch
import torch.autograd as autograd
from .._signal import resize as _resize
from . import fft as _fft
from . import _interp
from . import toeplitz as _toeplitz
def plan_nufft(coord, shape, width=4, oversamp=1.25, device="cpu"):
    """
    Precompute NUFFT object.
    Parameters
    ----------
    coord : torch.Tensor
        K-space coordinates of shape ``(ncontrasts, nviews, nsamples, ndims)``.
        Coordinates must be normalized between ``(-0.5 * shape, 0.5 * shape)``.
    shape : int | Iterable[int]
        Oversampled grid size of shape ``(ndim,)``.
        If scalar, isotropic matrix is assumed.
    width : int | Iterable[int], optional
        Interpolation kernel full-width of shape ``(ndim,)``.
        If scalar, isotropic kernel is assumed.
        The default is ``3``.
    oversamp : float | Iterable[float], optional
        Grid oversampling factor of shape ``(ndim,)``.
        If scalar, isotropic oversampling is assumed.
        The default is ``1.125``.
    device : str, optional
        Computational device (``cpu`` or ``cuda:n``, with ``n=0, 1,...nGPUs``).
        The default is ``cpu``.
    Returns
    -------
    interpolator : NUFFTPlan
        Structure containing sparse interpolator matrix:
        * ndim (``int``): number of spatial dimensions.
        * oversampling (``Iterable[float]``): grid oversampling factor (z, y, x).
        * width (``Iterable[int]``): kernel width (z, y, x).
        * beta (``Iterable[float]``): Kaiser Bessel parameter (z, y, x).
        * os_shape (``Iterable[int]``): oversampled grid shape (z, y, x).
        * shape (``Iterable[int]``): grid shape (z, y, x).
        * interpolator (``Interpolator``): precomputed interpolator object.
        * device (``str``): computational device.
    Notes
    -----
    Non-uniform coordinates axes ordering is assumed to be ``(x, y)`` for 2D signals
    and ``(x, y, z)`` for 3D. Conversely, axes ordering for grid shape, kernel width
    and oversampling factors are assumed to be ``(y, x)`` and ``(z, y, x)``.
    Coordinates tensor shape is ``(ncontrasts, nviews, nsamples, ndim)``. If there are less dimensions
    (e.g., single-shot or single contrast trajectory), assume singleton for the missing ones:
    * ``coord.shape = (nsamples, ndim) -> (1, 1, nsamples, ndim)``
    * ``coord.shape = (nviews, nsamples, ndim) -> (1, nviews, nsamples, ndim)``
    """
    # make sure this is a tensor
    coord = torch.as_tensor(coord)
    # copy coord and switch to cpu
    coord = coord.clone().cpu().to(torch.float32)
    # get parameters
    ndim = coord.shape[-1]
    if np.isscalar(width):
        width = np.asarray([width] * ndim, dtype=np.int16)
    else:
        width = np.asarray(width, dtype=np.int16)
    if np.isscalar(oversamp):
        oversamp = np.asarray([oversamp] * ndim, dtype=np.float32)
    else:
        oversamp = np.asarray(oversamp, dtype=np.float32)
    # calculate Kaiser-Bessel beta parameter
    beta = math.pi * (((width / oversamp) * (oversamp - 0.5)) ** 2 - 0.8) ** 0.5
    if np.isscalar(shape):
        shape = np.asarray([shape] * ndim, dtype=np.int16)
    else:
        shape = np.asarray(shape, dtype=np.int16)[-ndim:]
    # check for Cartesian axes
    is_cart = [
        np.allclose(shape[ax] * coord[..., ax], np.round(shape[ax] * coord[..., ax]))
        for ax in range(ndim)
    ]
    is_cart = np.asarray(is_cart[::-1])  # (z, y, x)
    # Cartesian axes have osf = 1.0 and kernel width = 1 (no interpolation)
    oversamp[is_cart] = 1.0
    width[is_cart] = 1
    # get oversampled grid shape
    os_shape = _get_oversamp_shape(shape, oversamp, ndim)
    # rescale trajectory
    coord = _scale_coord(coord, shape[::-1], oversamp[::-1])
    # compute interpolator
    interpolator = _interp.plan_interpolator(coord, os_shape, width, beta, device)
    # transform to tuples
    ndim: int
    oversamp = tuple(oversamp)
    width = tuple(width)
    beta = tuple(beta)
    os_shape = tuple(os_shape)
    shape = tuple(shape)
    return NUFFTPlan(ndim, oversamp, width, beta, os_shape, shape, interpolator, device)
[docs]def plan_toeplitz_nufft(coord, shape, basis=None, dcf=None, width=4, device="cpu"):
    """
    Compute spatio-temporal kernel for fast self-adjoint operation.
    Parameters
    ----------
    coord : torch.Tensor
        K-space coordinates of shape ``(ncontrasts, nviews, nsamples, ndims)``.
        Coordinates must be normalized between ``(-0.5 * shape[i], 0.5 * shape[i])``,
        with ``i = (z, y, x)``.
    shape : int | Iterable[int]
        Oversampled grid size of shape ``(ndim,)``.
        If scalar, isotropic matrix is assumed.
    basis : torch.Tensor, optional
        Low rank subspace projection operator
        of shape ``(ncontrasts, ncoeffs)``; can be ``None``. The default is ``None``.
    dcf : torch.Tensor, optional
        Density compensation function of shape ``(ncontrasts, nviews, nsamples)``.
        The default is a tensor of ``1.0``.
    width : int | Iterable[int], optional
        Interpolation kernel full-width of shape ``(ndim,)``.
        If scalar, isotropic kernel is assumed.
        The default is ``3``.
    device : str, optional
        Computational device (``cpu`` or ``cuda:n``, with ``n=0, 1,...nGPUs``).
        The default is ``cpu``.
    Returns
    -------
    toeplitz_kernel : GramMatrix
        Structure containing Toeplitz kernel (i.e., Fourier transform of system tPSF).
    """
    return _toeplitz.plan_toeplitz(coord, shape, basis, dcf, width, device) 
class ApplyNUFFT(autograd.Function):
    @staticmethod
    def forward(image, nufft_plan, basis_adjoint, weight, device, threadsperblock):
        return _apply_nufft(
            image, nufft_plan, basis_adjoint, weight, device, threadsperblock
        )
    @staticmethod
    def setup_context(ctx, inputs, output):
        _, nufft_plan, basis_adjoint, weight, device, threadsperblock = inputs
        ctx.set_materialize_grads(False)
        ctx.nufft_plan = nufft_plan
        ctx.basis_adjoint = basis_adjoint
        ctx.weight = weight
        ctx.device = device
        ctx.threadsperblock = threadsperblock
    @staticmethod
    def backward(ctx, kspace):
        nufft_plan = ctx.nufft_plan
        basis_adjoint = ctx.basis_adjoint
        if basis_adjoint is not None:
            basis = basis_adjoint.conj().t()
        else:
            basis = None
        weight = ctx.weight
        device = ctx.device
        threadsperblock = ctx.threadsperblock
        return (
            _apply_nufft_adj(
                kspace, nufft_plan, basis, weight, device, threadsperblock
            ),
            None,
            None,
            None,
            None,
            None,
        )
def apply_nufft(
    image, nufft_plan, basis_adjoint=None, weight=None, device=None, threadsperblock=128
):
    """
    Apply Non-Uniform Fast Fourier Transform.
    Parameters
    ----------
    image : np.ndarray | torch.Tensor
        Input image of shape ``(..., ncontrasts, ny, nx)`` (2D)
        or ``(..., ncontrasts, nz, ny, nx)`` (3D).
    nufft_plan : NUFFTPlan
        Pre-calculated NUFFT plan coefficients in sparse COO format.
    basis_adjoint : torch.Tensor, optional
        Adjoint low rank subspace projection operator
        of shape ``(ncontrasts, ncoeffs)``; can be ``None``. The default is ``None``.
    weight : np.ndarray | torch.Tensor, optional
        Optional weight for output data samples. Useful to force adjointeness.
        The default is ``None``.
    device : str, optional
        Computational device (``cpu`` or ``cuda:n``, with ``n=0, 1,...nGPUs``).
        The default is ``None`` (same as interpolator).
    threadsperblock : int
        CUDA blocks size (for GPU only). The default is ``128``.
    Returns
    -------
    kspace : np.ndarray | torch.Tensor
        Output Non-Cartesian kspace of shape ``(..., ncontrasts, nviews, nsamples)``.
    """
    return ApplyNUFFT.apply(
        image, nufft_plan, basis_adjoint, weight, device, threadsperblock
    )
class ApplyNUFFTAdjoint(autograd.Function):
    @staticmethod
    def forward(kspace, nufft_plan, basis, weight, device, threadsperblock):
        return _apply_nufft_adj(
            kspace, nufft_plan, basis, weight, device, threadsperblock
        )
    @staticmethod
    def setup_context(ctx, inputs, output):
        _, nufft_plan, basis, weight, device, threadsperblock = inputs
        ctx.set_materialize_grads(False)
        ctx.nufft_plan = nufft_plan
        ctx.basis = basis
        ctx.weight = weight
        ctx.device = device
        ctx.threadsperblock = threadsperblock
    @staticmethod
    def backward(ctx, image):
        nufft_plan = ctx.nufft_plan
        basis = ctx.basis
        if basis is not None:
            basis_adjoint = basis.conj().t()
        else:
            basis_adjoint = None
        weight = ctx.weight
        device = ctx.device
        threadsperblock = ctx.threadsperblock
        return (
            _apply_nufft(
                image, nufft_plan, basis_adjoint, weight, device, threadsperblock
            ),
            None,
            None,
            None,
            None,
            None,
        )
def apply_nufft_adj(
    kspace, nufft_plan, basis=None, weight=None, device=None, threadsperblock=128
):
    """
    Apply adjoint Non-Uniform Fast Fourier Transform.
    Parameters
    ----------
    kspace : torch.Tensor
        Input kspace of shape ``(..., ncontrasts, nviews, nsamples)``.
    nufft_plan : NUFFTPlan
        Pre-calculated NUFFT plan coefficients in sparse COO format.
    basis : torch.Tensor, optional
        Low rank subspace projection operator
        of shape ``(ncontrasts, ncoeffs)``; can be ``None``. The default is ``None``.
    weight : np.ndarray | torch.Tensor, optional
        Optional weight for output data samples. Useful to force adjointeness.
        The default is ``None``.
    device : str, optional
        Computational device (``cpu`` or ``cuda:n``, with ``n=0, 1,...nGPUs``).
        The default is ``None ``(same as interpolator).
    threadsperblock : int
        CUDA blocks size (for GPU only). The default is ``128``.
    Returns
    -------
    image : torch.Tensor
        Output image of shape ``(..., ncontrasts, ny, nx)`` (2D)
        or ``(..., ncontrasts, nz, ny, nx)`` (3D).
    """
    return ApplyNUFFTAdjoint.apply(
        kspace, nufft_plan, basis, weight, device, threadsperblock
    )
class ApplyNUFFTSelfAdjoint(autograd.Function):
    @staticmethod
    def forward(image, toeplitz_kern, device, threadsperblock):
        return _apply_nufft_selfadj(image, toeplitz_kern, device, threadsperblock)
    @staticmethod
    def setup_context(ctx, inputs, output):
        _, toeplitz_kern, device, threadsperblock = inputs
        ctx.set_materialize_grads(False)
        ctx.toeplitz_kern = toeplitz_kern
        ctx.device = device
        ctx.threadsperblock = threadsperblock
    @staticmethod
    def backward(ctx, image):
        toeplitz_kern = ctx.toeplitz_kern
        device = ctx.device
        threadsperblock = ctx.threadsperblock
        return (
            _apply_nufft_selfadj(image, toeplitz_kern, device, threadsperblock),
            None,
            None,
            None,
        )
[docs]def apply_nufft_selfadj(image, toeplitz_kern, device=None, threadsperblock=128):
    """
    Apply self-adjoint Non-Uniform Fast Fourier Transform via Toeplitz Convolution.
    Parameters
    ----------
    image : torch.Tensor
        Input image of shape ``(..., ncontrasts, ny, nx)`` (2D)
        or ``(..., ncontrasts, nz, ny, nx)`` (3D).
    toeplitz_kern : GramMatrix
        Pre-calculated Toeplitz kernel.
    device : str, optional
        Computational device (``cpu`` or ``cuda:n``, with ``n=0, 1,...nGPUs``).
        The default is ``None ``(same as interpolator).
    threadsperblock : int
        CUDA blocks size (for GPU only). The default is ``128``.
    Returns
    -------
    image : torch.Tensor
        Output image of shape ``(..., ncontrasts, ny, nx)`` (2D)
        or ``(..., ncontrasts, nz, ny, nx)`` (3D).
    """
    return ApplyNUFFTSelfAdjoint.apply(image, toeplitz_kern, device, threadsperblock) 
# %% local utils
@dataclass
class NUFFTPlan:
    ndim: int
    oversamp: tuple
    width: tuple
    beta: tuple
    os_shape: tuple
    shape: tuple
    interpolator: object
    device: str
    def to(self, device):
        """
        Dispatch internal attributes to selected device.
        Parameters
        ----------
        device : str
            Computational device ("cpu" or "cuda:n", with n=0, 1,...nGPUs).
        """
        if device != self.device:
            self.interpolator.to(device)
            self.device = device
        return self
def _get_oversamp_shape(shape, oversamp, ndim):
    return np.ceil(oversamp * shape).astype(np.int16)
def _scale_coord(coord, shape, oversamp):
    ndim = coord.shape[-1]
    output = coord.clone()
    for i in range(-ndim, 0):
        scale = np.ceil(oversamp[i] * shape[i]) / shape[i]
        shift = np.ceil(oversamp[i] * shape[i]) // 2
        output[..., i] *= scale
        output[..., i] += shift
    return output
def _apodize(data_in, ndim, oversamp, width, beta):
    data_out = data_in
    for n in range(1, ndim + 1):
        axis = -n
        if width[axis] != 1:
            i = data_out.shape[axis]
            os_i = np.ceil(oversamp[axis] * i)
            idx = torch.arange(i, dtype=torch.float32, device=data_in.device)
            # Calculate apodization
            apod = (
                beta[axis] ** 2 - (math.pi * width[axis] * (idx - i // 2) / os_i) ** 2
            ) ** 0.5
            apod /= torch.sinh(apod)
            # normalize by DC
            apod = apod / apod[int(i // 2)]
            # avoid NaN
            apod = torch.nan_to_num(apod, nan=1.0)
            # apply to axis
            data_out *= apod.reshape([i] + [1] * (-axis - 1))
    return data_out
def _apply_nufft(image, nufft_plan, basis_adjoint, weight, device, threadsperblock):
    # check if it is numpy
    if isinstance(image, np.ndarray):
        isnumpy = True
    else:
        isnumpy = False
    # convert to tensor if nececessary
    image = torch.as_tensor(image)
    # make sure datatype is correct
    if image.dtype in (torch.float16, torch.float32, torch.float64):
        image = image.to(torch.float32)
    else:
        image = image.to(torch.complex64)
    # handle basis
    if basis_adjoint is not None:
        basis_adjoint = torch.as_tensor(basis_adjoint)
        # make sure datatype is correct
        if basis_adjoint.dtype in (torch.float16, torch.float32, torch.float64):
            basis_adjoint = basis_adjoint.to(torch.float32)
        else:
            basis_adjoint = basis_adjoint.to(torch.complex64)
    # cast tp device is necessary
    if device is not None:
        nufft_plan.to(device)
    # unpack plan
    ndim = nufft_plan.ndim
    oversamp = nufft_plan.oversamp
    width = nufft_plan.width
    beta = nufft_plan.beta
    os_shape = nufft_plan.os_shape
    interpolator = nufft_plan.interpolator
    device = nufft_plan.device
    # copy input to avoid original data modification
    image = image.clone()
    # original device
    odevice = image.device
    # offload to computational device
    image = image.to(device)
    # apodize
    _apodize(image, ndim, oversamp, width, beta)
    # zero-pad
    image = _resize(image, list(image.shape[:-ndim]) + list(os_shape))
    # FFT
    kspace = _fft.fft(image, axes=range(-ndim, 0), norm=None)
    # interpolate
    kspace = _interp.apply_interpolation(
        kspace, interpolator, basis_adjoint, device, threadsperblock
    )
    # apply weight
    if weight is not None:
        weight = torch.as_tensor(weight, dtype=torch.float32, device=kspace.device)
        kspace = weight * kspace
    # bring back to original device
    kspace = kspace.to(odevice)
    image = image.to(odevice)
    # transform back to numpy if required
    if isnumpy:
        kspace = kspace.numpy(force=True)
    # collect garbage
    gc.collect()
    return kspace
def _apply_nufft_adj(kspace, nufft_plan, basis, weight, device, threadsperblock):
    # check if it is numpy
    if isinstance(kspace, np.ndarray):
        isnumpy = True
    else:
        isnumpy = False
    # convert to tensor if nececessary
    kspace = torch.as_tensor(kspace)
    # make sure datatype is correct
    if kspace.dtype in (torch.float16, torch.float32, torch.float64):
        kspace = kspace.to(torch.float32)
    else:
        kspace = kspace.to(torch.complex64)
    # handle basis
    if basis is not None:
        basis = torch.as_tensor(basis)
        # make sure datatype is correct
        if basis.dtype in (torch.float16, torch.float32, torch.float64):
            basis = basis.to(torch.float32)
        else:
            basis = basis.to(torch.complex64)
    # cast to device is necessary
    if device is not None:
        nufft_plan.to(device)
    # unpack plan
    ndim = nufft_plan.ndim
    oversamp = nufft_plan.oversamp
    width = nufft_plan.width
    beta = nufft_plan.beta
    shape = nufft_plan.shape
    interpolator = nufft_plan.interpolator
    device = nufft_plan.device
    # original device
    odevice = kspace.device
    # offload to computational device
    kspace = kspace.to(device)
    # apply weight
    if weight is not None:
        weight = torch.as_tensor(weight, dtype=torch.float32, device=kspace.device)
        kspace = weight * kspace
    # gridding
    kspace = _interp.apply_gridding(
        kspace, interpolator, basis, device, threadsperblock
    )
    # IFFT
    image = _fft.ifft(kspace, axes=range(-ndim, 0), norm=None)
    # crop
    image = _resize(image, list(image.shape[:-ndim]) + list(shape))
    # apodize
    _apodize(image, ndim, oversamp, width, beta)
    # bring back to original device
    kspace = kspace.to(odevice)
    image = image.to(odevice)
    # transform back to numpy if required
    if isnumpy:
        image = image.numpy(force=True)
    # collect garbage
    gc.collect()
    return image
def _apply_nufft_selfadj(image, toeplitz_kern, device, threadsperblock):
    # check if it is numpy
    if isinstance(image, np.ndarray):
        isnumpy = True
    else:
        isnumpy = False
    # convert to tensor if nececessary
    image = torch.as_tensor(image)
    # make sure datatype is correct
    if image.dtype in (torch.float16, torch.float32, torch.float64):
        image = image.to(torch.float32)
    else:
        image = image.to(torch.complex64)
    # cast to device is necessary
    if device is not None:
        toeplitz_kern.to(device)
    # unpack plan
    shape = toeplitz_kern.shape
    ndim = toeplitz_kern.ndim
    device = toeplitz_kern.device
    # original shape
    oshape = image.shape[-ndim:]
    # original device
    odevice = image.device
    # offload to computational device
    image = image.to(device)
    # zero-pad
    image = _resize(image, list(image.shape[:-ndim]) + list(shape))
    # FFT
    kspace = _fft.fft(image, axes=range(-ndim, 0), norm="ortho", centered=False)
    # Toeplitz convolution
    tmp = torch.zeros_like(kspace)
    tmp = _interp.apply_toeplitz(tmp, kspace, toeplitz_kern, device, threadsperblock)
    # IFFT
    image = _fft.ifft(tmp, axes=range(-ndim, 0), norm="ortho", centered=False)
    # crop
    image = _resize(image, list(image.shape[:-ndim]) + list(oshape))
    # bring back to original device
    image = image.to(odevice)
    # transform back to numpy if required
    if isnumpy:
        image = image.numpy(force=True)
    # collect garbage
    gc.collect()
    return image