"""Three-dimensional stack-of-spirals sampling."""
__all__ = ["spiral_stack"]
import numpy as np
# this is for stupid Sphinx
try:
    from ... import _design
except Exception:
    pass
from ..._types import Header
[docs]def spiral_stack(shape, accel=None, nintl=1, **kwargs):
    r"""
    Design a constant- or multi-density stack of spirals.
    As in the 2D spiral case, interleaves are rotated by a pseudo golden angle
    with period 377 interelaves. Rotations are performed both along
    ``view`` and ``contrast`` dimensions. Acquisition is assumed to
    traverse the ``contrast`` dimension first and then the ``view``,
    i.e., all the contrasts are acquired before moving to the second view.
    If multiple echoes are specified, final contrast dimensions will have
    length ``ncontrasts * nechoes``. Echoes are assumed to be acquired
    sequentially with the same spiral interleaf.
    Finally, slice dimension is assumed to be the outermost loop.
    Parameters
    ----------
    shape : Iterable[int]
        Matrix shape ``(in-plane, slices=1, contrasts=1, echoes=1)``.
    accel : Iterable[int], optional
        Acceleration factors ``(in-plane, slices=1)``.
        Range from ``1`` (fully sampled) to ``nintl`` / ``nslices``.
        The default is ``(1, 1)``.
    nintl : int, optional
        Number of interleaves to fully sample a plane.
        The default is ``1``.
    Keyword Arguments
    -----------------
    moco_shape : int
        Matrix size for inner-most (motion navigation) spiral.
        The default is ``None``.
    acs_shape : Iterable[int]
        Matrix size for intermediate inner (coil sensitivity estimation) spiral.
        The default is (``None``, ``None``).
    acs_nintl : int
        Number of interleaves to fully sample intermediate inner spiral.
        The default is ``1``.
    variant : str
        Type of spiral. Allowed values are:
        * ``center-out``: starts at the center of k-space and ends at the edge (default).
        * ``reverse``: starts at the edge of k-space and ends at the center.
        * ``in-out``: starts at the edge of k-space and ends on the opposite side (two 180° rotated arms back-to-back).
    Returns
    -------
    head : Header
        Acquisition header corresponding to the generated spiral.
    Example
    -------
    >>> import deepmr
    We can create a single-shot stack-of-spirals for a ``(128, 128, 120)`` voxels matrix by:
    >>> head = deepmr.spiral_stack((128, 120))
    A multi-shot trajectory can be generated by specifying the ``nintl`` argument:
    >>> head = deepmr.spiral_stack((128, 120), nintl=48)
    Both spirals have constant density. If we want a dual density we can use ``acs_shape`` and ``acs_nintl`` arguments.
    For example, if we want an inner ``(32, 32, 16)`` k-space region sampled with a 4 interleaves spiral, this can be obtained as:
    >>> head = deepmr.spiral_stack((128, 120), nintl=48, acs_shape=(32, 16), acs_nintl=4)
    This inner region can be used e.g., for Parallel Imaging calibration. Similarly, a triple density spiral can
    be obtained by using the ``moco_shape`` argument:
    >>> head = deepmr.spiral_stack((128, 120), nintl=48, acs_shape=(32, 16), acs_nintl=4, moco_shape=8)
    The generated spiral will have an innermost ``(8, 8)`` single-shot k-space region (e.g., for PROPELLER-like motion correction),
    an intermediate ``(32, 32, 16)`` k-space region fully covered by 4 spiral shots and an outer ``(128, 128, 120)`` region fully covered by 48 interleaves.
    In-plane and slice accelerations can be specified using the ``accel`` argument. For example, the following
    >>> head = deepmr.spiral_stack((128, 120), nintl=48, accel=(4, 2))
    will generate the following trajectory:
    >>> head.traj.shape
    torch.Size([1, 720, 538, 3])
    i.e., a 48-interleaves trajectory with an in-plane acceleration factor of 4 (i.e., 12 interleaves)
    and slice acceleration of 2 (i.e., 60 encodings).
    Multiple contrasts with different sampling (e.g., for MR Fingerprinting) can be achieved by providing
    a tuple of ints as the ``shape`` argument:
    >>> head = deepmr.spiral_stack((128, 120, 420), nintl=48)
    >>> head.traj.shape
    torch.Size([420, 120, 538, 3])
    corresponding to 420 different contrasts, each sampled with a single spiral interleaf of 538 points,
    repeated for 120 slice encodings. Similarly, multiple echoes (with fixed sampling) can be specified as:
    >>> head = deepmr.spiral_stack((128, 120, 1, 8), nintl=48)
    >>> head.traj.shape
    torch.Size([8, 5760, 538, 3])
    corresponding to a 8-echoes fully sampled k-spaces, e.g., for QSM and T2* mapping.
    Notes
    -----
    The returned ``head`` (:func:`deepmr.Header`) is a structure with the following fields:
    * shape (torch.Tensor):
        This is the expected image size of shape ``(nz, ny, nx)``.
    * t (torch.Tensor):
        This is the readout sampling time ``(0, t_read)`` in ``ms``.
        with shape ``(nsamples,)``.
    * traj (torch.Tensor):
        This is the k-space trajectory normalized as ``(-0.5 * shape, 0.5 * shape)``
        with shape ``(ncontrasts, nviews, nsamples, 3)``.
    * dcf (torch.Tensor):
        This is the k-space sampling density compensation factor
        with shape ``(ncontrasts, nviews, nsamples)``.
    * TE (torch.Tensor):
        This is the Echo Times array. Assumes a k-space raster time of ``1 us``
        and minimal echo spacing.
    """
    assert len(shape) >= 2, "Please provide at least (in-plane, nslices) as shape."
    # expand shape if needed
    shape = list(shape)
    while len(shape) < 4:
        shape = shape + [1]
    # default accel
    if accel is None:
        if shape[2] == 1:
            accel = 1
        else:
            accel = nintl
    # expand accel if needed
    if np.isscalar(accel):
        accel = [accel, 1]
    else:
        accel = list(accel)
    # expand acs if needed
    if "acs_shape" in kwargs:
        acs_shape = kwargs["acs_shape"]
    else:
        acs_shape = [None]
    kwargs.pop("acs_shape", None)
    while len(acs_shape) < 2:
        acs_shape = acs_shape + [None]
    # assume 1mm iso
    fov = shape[0]
    # design single interleaf spiral
    tmp, _ = _design.spiral(fov, shape[0], 1, nintl, acs_shape=acs_shape[0], **kwargs)
    # rotate
    ncontrasts = shape[2]
    nviews = max(int(nintl // accel[0]), 1)
    # generate angles
    dphi = (1 - 233 / 377) * 360.0
    phi = np.arange(ncontrasts * nviews) * dphi  # angles in degrees
    phi = np.deg2rad(phi)  # angles in radians
    # build rotation matrix
    rot = _design.angleaxis2rotmat(phi, "z")
    # get trajectory
    traj = tmp["kr"] * tmp["mtx"]
    traj = _design.projection(traj[0].T, rot)
    traj = traj.swapaxes(-2, -1).T
    traj = traj.reshape(nviews, ncontrasts, *traj.shape[-2:])
    traj = traj.swapaxes(0, 1)
    # expand slices
    nz = shape[1]
    az = np.arange(-nz // 2, nz // 2, dtype=np.float32)
    # accelerate
    az = az[:: accel[1]]
    # add back ACS
    if acs_shape[1] is not None:
        az = np.concatenate(
            (az, np.arange(-acs_shape[1] // 2, acs_shape[1] // 2, dtype=np.float32))
        )
        az = np.unique(az)
    # expand
    traj = np.apply_along_axis(np.tile, -3, traj, len(az))
    az = np.repeat(az, nviews)
    az = az[None, :, None] * np.ones_like(traj[..., 0])
    # append new axis
    traj = np.concatenate((traj, az[..., None]), axis=-1)
    # get dcf
    dcf = tmp["dcf"]
    # expand echoes
    nechoes = shape[-1]
    traj = np.repeat(traj, nechoes, axis=0)
    # get shape
    shape = [shape[1]] + tmp["mtx"]
    # get time
    t = tmp["t"]
    # calculate TE
    min_te = float(tmp["te"][0])
    TE = np.arange(nechoes, dtype=np.float32) * t[-1] + min_te
    # extra args
    user = {}
    user["moco_shape"] = tmp["moco"]["mtx"]
    user["acs_shape"] = tmp["acs"]["mtx"]
    # get indexes
    head = Header(shape, t=t, traj=traj, dcf=dcf, TE=TE, user=user)
    head.torch()
    return head