"""Classical iterative reconstruction wrapper."""
__all__ = ["recon_lstsq"]
import numpy as np
import torch
from ... import optim as _optim
from ... import prox as _prox
from .. import calib as _calib
from ... import linops as _linops
from . import linop as _linop
from numba.core.errors import NumbaPerformanceWarning
import warnings
warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
[docs]def recon_lstsq(
data,
head,
mask=None,
niter=1,
prior=None,
prior_ths=0.01,
prior_params=None,
solver_params=None,
lamda=0.0,
stepsize=1.0,
basis=None,
nsets=1,
device=None,
cal_data=None,
toeplitz=True,
use_dcf=True,
):
"""
Classical MR reconstruction.
Parameters
----------
data : np.ndarray | torch.Tensor
Input k-space data of shape ``(nslices, ncoils, ncontrasts, nviews, nsamples)``.
head : deepmr.Header
DeepMR acquisition header, containing ``traj``, ``shape`` and ``dcf``.
mask : np.ndarray | torch.Tensor, optional
Sampling mask for Cartesian imaging.
Expected shape is ``(ncontrasts, nviews, nsamples)``.
The default is ``None``.
niter : int, optional
Number of recon iterations. If single iteration,
perform simple zero-filled recon. The default is ``1``.
prior : str | deepinv.optim.Prior, optional
Prior for image regularization. If string, it must be one of the following:
* ``"L1Wav"``: L1 Wavelet regularization.
* ``"TV"``: Total Variation regularization.
The default is ``None`` (no regularizer).
prior_ths : float, optional
Threshold for denoising in regularizer. The default is ``0.01``.
prior_params : dict, optional
Parameters for Prior initializations.
See :func:`deepmr.prox`.
The defaul it ``None`` (use each regularizer default parameters).
solver_params : dict, optional
Parameters for Solver initializations.
See :func:`deepmr.optim`.
The defaul it ``None`` (use each solver default parameters).
lamda : float, optional
Regularization strength. If 0.0, do not apply regularization.
The default is ``0.0``.
stepsize : float, optional
Iterations step size. If not provided, estimate from Encoding
operator maximum eigenvalue. The default is ``None``.
basis : np.ndarray | torch.Tensor, optional
Low rank subspace basis of shape ``(ncontrasts, ncoeffs)``. The default is ``None``.
nsets : int, optional
Number of coil sensitivity sets of maps. The default is ``1.
device : str, optional
Computational device. The default is ``None`` (same as ``data``).
cal_data : np.ndarray | torch.Tensor, optional
Calibration dataset for coil sensitivity estimation.
The default is ``None`` (use center region of ``data``).
toeplitz : bool, optional
Use Toeplitz approach for normal equation. The default is ``True``.
use_dcf : bool, optional
Use dcf to accelerate convergence. The default is ``True``.
Returns
-------
img np.ndarray | torch.Tensor
Reconstructed image of shape:
* 2D Cartesian: ``(nslices, ncontrasts, ny, nx).
* 2D Non Cartesian: ``(nslices, ncontrasts, ny, nx).
* 2D Non Cartesian: ``(nslices, ncontrasts, ny, nx).
* 3D Non Cartesian: ``(ncontrasts, nz, ny, nx).
"""
if isinstance(data, np.ndarray):
data = torch.as_tensor(data)
isnumpy = True
else:
isnumpy = False
if device is None:
device = data.device
data = data.to(device)
if use_dcf and head.dcf is not None:
dcf = head.dcf.to(device)
else:
dcf = None
# toggle off Topelitz for non-iterative
if niter == 1:
toeplitz = False
# get ndim
if head.traj is not None:
ndim = head.traj.shape[-1]
else:
ndim = 2 # assume 3D data already decoupled along readout
# build encoding operator
E, EHE = _linop.EncodingOp(
data,
mask,
head.traj,
dcf,
head.shape,
nsets,
basis,
device,
cal_data,
toeplitz,
)
# transfer
E = E.to(device)
EHE = EHE.to(device)
# perform zero-filled reconstruction
if dcf is not None:
img = E.H(dcf**0.5 * data[:, None, ...])
else:
img = E.H(data[:, None, ...])
# if non-iterative, just perform linear recon
if niter == 1:
output = img
if isnumpy:
output = output.numpy(force=True)
return output
# default solver params
if solver_params is None:
solver_params = {}
# rescale
img = _calib.intensity_scaling(img, ndim=ndim)
# if no prior is specified, use CG recon
if prior is None:
output = _optim.cg_solve(
img, EHE, niter=niter, lamda=lamda, ndim=ndim, **solver_params
)
if isnumpy:
output = output.numpy(force=True)
return output
# modify EHE
if lamda != 0.0:
_EHE = EHE + lamda * _linops.Identity(ndim)
else:
_EHE = EHE
# compute spectral norm
xhat = torch.rand(img.shape, dtype=img.dtype, device=img.device)
max_eig = _optim.power_method(None, xhat, AHA=_EHE, device=device, niter=30)
if max_eig != 0.0:
stepsize = stepsize / max_eig
# if a single prior is specified, use PDG
if isinstance(prior, (list, tuple)) is False:
# default prior params
if prior_params is None:
prior_params = {}
# get prior
D = _get_prior(prior, ndim, lamda, device, **prior_params)
# solve
output = _optim.pgd_solve(
img, stepsize, _EHE, D, niter=niter, accelerate=True, **solver_params
)
else:
npriors = len(prior)
if prior_params is None:
prior_params = [{} for n in range(npriors)]
else:
assert (
isinstance(prior_params, (list, tuple)) and len(prior_params) == npriors
), "Please provide parameters for each regularizer (or leave completely empty to use default)"
# get priors
D = []
for n in range(npriors):
d = _get_prior(prior[n], ndim, lamda, device, **prior_params[n])
D.append(d)
# solve
output = _optim.admm_solve(img, stepsize, _EHE, D, niter=niter, **solver_params)
if isnumpy:
output = output.numpy(force=True)
return output
# %% local utils
def _get_prior(ptype, ndim, lamda, device, **params):
if isinstance(ptype, str):
if ptype == "L1Wave":
return _prox.WaveletDenoiser(ndim, ths=lamda, device=device, **params)
elif ptype == "TV":
return _prox.TVDenoiser(ndim, ths=lamda, device=device, **params)
elif ptype == "LLR":
return _prox.LLRDenoiser(ndim, ths=lamda, device=device, **params)
else:
raise ValueError(
f"Prior type = {ptype} not recognized; either specify 'L1Wave', 'TV' or 'LLR', or 'nn.Module' object."
)
else:
raise NotImplementedError("Direct prior object not implemented.")