"""Sensitivity coil linear operator."""
__all__ = ["SenseOp", "SenseAdjointOp"]
import numpy as np
import torch
from . import base
[docs]class SenseOp(base.Linop):
"""
Multiply input image by coil sensitivity profile.
Coil sensitivity profiles are expected to have the following dimensions:
* 2D MRI: ``(nsets, nslices, ncoils, ny, nx)``
* 3D Cartesian MRI: ``(nsets, nx, ncoils, nz, ny)``
* 3D NonCartesian MRI: ``(nsets, ncoils, nz, ny, nx)``
where ``nsets`` represents multiple sets of coil sensitivity estimation
for soft-SENSE implementations (e.g., ESPIRIT), equal to ``1`` for conventional SENSE
and ``ncoils`` represents the number of receiver channels in the coil array.
"""
[docs] def __init__(self, ndim, sensmap, device=None, multicontrast=True):
super().__init__(ndim)
# cast map to tensor
self.sensmap = torch.as_tensor(sensmap)
# assign device
if device is None:
self.device = self.sensmap.device
else:
self.device = device
# offloat to device
self.sensmap = self.sensmap.to(self.device)
# multicontrast
self.multicontrast = multicontrast
if self.multicontrast and self.ndim == 2:
self.sensmap = self.sensmap.unsqueeze(-3)
if self.multicontrast and self.ndim == 3:
self.sensmap = self.sensmap.unsqueeze(-4)
def forward(self, x):
"""
Forward coil operator.
Parameters
----------
x : np.ndarray | torch.Tensor
Input combined images of shape ``(nslices, ..., ny, nx)``.
(2D MRI / 3D Cartesian MRI) or ``(..., nz, ny, nx)`` (3D NonCartesian MRI).
Returns
-------
y : np.ndarray | torch.Tensor
Output images of shape ``(nsets, nslices, ncoils, ..., ny, nx)``.
(2D MRI / 3D Cartesian) or ``(nsets, ncoils, ..., nz, ny, nx)`` (3D NonCartesian MRI)
modulated by coil sensitivity profiles.
"""
if isinstance(x, np.ndarray):
isnumpy = True
else:
isnumpy = False
# convert to tensor
x = torch.as_tensor(x)
# transfer to device
self.sensmap = self.sensmap.to(x.device)
# unsqueeze
if self.multicontrast:
if self.ndim == 2:
x = x.unsqueeze(-4)
elif self.ndim == 3:
x = x.unsqueeze(-5)
else:
if self.ndim == 2:
x = x.unsqueeze(-3)
elif self.ndim == 3:
x = x.unsqueeze(-4)
# project
y = self.sensmap * x
# convert back to numpy if required
if isnumpy:
y = y.numpy(force=True)
return y
def _adjoint_linop(self):
if self.multicontrast and self.ndim == 2:
sensmap = self.sensmap.squeeze(-3)
if self.multicontrast and self.ndim == 3:
sensmap = self.sensmap.squeeze(-4)
if self.multicontrast is False:
sensmap = self.sensmap
return SenseAdjointOp(self.ndim, sensmap, self.device, self.multicontrast)
[docs]class SenseAdjointOp(base.Linop):
"""
Perform coil combination.
Coil sensitivity profiles are expected to have the following dimensions:
* 2D MRI: ``(nslices, nsets, ncoils, ny, nx)``
* 3D MRI: ``(nsets, ncoils, nz, ny, nx)``
where ``nsets`` represents multiple sets of coil sensitivity estimation
for soft-SENSE implementations (e.g., ESPIRIT), equal to ``1`` for conventional SENSE
and ``ncoils`` represents the number of receiver channels in the coil array.
"""
[docs] def __init__(self, ndim, sensmap, device=None, multicontrast=True):
super().__init__(ndim)
# cast map to tensor
self.sensmap = torch.as_tensor(sensmap)
# assign device
if device is None:
self.device = self.sensmap.device
else:
self.device = device
# offloat to device
self.sensmap = self.sensmap.to(self.device)
# multicontrast
self.multicontrast = multicontrast
if self.multicontrast and self.ndim == 2:
self.sensmap = self.sensmap.unsqueeze(-3)
if self.multicontrast and self.ndim == 3:
self.sensmap = self.sensmap.unsqueeze(-4)
def forward(self, y):
"""
Adjoint coil operator (coil combination).
Parameters
----------
y : np.ndarray | torch.Tensor
Output images of shape ``(nsets, nslices, ncoils, ..., ny, nx)``.
(2D MRI / 3D Cartesian MRI) or ``(nsets, ncoils, ..., nz, ny, nx)``
(3D NonCartesian MRI) modulated by coil sensitivity profiles.
Returns
-------
x : np.ndarray | torch.Tensor
Output combined images of shape ``(nslices, ..., ny, nx)``.
(2D MRI / 3D Cartesian MRI) or ``(..., nz, ny, nx)`` (3D NonCartesian MRI).
"""
if isinstance(y, np.ndarray):
isnumpy = True
else:
isnumpy = False
# convert to tensor
y = torch.as_tensor(y)
# transfer to device
self.sensmap = self.sensmap.to(y.device)
# apply sensitivity
tmp = self.sensmap.conj() * y
# combine (over channels and sets)
if self.multicontrast:
if self.ndim == 2:
x = tmp.sum(axis=-4).sum(axis=0)
elif self.ndim == 3:
x = tmp.sum(axis=-5).sum(axis=0)
else:
if self.ndim == 2:
x = tmp.sum(axis=-3).sum(axis=0)
elif self.ndim == 3:
x = tmp.sum(axis=-4).sum(axis=0)
# convert back to numpy if required
if isnumpy:
x = x.numpy(force=True)
return x
def _adjoint_linop(self):
if self.multicontrast and self.ndim == 2:
sensmap = self.sensmap.squeeze(-3)
if self.multicontrast and self.ndim == 3:
sensmap = self.sensmap.squeeze(-4)
if self.multicontrast is False:
sensmap = self.sensmap
return SenseOp(self.ndim, sensmap, self.device, self.multicontrast)