Source code for deepmr.fft.nufft

"""NUFFT subroutines."""

__all__ = [

import gc
import math

from dataclasses import dataclass

import numpy as np
import torch
import torch.autograd as autograd

from .._signal import resize as _resize

from . import fft as _fft
from . import _interp
from . import toeplitz as _toeplitz

def plan_nufft(coord, shape, width=4, oversamp=1.25, device="cpu"):
    Precompute NUFFT object.

    coord : torch.Tensor
        K-space coordinates of shape ``(ncontrasts, nviews, nsamples, ndims)``.
        Coordinates must be normalized between ``(-0.5 * shape, 0.5 * shape)``.
    shape : int | Iterable[int]
        Oversampled grid size of shape ``(ndim,)``.
        If scalar, isotropic matrix is assumed.
    width : int | Iterable[int], optional
        Interpolation kernel full-width of shape ``(ndim,)``.
        If scalar, isotropic kernel is assumed.
        The default is ``3``.
    oversamp : float | Iterable[float], optional
        Grid oversampling factor of shape ``(ndim,)``.
        If scalar, isotropic oversampling is assumed.
        The default is ``1.125``.
    device : str, optional
        Computational device (``cpu`` or ``cuda:n``, with ``n=0, 1,...nGPUs``).
        The default is ``cpu``.

    interpolator : NUFFTPlan
        Structure containing sparse interpolator matrix:

        * ndim (``int``): number of spatial dimensions.
        * oversampling (``Iterable[float]``): grid oversampling factor (z, y, x).
        * width (``Iterable[int]``): kernel width (z, y, x).
        * beta (``Iterable[float]``): Kaiser Bessel parameter (z, y, x).
        * os_shape (``Iterable[int]``): oversampled grid shape (z, y, x).
        * shape (``Iterable[int]``): grid shape (z, y, x).
        * interpolator (``Interpolator``): precomputed interpolator object.
        * device (``str``): computational device.

    Non-uniform coordinates axes ordering is assumed to be ``(x, y)`` for 2D signals
    and ``(x, y, z)`` for 3D. Conversely, axes ordering for grid shape, kernel width
    and oversampling factors are assumed to be ``(y, x)`` and ``(z, y, x)``.

    Coordinates tensor shape is ``(ncontrasts, nviews, nsamples, ndim)``. If there are less dimensions
    (e.g., single-shot or single contrast trajectory), assume singleton for the missing ones:

    * ``coord.shape = (nsamples, ndim) -> (1, 1, nsamples, ndim)``
    * ``coord.shape = (nviews, nsamples, ndim) -> (1, nviews, nsamples, ndim)``

    # make sure this is a tensor
    coord = torch.as_tensor(coord)

    # copy coord and switch to cpu
    coord = coord.clone().cpu().to(torch.float32)

    # get parameters
    ndim = coord.shape[-1]

    if np.isscalar(width):
        width = np.asarray([width] * ndim, dtype=np.int16)
        width = np.asarray(width, dtype=np.int16)

    if np.isscalar(oversamp):
        oversamp = np.asarray([oversamp] * ndim, dtype=np.float32)
        oversamp = np.asarray(oversamp, dtype=np.float32)

    # calculate Kaiser-Bessel beta parameter
    beta = math.pi * (((width / oversamp) * (oversamp - 0.5)) ** 2 - 0.8) ** 0.5
    if np.isscalar(shape):
        shape = np.asarray([shape] * ndim, dtype=np.int16)
        shape = np.asarray(shape, dtype=np.int16)[-ndim:]

    # check for Cartesian axes
    is_cart = [
        np.allclose(shape[ax] * coord[..., ax], np.round(shape[ax] * coord[..., ax]))
        for ax in range(ndim)
    is_cart = np.asarray(is_cart[::-1])  # (z, y, x)

    # Cartesian axes have osf = 1.0 and kernel width = 1 (no interpolation)
    oversamp[is_cart] = 1.0
    width[is_cart] = 1

    # get oversampled grid shape
    os_shape = _get_oversamp_shape(shape, oversamp, ndim)

    # rescale trajectory
    coord = _scale_coord(coord, shape[::-1], oversamp[::-1])

    # compute interpolator
    interpolator = _interp.plan_interpolator(coord, os_shape, width, beta, device)

    # transform to tuples
    ndim: int
    oversamp = tuple(oversamp)
    width = tuple(width)
    beta = tuple(beta)
    os_shape = tuple(os_shape)
    shape = tuple(shape)

    return NUFFTPlan(ndim, oversamp, width, beta, os_shape, shape, interpolator, device)

[docs]def plan_toeplitz_nufft(coord, shape, basis=None, dcf=None, width=4, device="cpu"): """ Compute spatio-temporal kernel for fast self-adjoint operation. Parameters ---------- coord : torch.Tensor K-space coordinates of shape ``(ncontrasts, nviews, nsamples, ndims)``. Coordinates must be normalized between ``(-0.5 * shape[i], 0.5 * shape[i])``, with ``i = (z, y, x)``. shape : int | Iterable[int] Oversampled grid size of shape ``(ndim,)``. If scalar, isotropic matrix is assumed. basis : torch.Tensor, optional Low rank subspace projection operator of shape ``(ncontrasts, ncoeffs)``; can be ``None``. The default is ``None``. dcf : torch.Tensor, optional Density compensation function of shape ``(ncontrasts, nviews, nsamples)``. The default is a tensor of ``1.0``. width : int | Iterable[int], optional Interpolation kernel full-width of shape ``(ndim,)``. If scalar, isotropic kernel is assumed. The default is ``3``. device : str, optional Computational device (``cpu`` or ``cuda:n``, with ``n=0, 1,...nGPUs``). The default is ``cpu``. Returns ------- toeplitz_kernel : GramMatrix Structure containing Toeplitz kernel (i.e., Fourier transform of system tPSF). """ return _toeplitz.plan_toeplitz(coord, shape, basis, dcf, width, device)
class ApplyNUFFT(autograd.Function): @staticmethod def forward(image, nufft_plan, basis_adjoint, weight, device, threadsperblock): return _apply_nufft( image, nufft_plan, basis_adjoint, weight, device, threadsperblock ) @staticmethod def setup_context(ctx, inputs, output): _, nufft_plan, basis_adjoint, weight, device, threadsperblock = inputs ctx.set_materialize_grads(False) ctx.nufft_plan = nufft_plan ctx.basis_adjoint = basis_adjoint ctx.weight = weight ctx.device = device ctx.threadsperblock = threadsperblock @staticmethod def backward(ctx, kspace): nufft_plan = ctx.nufft_plan basis_adjoint = ctx.basis_adjoint if basis_adjoint is not None: basis = basis_adjoint.conj().t() else: basis = None weight = ctx.weight device = ctx.device threadsperblock = ctx.threadsperblock return ( _apply_nufft_adj( kspace, nufft_plan, basis, weight, device, threadsperblock ), None, None, None, None, None, ) def apply_nufft( image, nufft_plan, basis_adjoint=None, weight=None, device=None, threadsperblock=128 ): """ Apply Non-Uniform Fast Fourier Transform. Parameters ---------- image : np.ndarray | torch.Tensor Input image of shape ``(..., ncontrasts, ny, nx)`` (2D) or ``(..., ncontrasts, nz, ny, nx)`` (3D). nufft_plan : NUFFTPlan Pre-calculated NUFFT plan coefficients in sparse COO format. basis_adjoint : torch.Tensor, optional Adjoint low rank subspace projection operator of shape ``(ncontrasts, ncoeffs)``; can be ``None``. The default is ``None``. weight : np.ndarray | torch.Tensor, optional Optional weight for output data samples. Useful to force adjointeness. The default is ``None``. device : str, optional Computational device (``cpu`` or ``cuda:n``, with ``n=0, 1,...nGPUs``). The default is ``None`` (same as interpolator). threadsperblock : int CUDA blocks size (for GPU only). The default is ``128``. Returns ------- kspace : np.ndarray | torch.Tensor Output Non-Cartesian kspace of shape ``(..., ncontrasts, nviews, nsamples)``. """ return ApplyNUFFT.apply( image, nufft_plan, basis_adjoint, weight, device, threadsperblock ) class ApplyNUFFTAdjoint(autograd.Function): @staticmethod def forward(kspace, nufft_plan, basis, weight, device, threadsperblock): return _apply_nufft_adj( kspace, nufft_plan, basis, weight, device, threadsperblock ) @staticmethod def setup_context(ctx, inputs, output): _, nufft_plan, basis, weight, device, threadsperblock = inputs ctx.set_materialize_grads(False) ctx.nufft_plan = nufft_plan ctx.basis = basis ctx.weight = weight ctx.device = device ctx.threadsperblock = threadsperblock @staticmethod def backward(ctx, image): nufft_plan = ctx.nufft_plan basis = ctx.basis if basis is not None: basis_adjoint = basis.conj().t() else: basis_adjoint = None weight = ctx.weight device = ctx.device threadsperblock = ctx.threadsperblock return ( _apply_nufft( image, nufft_plan, basis_adjoint, weight, device, threadsperblock ), None, None, None, None, None, ) def apply_nufft_adj( kspace, nufft_plan, basis=None, weight=None, device=None, threadsperblock=128 ): """ Apply adjoint Non-Uniform Fast Fourier Transform. Parameters ---------- kspace : torch.Tensor Input kspace of shape ``(..., ncontrasts, nviews, nsamples)``. nufft_plan : NUFFTPlan Pre-calculated NUFFT plan coefficients in sparse COO format. basis : torch.Tensor, optional Low rank subspace projection operator of shape ``(ncontrasts, ncoeffs)``; can be ``None``. The default is ``None``. weight : np.ndarray | torch.Tensor, optional Optional weight for output data samples. Useful to force adjointeness. The default is ``None``. device : str, optional Computational device (``cpu`` or ``cuda:n``, with ``n=0, 1,...nGPUs``). The default is ``None ``(same as interpolator). threadsperblock : int CUDA blocks size (for GPU only). The default is ``128``. Returns ------- image : torch.Tensor Output image of shape ``(..., ncontrasts, ny, nx)`` (2D) or ``(..., ncontrasts, nz, ny, nx)`` (3D). """ return ApplyNUFFTAdjoint.apply( kspace, nufft_plan, basis, weight, device, threadsperblock ) class ApplyNUFFTSelfAdjoint(autograd.Function): @staticmethod def forward(image, toeplitz_kern, device, threadsperblock): return _apply_nufft_selfadj(image, toeplitz_kern, device, threadsperblock) @staticmethod def setup_context(ctx, inputs, output): _, toeplitz_kern, device, threadsperblock = inputs ctx.set_materialize_grads(False) ctx.toeplitz_kern = toeplitz_kern ctx.device = device ctx.threadsperblock = threadsperblock @staticmethod def backward(ctx, image): toeplitz_kern = ctx.toeplitz_kern device = ctx.device threadsperblock = ctx.threadsperblock return ( _apply_nufft_selfadj(image, toeplitz_kern, device, threadsperblock), None, None, None, )
[docs]def apply_nufft_selfadj(image, toeplitz_kern, device=None, threadsperblock=128): """ Apply self-adjoint Non-Uniform Fast Fourier Transform via Toeplitz Convolution. Parameters ---------- image : torch.Tensor Input image of shape ``(..., ncontrasts, ny, nx)`` (2D) or ``(..., ncontrasts, nz, ny, nx)`` (3D). toeplitz_kern : GramMatrix Pre-calculated Toeplitz kernel. device : str, optional Computational device (``cpu`` or ``cuda:n``, with ``n=0, 1,...nGPUs``). The default is ``None ``(same as interpolator). threadsperblock : int CUDA blocks size (for GPU only). The default is ``128``. Returns ------- image : torch.Tensor Output image of shape ``(..., ncontrasts, ny, nx)`` (2D) or ``(..., ncontrasts, nz, ny, nx)`` (3D). """ return ApplyNUFFTSelfAdjoint.apply(image, toeplitz_kern, device, threadsperblock)
# %% local utils @dataclass class NUFFTPlan: ndim: int oversamp: tuple width: tuple beta: tuple os_shape: tuple shape: tuple interpolator: object device: str def to(self, device): """ Dispatch internal attributes to selected device. Parameters ---------- device : str Computational device ("cpu" or "cuda:n", with n=0, 1,...nGPUs). """ if device != self.device: self.device = device return self def _get_oversamp_shape(shape, oversamp, ndim): return np.ceil(oversamp * shape).astype(np.int16) def _scale_coord(coord, shape, oversamp): ndim = coord.shape[-1] output = coord.clone() for i in range(-ndim, 0): scale = np.ceil(oversamp[i] * shape[i]) / shape[i] shift = np.ceil(oversamp[i] * shape[i]) // 2 output[..., i] *= scale output[..., i] += shift return output def _apodize(data_in, ndim, oversamp, width, beta): data_out = data_in for n in range(1, ndim + 1): axis = -n if width[axis] != 1: i = data_out.shape[axis] os_i = np.ceil(oversamp[axis] * i) idx = torch.arange(i, dtype=torch.float32, device=data_in.device) # Calculate apodization apod = ( beta[axis] ** 2 - (math.pi * width[axis] * (idx - i // 2) / os_i) ** 2 ) ** 0.5 apod /= torch.sinh(apod) # normalize by DC apod = apod / apod[int(i // 2)] # avoid NaN apod = torch.nan_to_num(apod, nan=1.0) # apply to axis data_out *= apod.reshape([i] + [1] * (-axis - 1)) return data_out def _apply_nufft(image, nufft_plan, basis_adjoint, weight, device, threadsperblock): # check if it is numpy if isinstance(image, np.ndarray): isnumpy = True else: isnumpy = False # convert to tensor if nececessary image = torch.as_tensor(image) # make sure datatype is correct if image.dtype in (torch.float16, torch.float32, torch.float64): image = else: image = # handle basis if basis_adjoint is not None: basis_adjoint = torch.as_tensor(basis_adjoint) # make sure datatype is correct if basis_adjoint.dtype in (torch.float16, torch.float32, torch.float64): basis_adjoint = else: basis_adjoint = # cast tp device is necessary if device is not None: # unpack plan ndim = nufft_plan.ndim oversamp = nufft_plan.oversamp width = nufft_plan.width beta = nufft_plan.beta os_shape = nufft_plan.os_shape interpolator = nufft_plan.interpolator device = nufft_plan.device # copy input to avoid original data modification image = image.clone() # original device odevice = image.device # offload to computational device image = # apodize _apodize(image, ndim, oversamp, width, beta) # zero-pad image = _resize(image, list(image.shape[:-ndim]) + list(os_shape)) # FFT kspace = _fft.fft(image, axes=range(-ndim, 0), norm=None) # interpolate kspace = _interp.apply_interpolation( kspace, interpolator, basis_adjoint, device, threadsperblock ) # apply weight if weight is not None: weight = torch.as_tensor(weight, dtype=torch.float32, device=kspace.device) kspace = weight * kspace # bring back to original device kspace = image = # transform back to numpy if required if isnumpy: kspace = kspace.numpy(force=True) # collect garbage gc.collect() return kspace def _apply_nufft_adj(kspace, nufft_plan, basis, weight, device, threadsperblock): # check if it is numpy if isinstance(kspace, np.ndarray): isnumpy = True else: isnumpy = False # convert to tensor if nececessary kspace = torch.as_tensor(kspace) # make sure datatype is correct if kspace.dtype in (torch.float16, torch.float32, torch.float64): kspace = else: kspace = # handle basis if basis is not None: basis = torch.as_tensor(basis) # make sure datatype is correct if basis.dtype in (torch.float16, torch.float32, torch.float64): basis = else: basis = # cast to device is necessary if device is not None: # unpack plan ndim = nufft_plan.ndim oversamp = nufft_plan.oversamp width = nufft_plan.width beta = nufft_plan.beta shape = nufft_plan.shape interpolator = nufft_plan.interpolator device = nufft_plan.device # original device odevice = kspace.device # offload to computational device kspace = # apply weight if weight is not None: weight = torch.as_tensor(weight, dtype=torch.float32, device=kspace.device) kspace = weight * kspace # gridding kspace = _interp.apply_gridding( kspace, interpolator, basis, device, threadsperblock ) # IFFT image = _fft.ifft(kspace, axes=range(-ndim, 0), norm=None) # crop image = _resize(image, list(image.shape[:-ndim]) + list(shape)) # apodize _apodize(image, ndim, oversamp, width, beta) # bring back to original device kspace = image = # transform back to numpy if required if isnumpy: image = image.numpy(force=True) # collect garbage gc.collect() return image def _apply_nufft_selfadj(image, toeplitz_kern, device, threadsperblock): # check if it is numpy if isinstance(image, np.ndarray): isnumpy = True else: isnumpy = False # convert to tensor if nececessary image = torch.as_tensor(image) # make sure datatype is correct if image.dtype in (torch.float16, torch.float32, torch.float64): image = else: image = # cast to device is necessary if device is not None: # unpack plan shape = toeplitz_kern.shape ndim = toeplitz_kern.ndim device = toeplitz_kern.device # original shape oshape = image.shape[-ndim:] # original device odevice = image.device # offload to computational device image = # zero-pad image = _resize(image, list(image.shape[:-ndim]) + list(shape)) # FFT kspace = _fft.fft(image, axes=range(-ndim, 0), norm="ortho", centered=False) # Toeplitz convolution tmp = torch.zeros_like(kspace) tmp = _interp.apply_toeplitz(tmp, kspace, toeplitz_kern, device, threadsperblock) # IFFT image = _fft.ifft(tmp, axes=range(-ndim, 0), norm="ortho", centered=False) # crop image = _resize(image, list(image.shape[:-ndim]) + list(oshape)) # bring back to original device image = # transform back to numpy if required if isnumpy: image = image.numpy(force=True) # collect garbage gc.collect() return image