Source code for deepmr.prox.tgv

"""Total generalized variation denoising prior."""

__all__ = ["TGVDenoiser", "tgv_denoise"]

import numpy as np
import torch
import torch.nn as nn


[docs]class TGVDenoiser(nn.Module): r""" Proximal operator of (2nd order) Total Generalised Variation operator. (see K. Bredies, K. Kunisch, and T. Pock, "Total generalized variation," SIAM J. Imaging Sci., 3(3), 492-526, 2010.) This algorithm converges to the unique image :math:`x` (and the auxiliary vector field :math:`r`) minimizing .. math:: \underset{x, r}{\arg\min} \; \frac{1}{2}\|x-y\|_2^2 + \lambda_1 \|r\|_{1,2} + \lambda_2 \|J(Dx-r)\|_{1,F} where :math:`D` maps an image to its gradient field and :math:`J` maps a vector field to its Jacobian. For a large value of :math:`\lambda_2`, the TGV behaves like the TV. For a small value, it behaves like the :math:`\ell_1`-Frobenius norm of the Hessian. The problem is solved with an over-relaxed Chambolle-Pock algorithm (see L. Condat, "A primal-dual splitting method for convex optimization involving Lipschitzian, proximable and linear composite terms", J. Optimization Theory and Applications, vol. 158, no. 2, pp. 460-479, 2013. Code (and description) adapted from Laurent Condat's matlab version (https://lcondat.github.io/software.html) and Daniil Smolyakov's `code <https://github.com/RoundedGlint585/TGVDenoising/blob/master/TGV%20WithoutHist.ipynb>`_. Attributes ---------- ndim : int Number of spatial dimensions, can be either ``2`` or ``3``. ths : float, optional Denoise threshold. The default is ``0.1``. trainable : bool, optional If ``True``, threshold value is trainable, otherwise it is not. The default is ``False``. device : str, optional Device on which the wavelet transform is computed. The default is ``None`` (infer from input). verbose : bool, optional Whether to print computation details or not. The default is ``False``. niter : int, optional, Maximum number of iterations. The default is ``1000``. crit : float, optional Convergence criterion. The default is 1e-5. x2 : torch.Tensor, optional Primary variable for warm restart. The default is ``None``. u2 : torch.Tensor, optional Dual variable for warm restart. The default is ``None``. r2 : torch.Tensor, optional Auxiliary variable for warm restart. The default is ``None``. Notes ----- The regularization term :math:`\|r\|_{1,2} + \|J(Dx-r)\|_{1,F}` is implicitly normalized by its Lipschitz constant, i.e. :math:`\sqrt{72}`, see e.g. K. Bredies et al., "Total generalized variation," SIAM J. Imaging Sci., 3(3), 492-526, 2010. """
[docs] def __init__( self, ndim, ths=0.1, trainable=False, device=None, verbose=False, niter=100, crit=1e-5, x2=None, u2=None, r2=None, ): super().__init__() if trainable: self.ths = nn.Parameter(ths) else: self.ths = ths self.denoiser = _TGVDenoiser( ndim=ndim, device=device, verbose=verbose, n_it_max=niter, crit=crit, x2=x2, u2=u2, r2=r2, ) self.denoiser.device = device
def forward(self, input): # get complex if torch.is_complex(input): iscomplex = True else: iscomplex = False # default device idevice = input.device if self.denoiser.device is None: device = idevice else: device = self.denoiser.device # get input shape ndim = self.denoiser.ndim ishape = input.shape # reshape for computation input = input.reshape(-1, *ishape[-ndim:]) if iscomplex: input = torch.stack((input.real, input.imag), axis=1) input = input.reshape(-1, *ishape[-ndim:]) # apply denoising output = self.denoiser(input[:, None, ...].to(device), self.ths).to( idevice ) # perform the denoising on the real-valued tensor # reshape back if iscomplex: output = ( output[::2, ...] + 1j * output[1::2, ...] ) # build the denoised complex data output = output.reshape(ishape) return output.to(idevice)
[docs]def tgv_denoise( input, ndim, ths=0.1, device=None, verbose=False, niter=100, crit=1e-5, x2=None, u2=None, ): r""" Apply Total Generalized Variation denoising. (see K. Bredies, K. Kunisch, and T. Pock, "Total generalized variation," SIAM J. Imaging Sci., 3(3), 492-526, 2010.) This algorithm converges to the unique image :math:`x` (and the auxiliary vector field :math:`r`) minimizing .. math:: \underset{x, r}{\arg\min} \; \frac{1}{2}\|x-y\|_2^2 + \lambda_1 \|r\|_{1,2} + \lambda_2 \|J(Dx-r)\|_{1,F} where :math:`D` maps an image to its gradient field and :math:`J` maps a vector field to its Jacobian. For a large value of :math:`\lambda_2`, the TGV behaves like the TV. For a small value, it behaves like the :math:`\ell_1`-Frobenius norm of the Hessian. The problem is solved with an over-relaxed Chambolle-Pock algorithm (see L. Condat, "A primal-dual splitting method for convex optimization involving Lipschitzian, proximable and linear composite terms", J. Optimization Theory and Applications, vol. 158, no. 2, pp. 460-479, 2013. Code (and description) adapted from Laurent Condat's matlab version (https://lcondat.github.io/software.html) and Daniil Smolyakov's `code <https://github.com/RoundedGlint585/TGVDenoising/blob/master/TGV%20WithoutHist.ipynb>`_. Arguments --------- input : np.ndarray | torch.Tensor Input image of shape (..., n_ndim, ..., n_0). ndim : int Number of spatial dimensions, can be either ``2`` or ``3``. ths : float, optional Denoise threshold. Default is ``0.1``. ndim : int Number of spatial dimensions, can be either ``2`` or ``3``. ths : float, optional Denoise threshold. The default is ``0.1``. trainable : bool, optional If ``True``, threshold value is trainable, otherwise it is not. The default is ``False``. device : str, optional Device on which the wavelet transform is computed. The default is ``None`` (infer from input). verbose : bool, optional Whether to print computation details or not. The default is ``False``. niter : int, optional, Maximum number of iterations. The default is ``1000``. crit : float, optional Convergence criterion. The default is 1e-5. x2 : torch.Tensor, optional Primary variable for warm restart. The default is ``None``. u2 : torch.Tensor, optional Dual variable for warm restart. The default is ``None``. r2 : torch.Tensor, optional Auxiliary variable for warm restart. The default is ``None``. Returns ------- output : np.ndarray | torch.Tensor Denoised image of shape (..., n_ndim, ..., n_0). """ # cast to numpy if required if isinstance(input, np.ndarray): isnumpy = True input = torch.as_tensor(input) else: isnumpy = False # initialize denoiser TV = TGVDenoiser(ndim, ths, False, device, verbose, niter, crit, x2, u2) output = TV(input) # cast back to numpy if requried if isnumpy: output = output.numpy(force=True) return output
# %% local utils class _TGVDenoiser(nn.Module): def __init__( self, ndim, device, verbose=False, n_it_max=1000, crit=1e-5, x2=None, u2=None, r2=None, ): super().__init__() self.device = device self.ndim = ndim if ndim == 2: self.nabla = self.nabla2 self.nabla_adjoint = self.nabla2_adjoint self.epsilon = self.epsilon2 self.epsilon_adjoint = self.epsilon2_adjoint elif ndim == 3: self.nabla = self.nabla3 self.nabla_adjoint = self.nabla3_adjoint self.epsilon = self.epsilon3 self.epsilon_adjoint = self.epsilon3_adjoint self.verbose = verbose self.n_it_max = n_it_max self.crit = crit self.restart = True self.tau = 0.01 # > 0 self.rho = 1.99 # in 1,2 self.sigma = 1 / self.tau / 72 self.x2 = x2 self.r2 = r2 self.u2 = u2 self.has_converged = False def prox_tau_fx(self, x, y): return (x + self.tau * y) / (1 + self.tau) def prox_tau_fr(self, r, lambda1): left = torch.sqrt(torch.sum(r**2, axis=-1)) / (self.tau * lambda1) tmp = r - r / ( torch.maximum( left, torch.tensor([1], device=left.device).type(left.dtype) ).unsqueeze(-1) ) return tmp def prox_sigma_g_conj(self, u, lambda2): return u / ( torch.maximum( torch.sqrt(torch.sum(u**2, axis=-1)) / lambda2, torch.tensor([1], device=u.device).type(u.dtype), ).unsqueeze(-1) ) def forward(self, y, ths=None): restart = ( True if (self.restart or self.x2 is None or self.x2.shape != y.shape) else False ) if restart: self.x2 = y.clone() self.r2 = torch.zeros((*self.x2.shape, 2), device=self.x2.device).type( self.x2.dtype ) self.u2 = torch.zeros((*self.x2.shape, 4), device=self.x2.device).type( self.x2.dtype ) self.restart = False if ths is not None: lambda1 = ths * 0.1 lambda2 = ths * 0.15 cy = (y**2).sum() / 2 primalcostlowerbound = 0 for _ in range(self.n_it_max): x_prev = self.x2.clone() tmp = self.tau * self.epsilon_adjoint(self.u2) x = self.prox_tau_fx(self.x2 - self.nabla_adjoint(tmp), y) r = self.prox_tau_fr(self.r2 + tmp, lambda1) u = self.prox_sigma_g_conj( self.u2 + self.sigma * self.epsilon(self.nabla(2 * x - self.x2) - (2 * r - self.r2)), lambda2, ) self.x2 = self.x2 + self.rho * (x - self.x2) self.r2 = self.r2 + self.rho * (r - self.r2) self.u2 = self.u2 + self.rho * (u - self.u2) rel_err = torch.linalg.norm( x_prev.flatten() - self.x2.flatten() ) / torch.linalg.norm(self.x2.flatten() + 1e-12) if _ > 1 and rel_err < self.crit: self.has_converged = True if self.verbose: print("TGV prox reached convergence") break if self.verbose and _ % 100 == 0: primalcost = ( torch.linalg.norm(x.flatten() - y.flatten()) ** 2 + lambda1 * torch.sum(torch.sqrt(torch.sum(r**2, axis=-1))) + lambda2 * torch.sum( torch.sqrt( torch.sum(self.epsilon(self.nabla(x) - r) ** 2, axis=-1) ) ) ) # dualcost = cy - ((y - nablaT(epsilonT(u))) ** 2).sum() / 2.0 tmp = torch.max( torch.sqrt(torch.sum(self.epsilon_adjoint(u) ** 2, axis=-1)) ) # to check feasibility: the value will be <= lambda1 only at convergence. Since u is not feasible, the dual cost is not reliable: the gap=primalcost-dualcost can be <0 and cannot be used as stopping criterion. u3 = u / torch.maximum( tmp / lambda1, torch.tensor([1], device=tmp.device).type(tmp.dtype) ) # u3 is a scaled version of u, which is feasible. so, its dual cost is a valid, but very rough lower bound of the primal cost. dualcost2 = ( cy - torch.sum((y - self.nabla_adjoint(self.epsilon_adjoint(u3))) ** 2) / 2.0 ) # we display the best value of dualcost2 computed so far. primalcostlowerbound = max(primalcostlowerbound, dualcost2.item()) if self.verbose: print( "Iter: ", _, " Primal cost: ", primalcost.item(), " Rel err:", rel_err, ) if _ == self.n_it_max - 1: if self.verbose: print( "The algorithm did not converge, stopped after " + str(_ + 1) + " iterations." ) return self.x2 @staticmethod def nabla2(x): r""" Applies the finite differences operator associated with tensors of the same shape as x. """ b, c, h, w = x.shape u = torch.zeros((b, c, h, w, 2), device=x.device).type(x.dtype) u[:, :, :-1, :, 0] = u[:, :, :-1, :, 0] - x[:, :, :-1] u[:, :, :-1, :, 0] = u[:, :, :-1, :, 0] + x[:, :, 1:] u[:, :, :, :-1, 1] = u[:, :, :, :-1, 1] - x[..., :-1] u[:, :, :, :-1, 1] = u[:, :, :, :-1, 1] + x[..., 1:] return u @staticmethod def nabla2_adjoint(x): r""" Applies the adjoint of the finite difference operator. """ b, c, h, w = x.shape[:-1] u = torch.zeros((b, c, h, w), device=x.device).type( x.dtype ) # note that we just reversed left and right sides of each line to obtain the transposed operator u[:, :, :-1] = u[:, :, :-1] - x[:, :, :-1, :, 0] u[:, :, 1:] = u[:, :, 1:] + x[:, :, :-1, :, 0] u[..., :-1] = u[..., :-1] - x[..., :-1, 1] u[..., 1:] = u[..., 1:] + x[..., :-1, 1] return u @staticmethod def nabla3(x): r""" Applies the finite differences operator associated with tensors of the same shape as x. """ b, c, d, h, w = x.shape u = torch.zeros((b, c, d, h, w, 3), device=x.device).type(x.dtype) u[:, :, :-1, :, :, 0] = u[:, :, :-1, :, :, 0] - x[:, :, :-1] u[:, :, :-1, :, :, 0] = u[:, :, :-1, :, :, 0] + x[:, :, 1:] u[:, :, :, :-1, :, 1] = u[:, :, :, :-1, :, 1] - x[:, :, :, :-1] u[:, :, :, :-1, :, 1] = u[:, :, :, :-1, :, 1] + x[:, :, :, 1:] u[:, :, :, :, :-1, 2] = u[:, :, :, :, :-1, 2] - x[:, :, :, :, :-1] u[:, :, :, :, :-1, 2] = u[:, :, :, :, :-1, 2] + x[:, :, :, :, 1:] return u @staticmethod def nabla3_adjoint(x): r""" Applies the adjoint of the finite difference operator. """ b, c, d, h, w = x.shape u = torch.zeros((b, c, d, h, w), device=x.device).type(x.dtype) u[:, :, :-1, :, 0] = u[:, :, :-1, :, 0] - x[:, :, :-1] u[:, :, 1:, :, 0] = u[:, :, 1:, :, 0] + x[:, :, :-1] u[:, :, :, :-1, 1] = u[:, :, :, :-1, 1] - x[:, :, :, :-1] u[:, :, :, 1:, 1] = u[:, :, :, 1:, 1] + x[:, :, :, :-1] u[:, :, :, :, :-1, 2] = u[:, :, :, :, :-1, 2] - x[:, :, :, :, :-1] u[:, :, :, :, 1:, 2] = u[:, :, :, :, 1:, 2] + x[:, :, :, :, :-1] return u @staticmethod def epsilon2(I): # Simplified r""" Applies the jacobian of a vector field. """ b, c, h, w, _ = I.shape G = torch.zeros((b, c, h, w, 4), device=I.device).type(I.dtype) G[:, :, 1:, :, 0] = G[:, :, 1:, :, 0] - I[:, :, :-1, :, 0] # xdy G[..., 0] = G[..., 0] + I[..., 0] G[..., 1:, 1] = G[..., 1:, 1] - I[..., :-1, 0] # xdx G[..., 1:, 1] = G[..., 1:, 1] + I[..., 1:, 0] G[..., 1:, 2] = G[..., 1:, 2] - I[..., :-1, 1] # xdx G[..., 2] = G[..., 2] + I[..., 1] G[:, :, :-1, :, 3] = G[:, :, :-1, :, 3] - I[:, :, :-1, :, 1] # xdy G[:, :, :-1, :, 3] = G[:, :, :-1, :, 3] + I[:, :, 1:, :, 1] return G @staticmethod def epsilon2_adjoint(G): r""" Applies the adjoint of the jacobian of a vector field. """ b, c, h, w, _ = G.shape I = torch.zeros((b, c, h, w, 2), device=G.device).type(G.dtype) I[:, :, :-1, :, 0] = I[:, :, :-1, :, 0] - G[:, :, 1:, :, 0] I[..., 0] = I[..., 0] + G[..., 0] I[..., :-1, 0] = I[..., :-1, 0] - G[..., 1:, 1] I[..., 1:, 0] = I[..., 1:, 0] + G[..., 1:, 1] I[..., :-1, 1] = I[..., :-1, 1] - G[..., 1:, 2] I[..., 1] = I[..., 1] + G[..., 2] I[:, :, :-1, :, 1] = I[:, :, :-1, :, 1] - G[:, :, :-1, :, 3] I[:, :, 1:, :, 1] = I[:, :, 1:, :, 1] + G[:, :, :-1, :, 3] return I @staticmethod def epsilon3(I): # Adapted for 3D matrices r""" Applies the jacobian of a vector field. """ b, c, d, h, w = I.shape G = torch.zeros((b, c, d, h, w, 6), device=I.device).type(I.dtype) G[:, :, :, 1:, :, 0] = G[:, :, :, 1:, :, 0] - I[:, :, :, :-1, :, 0] # xdy G[..., 0] = G[..., 0] + I[..., 0] G[..., 1:, :, 1] = G[..., 1:, :, 1] - I[..., :, :-1, 0] # xdx G[..., 1:, :, 1] = G[..., 1:, :, 1] + I[..., :, 1:, 0] G[..., 1:, :, 2] = G[..., 1:, :, 2] - I[..., :, :-1, 1] # xdz G[..., 2] = G[..., 2] + I[..., :, 1, 0] G[:, :, :, :-1, :, 3] = G[:, :, :, :-1, :, 3] - I[:, :, :, :-1, :, 1] # xdy G[:, :, :, :-1, :, 3] = G[:, :, :, :-1, :, 3] + I[:, :, :, 1:, :, 1] G[..., 3] = G[..., 3] + I[..., 0] G[..., 1:, :, 4] = G[..., 1:, :, 4] - I[..., 1:, :, :-1, 2] # xdz G[..., 4] = G[..., 4] + I[..., 1, :, :, 0] G[:, :, :, :, :-1, 5] = G[:, :, :, :, :-1, 5] - I[:, :, :, :, :-1, 2] # xdy G[:, :, :, 1:, :, 5] = G[:, :, :, 1:, :, 5] + I[:, :, :, :, :-1, 2] return G @staticmethod def epsilon3_adjoint(G): # Adapted for 3D matrices r""" Applies the adjoint of the jacobian of a vector field. """ b, c, d, h, w, _ = G.shape I = torch.zeros((b, c, d, h, w, 3), device=G.device).type(G.dtype) I[:, :, :, :-1, :, 0] = I[:, :, :, :-1, :, 0] - G[:, :, :, 1:, :, 0] I[..., 0] = I[..., 0] + G[..., 0] I[..., :-1, :, 0] = I[..., :-1, :, 0] - G[..., 1:, :, 1] I[..., 1:, :, 0] = I[..., 1:, :, 0] + G[..., 1:, :, 1] I[..., :-1, :, 1] = I[..., :-1, :, 1] - G[..., 1:, :, 2] I[..., 0] = I[..., 0] + G[..., 2] I[:, :, :, :-1, :, 1] = I[:, :, :, :-1, :, 1] - G[:, :, :, :-1, :, 3] I[:, :, :, 1:, :, 1] = I[:, :, :, 1:, :, 1] + G[:, :, :, :-1, :, 3] I[..., 1] = I[..., 1] + G[..., 3] I[..., :-1, :, 2] = I[..., :-1, :, 2] - G[..., 1:, :, 4] I[..., 0] = I[..., 0] + G[..., 4] I[:, :, :, :, :-1, 2] = I[:, :, :, :, :-1, 2] - G[:, :, :, :, :-1, 5] I[:, :, :, 1:, :, 2] = I[:, :, :, 1:, :, 2] + G[:, :, :, :, :-1, 5] return I