"""Pytorch ESPIRIT implementation. Adapted for convenience from https://github.com/mikgroup/espirit-python/tree/master"""
__all__ = ["espirit_cal"]
import numpy as np
import torch
from ... import fft as _fft
from ... import _signal
from . import acs as _acs
[docs]def espirit_cal(
    data, coord=None, dcf=None, shape=None, k=6, r=24, t=0.02, c=0.0, nsets=1
):
    """
    Derives the ESPIRiT [1] operator.
    Parameters
    ----------
    data : np.ndarray | torch.Tensor
        Multi channel k-space data.
    coord : np.ndarray | torch.Tensor, optional
        K-space trajectory of ``shape = (ncontrasts, nviews, nsamples, ndim)``.
        The default is ``None`` (Cartesian acquisition).
    dcf : np.ndarray | torch.Tensor, optional
        K-space density compensation of ``shape = (ncontrasts, nviews, nsamples)``.
        The default is ``None`` (no compensation).
    shape : Iterable[int] | optional
        Shape of the k-space after gridding. If not provided, estimate from
        input data (assumed on a Cartesian grid already).
        The default is ``None`` (Cartesian acquisition).
    k : int, optional
        k-space kernel size. The default is ``6``.
    r : int, optional
        Calibration region size. The default is ``24``.
    t : float, optional
        Rank of the auto-calibration matrix (A).
        The default is ``0.02``.
    c : float, optional
        Crop threshold that determines eigenvalues "=1".
        The defaults is ``0.95``.
    nsets : int, optional
        Number of set of maps to be returned.
        The default is ``1`` (conventional SENSE recon).
    Returns
    -------
    maps : np.ndarray | torch.Tensor
        Output coil sensitivity maps.
    Notes
    -----
    The input k-space ``data`` tensor is assumed to have the following shape:
    * **2Dcart:** ``(nslices, ncoils, ..., ny, nx)``.
    * **2Dnoncart:** ``(nslices, ncoils, ..., nviews, nsamples)``.
    * **3Dcart:** ``(nx, ncoils, ..., nz, ny)``.
    * **3Dnoncart:** ``(ncoils, ..., nviews, nsamples)``.
    For multi-contrast acquisitions, calibration is obtained by averaging over
    contrast dimensions.
    The output sensitivity maps are assumed to have the following shape:
    * **2Dcart:** ``(nsets, nslices, ncoils, ny, nx)``.
    * **2Dnoncart:** ``(nsets, nslices, ncoils, ny, nx)``.
    * **3Dcart:** ``(nsets, nx, ncoils, nz, ny)``.
    * **3Dnoncart:** ``(nsets, ncoils, nz, ny, nx)``.
    References
    ----------
    .. [1] Uecker M, Lai P, Murphy MJ, Virtue P, Elad M, Pauly JM, Vasanawala SS, Lustig M.
           ESPIRiT--an eigenvalue approach to autocalibrating parallel MRI: where SENSE meets GRAPPA.
           Magn Reson Med. 2014 Mar;71(3):990-1001. doi: 10.1002/mrm.24751. PMID: 23649942; PMCID: PMC4142121.
    """
    if isinstance(data, np.ndarray):
        isnumpy = True
    else:
        isnumpy = False
    while len(data.shape) < 5:
        data = data[None, ...]
    # keep shape
    if coord is not None:
        ndim = coord.shape[-1]
        if np.isscalar(shape):
            shape = ndim * [shape]
        else:
            shape = list(shape)[-ndim:]
        shape = [int(s) for s in shape]
    else:
        ndim = 2
        shape = list(data.shape[-2:])
    # extract calibration region
    cshape = list(np.asarray(shape, dtype=int) // 2)
    cal_data = _acs.find_acs(data, cshape, coord, dcf)
    # calculate maps
    maps = _espirit(cal_data.clone(), k, r, t, c)
    # select maps
    if nsets == 1:
        maps = maps[[0]]
    else:
        maps = maps[:nsets]
    # resample
    maps = _signal.resample(maps, shape)  # (nsets, ncoils, nz, ny, nx)
    # normalize
    maps_rss = _signal.rss(maps, axis=1, keepdim=True)
    maps = maps / maps_rss[[0]]
    # reformat
    if ndim == 2:  # Cartesian or 2D Non-Cartesian
        maps = maps.swapaxes(
            1, 2
        )  # (nsets, nslices, ncoils, ny, nx) / (nsets, nx, ncoils, nz, ny)
    # cast back to numpy if required
    if isnumpy:
        maps = maps.numpy(force=True)
    return maps, _signal.resize(cal_data, ndim * [r]) 
# %% local utils
def _espirit(X, k, r, t, c):
    # transpose
    X = X.permute(3, 2, 1, 0)
    # get shape
    sx, sy, sz, nc = X.shape
    sxt = (sx // 2 - r // 2, sx // 2 + r // 2) if (sx > 1) else (0, 1)
    syt = (sy // 2 - r // 2, sy // 2 + r // 2) if (sy > 1) else (0, 1)
    szt = (sz // 2 - r // 2, sz // 2 + r // 2) if (sz > 1) else (0, 1)
    # Extract calibration region.
    C = X[sxt[0] : sxt[1], syt[0] : syt[1], szt[0] : szt[1], :].to(
        dtype=torch.complex64
    )
    # Construct Hankel matrix.
    p = (sx > 1) + (sy > 1) + (sz > 1)
    A = torch.zeros(
        [(r - k + 1) ** p, k**p * nc], dtype=torch.complex64, device=X.device
    )
    idx = 0
    for xdx in range(max(1, C.shape[0] - k + 1)):
        for ydx in range(max(1, C.shape[1] - k + 1)):
            for zdx in range(max(1, C.shape[2] - k + 1)):
                block = C[xdx : xdx + k, ydx : ydx + k, zdx : zdx + k, :].to(
                    dtype=torch.complex64
                )
                A[idx, :] = block.flatten()
                idx += 1
    # Take the Singular Value Decomposition.
    U, S, VH = torch.linalg.svd(A, full_matrices=True)
    V = VH.conj().t()
    # Select kernels
    n = torch.sum(S >= t * S[0])
    V = V[:, :n]
    kxt = (sx // 2 - k // 2, sx // 2 + k // 2) if (sx > 1) else (0, 1)
    kyt = (sy // 2 - k // 2, sy // 2 + k // 2) if (sy > 1) else (0, 1)
    kzt = (sz // 2 - k // 2, sz // 2 + k // 2) if (sz > 1) else (0, 1)
    # Reshape into k-space kernel, flips it and takes the conjugate
    kernels = torch.zeros((sx, sy, sz, nc, n), dtype=torch.complex64, device=X.device)
    kerdims = [
        ((sx > 1) * k + (sx == 1) * 1),
        ((sy > 1) * k + (sy == 1) * 1),
        ((sz > 1) * k + (sz == 1) * 1),
        nc,
    ]
    for idx in range(n):
        kernels[kxt[0] : kxt[1], kyt[0] : kyt[1], kzt[0] : kzt[1], :, idx] = V[
            :, idx
        ].reshape(kerdims)
    # Take the iucfft
    axes = (0, 1, 2)
    kerimgs = (
        _fft.fft(kernels.flip(0).flip(1).flip(2).conj(), axes)
        * (sx * sy * sz) ** 0.5
        / (k**p) ** 0.5
    )
    # Take the point-wise eigenvalue decomposition and keep eigenvalues greater than c
    u, s, vh = torch.linalg.svd(
        kerimgs.view(sx, sy, sz, nc, n).reshape(-1, nc, n), full_matrices=True
    )
    mask = s**2 > c
    # mask u (nvoxels, neigen, neigen)
    u = mask[:, None, :] * u
    # Reshape back to the original shape and assign to maps
    maps = u.view(sx, sy, sz, nc, nc)
    # transpose
    maps = maps.permute(4, 3, 2, 1, 0)
    return maps