Source code for deepmr.optim.cg

"""Conjugate Gradient iteration."""

__all__ = ["cg_solve", "CGStep"]

import numpy as np
import torch

import torch.nn as nn

from .. import linops as _linops


@torch.no_grad()
def cg_solve(
    input,
    AHA,
    niter=10,
    device=None,
    tol=1e-4,
    lamda=0.0,
    ndim=None,
):
    """
    Solve inverse problem using Conjugate Gradient method.

    Parameters
    ----------
    input : np.ndarray | torch.Tensor
        Signal to be reconstructed. Assume it is the adjoint AH of measurement
        operator A applied to the measured data y (i.e., input = AHy).
    AHA : Callable | torch.Tensor | np.ndarray
        Normal operator AHA = AH * A.
    niter : int, optional
        Number of iterations. The default is ``10``.
    device : str, optional
        Computational device.
        The default is ``None`` (infer from input).
    tol : float, optional
        Stopping condition. The default is ``1e-4``.
    lamda : float, optional
        Tikhonov regularization strength. The default is ``0.0``.
    ndim : int, optional
        Number of spatial dimensions of the problem.
        It is used to infer the batch axes. If ``AHA`` is a ``deepmr.linop.Linop``
        operator, this is inferred from ``AHA.ndim`` and ``ndim`` is ignored.


    Returns
    -------
    output : np.ndarray | torch.Tensor
        Reconstructed signal.

    """
    # cast to numpy if required
    if isinstance(input, np.ndarray):
        isnumpy = True
        input = torch.as_tensor(input)
    else:
        isnumpy = False

    # keep original device
    idevice = input.device
    if device is None:
        device = idevice

    # put on device
    input = input.to(device)
    if isinstance(AHA, _linops.Linop):
        AHA = AHA.to(device)
    elif callable(AHA) is False:
        AHA = torch.as_tensor(AHA, dtype=input.dtype, device=device)

    # assume input is AH(y), i.e., adjoint of measurement operator
    # applied on measured data
    AHy = input.clone()

    # add Tikhonov regularization
    if lamda != 0.0:
        if isinstance(AHA, _linops.Linop):
            _AHA = AHA + lamda * _linops.Identity(AHA.ndim)
        elif callable(AHA):
            _AHA = lambda x: AHA(x) + lamda * x
        else:
            _AHA = lambda x: AHA @ x + lamda * x
    else:
        _AHA = AHA

    # initialize algorithm
    CG = CGStep(_AHA, AHy, ndim, tol)

    # initialize
    input = 0 * input

    # run algorithm
    for n in range(niter):
        output = CG(input)
        if CG.check_convergence():
            break
        input = output.clone()

    # back to original device
    output = output.to(device)

    # cast back to numpy if requried
    if isnumpy:
        output = output.numpy(force=True)

    return output


[docs]class CGStep(nn.Module): """ Conjugate Gradient method step. This represents propagation through a single iteration of a CG algorithm; can be used to build unrolled architectures. Attributes ---------- AHA : Callable | torch.Tensor Normal operator AHA = AH * A. Ahy : torch.Tensor Adjoint AH of measurement operator A applied to the measured data y. ndim : int Number of spatial dimensions of the problem. It is used to infer the batch axes. If ``AHA`` is a ``deepmr.linop.Linop`` operator, this is inferred from ``AHA.ndim`` and ``ndim`` is ignored. tol : float, optional Stopping condition. The default is ``None`` (run until niter). """
[docs] def __init__(self, AHA, AHy, ndim=None, tol=None): super().__init__() # set up problem dims try: self.ndim = AHA.ndim except Exception: self.ndim = ndim # assign operators self.AHA = AHA self.AHy = AHy # preallocate self.r = self.AHy.clone() self.p = self.r self.rsold = self.dot(self.r, self.r) self.rsnew = None self.tol = tol
def dot(self, s1, s2): dot = s1.conj() * s2 dot = dot.reshape(*s1.shape[: -self.ndim], -1).sum(axis=-1) return dot def forward(self, input): AHAp = self.AHA(self.p) alpha = self.rsold / self.dot(self.p, AHAp) output = input + self.p * alpha self.r = self.r + AHAp * (-alpha) self.rsnew = torch.real(self.dot(self.r, self.r)) self.p = self.r + self.p * (self.rsnew / self.rsold) self.rsold = self.rsnew return output def check_convergence(self): if self.tol is not None: if self.rsnew.sqrt() < self.tol: return True else: return False else: return False