Source code for deepmr.optim.admm

"""Alternate Direction of Multipliers Method iteration."""

__all__ = ["admm_solve", "ADMMStep"]

import numpy as np
import torch

import torch.nn as nn

from .cg import cg_solve

from .. import linops as _linops


@torch.no_grad()
def admm_solve(
    input, step, AHA, D, niter=10, device=None, dc_niter=10, dc_tol=1e-4, dc_ndim=None
):
    """
    Solve inverse problem using Alternate Direction of Multipliers Method.

    Parameters
    ----------
    input : np.ndarray | torch.Tensor
        Signal to be reconstructed. Assume it is the adjoint AH of measurement
        operator A applied to the measured data y (i.e., input = AHy).
    step : float
        Gradient step size; should be <= 1 / max(eig(AHA)).
    AHA : Callable | torch.Tensor | np.ndarray
        Normal operator AHA = AH * A.
    D : Callable
        Signal denoiser for plug-n-play restoration.
    niter : int, optional
        Number of iterations. The default is ``10``.
    device : str, optional
        Computational device.
        The default is ``None`` (infer from input).
    dc_niter : int, optional
        Number of iterations of inner data consistency step.
        The default is ``10``.
    dc_tol : float, optional
        Stopping condition for inner data consistency step.
        The default is ``1e-4``.
    dc_ndim : int, optional
        Number of spatial dimensions of the problem for inner data consistency step.
        It is used to infer the batch axes. If ``AHA`` is a ``deepmr.linop.Linop``
        operator, this is inferred from ``AHA.ndim`` and ``ndim`` is ignored.

    Returns
    -------
    output : np.ndarray | torch.Tensor
        Reconstructed signal.

    """
    # cast to numpy if required
    if isinstance(input, np.ndarray):
        isnumpy = True
        input = torch.as_tensor(input)
    else:
        isnumpy = False

    # keep original device
    idevice = input.device
    if device is None:
        device = idevice

    # put on device
    input = input.to(device)
    if isinstance(AHA, _linops.Linop):
        AHA = AHA.to(device)
    elif callable(AHA) is False:
        AHA = torch.as_tensor(AHA, dtype=input.dtype, device=device)

    # assume input is AH(y), i.e., adjoint of measurement operator
    # applied on measured data
    AHy = input.clone()

    # initialize algorithm
    ADMM = ADMMStep(step, AHA, AHy, D, niter=dc_niter, tol=dc_tol, ndim=dc_ndim)

    # initialize
    input = 0 * input

    # run algorithm
    for n in range(niter):
        output = ADMM(input)
        input = output.clone()

    # back to original device
    output = output.to(device)

    # cast back to numpy if requried
    if isnumpy:
        output = output.numpy(force=True)

    return output


[docs]class ADMMStep(nn.Module): """ Alternate Direction of Multipliers Method step. This represents propagation through a single iteration of a ADMM algorithm; can be used to build unrolled architectures. Attributes ---------- step : float ADMM step size; should be <= 1 / max(eig(AHA)). AHA : Callable | torch.Tensor Normal operator AHA = AH * A. Ahy : torch.Tensor Adjoint AH of measurement operator A applied to the measured data y. D : Iterable(Callable) Signal denoiser(s) for plug-n-play restoration. trainable : bool, optional If ``True``, gradient update step is trainable, otherwise it is not. The default is ``False``. niter : int, optional Number of iterations of inner data consistency step. tol : float, optional Stopping condition for inner data consistency step. ndim : int, optional Number of spatial dimensions of the problem for inner data consistency step. It is used to infer the batch axes. If ``AHA`` is a ``deepmr.linop.Linop`` operator, this is inferred from ``AHA.ndim`` and ``ndim`` is ignored. """
[docs] def __init__( self, step, AHA, AHy, D, trainable=False, niter=10, tol=1e-4, ndim=None ): super().__init__() if trainable: self.step = nn.Parameter(step) else: self.step = step # set up problem dims try: self.ndim = AHA.ndim except Exception: self.ndim = ndim # assign operators self.AHA = AHA self.AHy = AHy # assign denoisers if hasattr(D, "__iter__"): self.D = list(D) else: self.D = [D] # prepare auxiliary self.xi = torch.zeros( [1 + len(self.D)] + list(AHy.shape), dtype=AHy.dtype, device=AHy.device, ) self.ui = torch.zeros_like(self.xi) # dc solver settings self.niter = niter self.tol = tol
def forward(self, input): # data consistency step: zk = (AHA + gamma * I).solve(AHy) self.xi[0] = cg_solve( self.AHy + self.step * (input - self.ui[0]), self.AHA, niter=self.niter, tol=self.tol, lamda=self.step, ndim=self.ndim, ) # denoise using each regularizator for n in range(len(self.D)): self.xi[n + 1] = self.D[n](input - self.ui[n + 1]) # average consensus output = self.xi.mean(axis=0) self.ui += self.xi - output[None, ...] return output