Source code for deepmr.recon.calib.espirit

"""Pytorch ESPIRIT implementation. Adapted for convenience from https://github.com/mikgroup/espirit-python/tree/master"""

__all__ = ["espirit_cal"]

import numpy as np
import torch

from ... import fft as _fft
from ... import _signal

from . import acs as _acs


[docs]def espirit_cal( data, coord=None, dcf=None, shape=None, k=6, r=24, t=0.02, c=0.0, nsets=1 ): """ Derives the ESPIRiT [1] operator. Parameters ---------- data : np.ndarray | torch.Tensor Multi channel k-space data. coord : np.ndarray | torch.Tensor, optional K-space trajectory of ``shape = (ncontrasts, nviews, nsamples, ndim)``. The default is ``None`` (Cartesian acquisition). dcf : np.ndarray | torch.Tensor, optional K-space density compensation of ``shape = (ncontrasts, nviews, nsamples)``. The default is ``None`` (no compensation). shape : Iterable[int] | optional Shape of the k-space after gridding. If not provided, estimate from input data (assumed on a Cartesian grid already). The default is ``None`` (Cartesian acquisition). k : int, optional k-space kernel size. The default is ``6``. r : int, optional Calibration region size. The default is ``24``. t : float, optional Rank of the auto-calibration matrix (A). The default is ``0.02``. c : float, optional Crop threshold that determines eigenvalues "=1". The defaults is ``0.95``. nsets : int, optional Number of set of maps to be returned. The default is ``1`` (conventional SENSE recon). Returns ------- maps : np.ndarray | torch.Tensor Output coil sensitivity maps. Notes ----- The input k-space ``data`` tensor is assumed to have the following shape: * **2Dcart:** ``(nslices, ncoils, ..., ny, nx)``. * **2Dnoncart:** ``(nslices, ncoils, ..., nviews, nsamples)``. * **3Dcart:** ``(nx, ncoils, ..., nz, ny)``. * **3Dnoncart:** ``(ncoils, ..., nviews, nsamples)``. For multi-contrast acquisitions, calibration is obtained by averaging over contrast dimensions. The output sensitivity maps are assumed to have the following shape: * **2Dcart:** ``(nsets, nslices, ncoils, ny, nx)``. * **2Dnoncart:** ``(nsets, nslices, ncoils, ny, nx)``. * **3Dcart:** ``(nsets, nx, ncoils, nz, ny)``. * **3Dnoncart:** ``(nsets, ncoils, nz, ny, nx)``. References ---------- .. [1] Uecker M, Lai P, Murphy MJ, Virtue P, Elad M, Pauly JM, Vasanawala SS, Lustig M. ESPIRiT--an eigenvalue approach to autocalibrating parallel MRI: where SENSE meets GRAPPA. Magn Reson Med. 2014 Mar;71(3):990-1001. doi: 10.1002/mrm.24751. PMID: 23649942; PMCID: PMC4142121. """ if isinstance(data, np.ndarray): isnumpy = True else: isnumpy = False while len(data.shape) < 5: data = data[None, ...] # keep shape if coord is not None: ndim = coord.shape[-1] if np.isscalar(shape): shape = ndim * [shape] else: shape = list(shape)[-ndim:] shape = [int(s) for s in shape] else: ndim = 2 shape = list(data.shape[-2:]) # extract calibration region cshape = list(np.asarray(shape, dtype=int) // 2) cal_data = _acs.find_acs(data, cshape, coord, dcf) # calculate maps maps = _espirit(cal_data.clone(), k, r, t, c) # select maps if nsets == 1: maps = maps[[0]] else: maps = maps[:nsets] # resample maps = _signal.resample(maps, shape) # (nsets, ncoils, nz, ny, nx) # normalize maps_rss = _signal.rss(maps, axis=1, keepdim=True) maps = maps / maps_rss[[0]] # reformat if ndim == 2: # Cartesian or 2D Non-Cartesian maps = maps.swapaxes( 1, 2 ) # (nsets, nslices, ncoils, ny, nx) / (nsets, nx, ncoils, nz, ny) # cast back to numpy if required if isnumpy: maps = maps.numpy(force=True) return maps, _signal.resize(cal_data, ndim * [r])
# %% local utils def _espirit(X, k, r, t, c): # transpose X = X.permute(3, 2, 1, 0) # get shape sx, sy, sz, nc = X.shape sxt = (sx // 2 - r // 2, sx // 2 + r // 2) if (sx > 1) else (0, 1) syt = (sy // 2 - r // 2, sy // 2 + r // 2) if (sy > 1) else (0, 1) szt = (sz // 2 - r // 2, sz // 2 + r // 2) if (sz > 1) else (0, 1) # Extract calibration region. C = X[sxt[0] : sxt[1], syt[0] : syt[1], szt[0] : szt[1], :].to( dtype=torch.complex64 ) # Construct Hankel matrix. p = (sx > 1) + (sy > 1) + (sz > 1) A = torch.zeros( [(r - k + 1) ** p, k**p * nc], dtype=torch.complex64, device=X.device ) idx = 0 for xdx in range(max(1, C.shape[0] - k + 1)): for ydx in range(max(1, C.shape[1] - k + 1)): for zdx in range(max(1, C.shape[2] - k + 1)): block = C[xdx : xdx + k, ydx : ydx + k, zdx : zdx + k, :].to( dtype=torch.complex64 ) A[idx, :] = block.flatten() idx += 1 # Take the Singular Value Decomposition. U, S, VH = torch.linalg.svd(A, full_matrices=True) V = VH.conj().t() # Select kernels n = torch.sum(S >= t * S[0]) V = V[:, :n] kxt = (sx // 2 - k // 2, sx // 2 + k // 2) if (sx > 1) else (0, 1) kyt = (sy // 2 - k // 2, sy // 2 + k // 2) if (sy > 1) else (0, 1) kzt = (sz // 2 - k // 2, sz // 2 + k // 2) if (sz > 1) else (0, 1) # Reshape into k-space kernel, flips it and takes the conjugate kernels = torch.zeros((sx, sy, sz, nc, n), dtype=torch.complex64, device=X.device) kerdims = [ ((sx > 1) * k + (sx == 1) * 1), ((sy > 1) * k + (sy == 1) * 1), ((sz > 1) * k + (sz == 1) * 1), nc, ] for idx in range(n): kernels[kxt[0] : kxt[1], kyt[0] : kyt[1], kzt[0] : kzt[1], :, idx] = V[ :, idx ].reshape(kerdims) # Take the iucfft axes = (0, 1, 2) kerimgs = ( _fft.fft(kernels.flip(0).flip(1).flip(2).conj(), axes) * (sx * sy * sz) ** 0.5 / (k**p) ** 0.5 ) # Take the point-wise eigenvalue decomposition and keep eigenvalues greater than c u, s, vh = torch.linalg.svd( kerimgs.view(sx, sy, sz, nc, n).reshape(-1, nc, n), full_matrices=True ) mask = s**2 > c # mask u (nvoxels, neigen, neigen) u = mask[:, None, :] * u # Reshape back to the original shape and assign to maps maps = u.view(sx, sy, sz, nc, nc) # transpose maps = maps.permute(4, 3, 2, 1, 0) return maps