Source code for deepmr.linops.coil

"""Sensitivity coil linear operator."""

__all__ = ["SenseOp", "SenseAdjointOp"]

import numpy as np
import torch

from . import base


[docs]class SenseOp(base.Linop): """ Multiply input image by coil sensitivity profile. Coil sensitivity profiles are expected to have the following dimensions: * 2D MRI: ``(nsets, nslices, ncoils, ny, nx)`` * 3D Cartesian MRI: ``(nsets, nx, ncoils, nz, ny)`` * 3D NonCartesian MRI: ``(nsets, ncoils, nz, ny, nx)`` where ``nsets`` represents multiple sets of coil sensitivity estimation for soft-SENSE implementations (e.g., ESPIRIT), equal to ``1`` for conventional SENSE and ``ncoils`` represents the number of receiver channels in the coil array. """
[docs] def __init__(self, ndim, sensmap, device=None, multicontrast=True): super().__init__(ndim) # cast map to tensor self.sensmap = torch.as_tensor(sensmap) # assign device if device is None: self.device = self.sensmap.device else: self.device = device # offloat to device self.sensmap = self.sensmap.to(self.device) # multicontrast self.multicontrast = multicontrast if self.multicontrast and self.ndim == 2: self.sensmap = self.sensmap.unsqueeze(-3) if self.multicontrast and self.ndim == 3: self.sensmap = self.sensmap.unsqueeze(-4)
def forward(self, x): """ Forward coil operator. Parameters ---------- x : np.ndarray | torch.Tensor Input combined images of shape ``(nslices, ..., ny, nx)``. (2D MRI / 3D Cartesian MRI) or ``(..., nz, ny, nx)`` (3D NonCartesian MRI). Returns ------- y : np.ndarray | torch.Tensor Output images of shape ``(nsets, nslices, ncoils, ..., ny, nx)``. (2D MRI / 3D Cartesian) or ``(nsets, ncoils, ..., nz, ny, nx)`` (3D NonCartesian MRI) modulated by coil sensitivity profiles. """ if isinstance(x, np.ndarray): isnumpy = True else: isnumpy = False # convert to tensor x = torch.as_tensor(x) # transfer to device self.sensmap = self.sensmap.to(x.device) # unsqueeze if self.multicontrast: if self.ndim == 2: x = x.unsqueeze(-4) elif self.ndim == 3: x = x.unsqueeze(-5) else: if self.ndim == 2: x = x.unsqueeze(-3) elif self.ndim == 3: x = x.unsqueeze(-4) # project y = self.sensmap * x # convert back to numpy if required if isnumpy: y = y.numpy(force=True) return y def _adjoint_linop(self): if self.multicontrast and self.ndim == 2: sensmap = self.sensmap.squeeze(-3) if self.multicontrast and self.ndim == 3: sensmap = self.sensmap.squeeze(-4) if self.multicontrast is False: sensmap = self.sensmap return SenseAdjointOp(self.ndim, sensmap, self.device, self.multicontrast)
[docs]class SenseAdjointOp(base.Linop): """ Perform coil combination. Coil sensitivity profiles are expected to have the following dimensions: * 2D MRI: ``(nslices, nsets, ncoils, ny, nx)`` * 3D MRI: ``(nsets, ncoils, nz, ny, nx)`` where ``nsets`` represents multiple sets of coil sensitivity estimation for soft-SENSE implementations (e.g., ESPIRIT), equal to ``1`` for conventional SENSE and ``ncoils`` represents the number of receiver channels in the coil array. """
[docs] def __init__(self, ndim, sensmap, device=None, multicontrast=True): super().__init__(ndim) # cast map to tensor self.sensmap = torch.as_tensor(sensmap) # assign device if device is None: self.device = self.sensmap.device else: self.device = device # offloat to device self.sensmap = self.sensmap.to(self.device) # multicontrast self.multicontrast = multicontrast if self.multicontrast and self.ndim == 2: self.sensmap = self.sensmap.unsqueeze(-3) if self.multicontrast and self.ndim == 3: self.sensmap = self.sensmap.unsqueeze(-4)
def forward(self, y): """ Adjoint coil operator (coil combination). Parameters ---------- y : np.ndarray | torch.Tensor Output images of shape ``(nsets, nslices, ncoils, ..., ny, nx)``. (2D MRI / 3D Cartesian MRI) or ``(nsets, ncoils, ..., nz, ny, nx)`` (3D NonCartesian MRI) modulated by coil sensitivity profiles. Returns ------- x : np.ndarray | torch.Tensor Output combined images of shape ``(nslices, ..., ny, nx)``. (2D MRI / 3D Cartesian MRI) or ``(..., nz, ny, nx)`` (3D NonCartesian MRI). """ if isinstance(y, np.ndarray): isnumpy = True else: isnumpy = False # convert to tensor y = torch.as_tensor(y) # transfer to device self.sensmap = self.sensmap.to(y.device) # apply sensitivity tmp = self.sensmap.conj() * y # combine (over channels and sets) if self.multicontrast: if self.ndim == 2: x = tmp.sum(axis=-4).sum(axis=0) elif self.ndim == 3: x = tmp.sum(axis=-5).sum(axis=0) else: if self.ndim == 2: x = tmp.sum(axis=-3).sum(axis=0) elif self.ndim == 3: x = tmp.sum(axis=-4).sum(axis=0) # convert back to numpy if required if isnumpy: x = x.numpy(force=True) return x def _adjoint_linop(self): if self.multicontrast and self.ndim == 2: sensmap = self.sensmap.squeeze(-3) if self.multicontrast and self.ndim == 3: sensmap = self.sensmap.squeeze(-4) if self.multicontrast is False: sensmap = self.sensmap return SenseOp(self.ndim, sensmap, self.device, self.multicontrast)