Source code for deepmr.optim.pgd

"""Proximal Gradient Method iteration."""

__all__ = ["pgd_solve", "PGDStep"]

import numpy as np
import torch

import torch.nn as nn

from .. import linops as _linops


@torch.no_grad()
def pgd_solve(input, step, AHA, D, niter=10, accelerate=True, device=None, tol=None):
    """
    Solve inverse problem using Proximal Gradient Method.

    Parameters
    ----------
    input : np.ndarray | torch.Tensor
        Signal to be reconstructed. Assume it is the adjoint AH of measurement
        operator A applied to the measured data y (i.e., input = AHy).
    step : float
        Gradient step size; should be <= 1 / max(eig(AHA)).
    AHA : Callable | torch.Tensor | np.ndarray
        Normal operator AHA = AH * A.
    D : Callable
        Signal denoiser for plug-n-play restoration.
    niter : int, optional
        Number of iterations. The default is ``10``.
    accelerate : bool, optional
        Toggle Nesterov acceleration (``True``, i.e., FISTA) or
        not (``False``, ISTA). The default is ``True``.
    device : str, optional
        Computational device.
        The default is ``None`` (infer from input).
    tol : float, optional
        Stopping condition.
        The default is ``None`` (run until niter).

    Returns
    -------
    output : np.ndarray | torch.Tensor
        Reconstructed signal.

    """
    # cast to numpy if required
    if isinstance(input, np.ndarray):
        isnumpy = True
        input = torch.as_tensor(input)
    else:
        isnumpy = False

    # keep original device
    idevice = input.device
    if device is None:
        device = idevice

    # put on device
    input = input.to(device)
    if isinstance(AHA, _linops.Linop):
        AHA = AHA.to(device)
    elif callable(AHA) is False:
        AHA = torch.as_tensor(AHA, dtype=input.dtype, device=device)

    # assume input is AH(y), i.e., adjoint of measurement operator
    # applied on measured data
    AHy = input.clone()

    # initialize Nesterov acceleration
    if accelerate:
        q = _get_acceleration(niter)
    else:
        q = [0.0] * niter

    # initialize algorithm
    PGD = PGDStep(step, AHA, AHy, D)

    # initialize
    input = 0 * input

    # run algorithm
    for n in range(niter):
        output = PGD(input, q[n])

        # if required, compute residual and check if we reached convergence
        if PGD.check_convergence(output, input, step):
            break

        # update variable
        input = output.clone()

    # back to original device
    output = output.to(device)

    # cast back to numpy if requried
    if isnumpy:
        output = output.numpy(force=True)

    return output


[docs]class PGDStep(nn.Module): """ Proximal Gradient Method step. This represents propagation through a single iteration of a Proximal Gradient Descent algorithm; can be used to build unrolled architectures. Attributes ---------- step : float Gradient step size; should be <= 1 / max(eig(AHA)). AHA : Callable | torch.Tensor Normal operator AHA = AH * A. Ahy : torch.Tensor Adjoint AH of measurement operator A applied to the measured data y. D : Callable Signal denoiser for plug-n-play restoration. trainable : bool, optional If ``True``, gradient update step is trainable, otherwise it is not. The default is ``False``. tol : float, optional Stopping condition. The default is ``None`` (run until niter). """
[docs] def __init__(self, step, AHA, AHy, D, trainable=False, tol=None): super().__init__() if trainable: self.step = nn.Parameter(step) else: self.step = step # assign self.AHA = AHA self.AHy = AHy self.D = D self.s = AHy.clone() self.tol = tol
def forward(self, input, q=0.0): # gradient step : zk = xk-1 - gamma * AH(A(xk-1) - y != FISTA (accelerated) z = input - self.step * (self.AHA(input) - self.AHy) # denoise: sk = D(zk) s = self.D(z) # update: xk = sk + [(qk-1 - 1) / qk] * (sk - sk-1) if q != 0.0: output = s + q * (s - self.s) self.s = s.clone() else: output = s # q1...qn = 1.0 != ISTA (non-accelerated) return output def check_convergence(self, output, input, step): if self.tol is not None: resid = torch.linalg.norm(output - input).item() / step if resid < self.tol: return True else: return False else: return False
# %% local utils def _get_acceleration(niter): t = [] t_new = 1 for n in range(niter): t_old = t_new t_new = (1 + (1 + 4 * t_old**2) ** 0.5) / 2 t.append((t_old - 1) / t_new) return t