"""Total variation denoising prior."""
__all__ = ["TVDenoiser", "tv_denoise"]
import numpy as np
import torch
import torch.nn as nn
[docs]class TVDenoiser(nn.Module):
r"""
Proximal operator of the isotropic Total Variation operator.
This algorithm converges to the unique image :math:`x` that is the solution of
.. math::
\underset{x}{\arg\min} \; \frac{1}{2}\|x-y\|_2^2 + \gamma \|Dx\|_{1,2},
where :math:`D` maps an image to its gradient field.
The problem is solved with an over-relaxed Chambolle-Pock algorithm (see L. Condat, "A primal-dual splitting method
for convex optimization involving Lipschitzian, proximable and linear composite terms", J. Optimization Theory and
Applications, vol. 158, no. 2, pp. 460-479, 2013.
Code (and description) adapted from ``deepinv``, in turn adapted from
Laurent Condat's matlab version (https://lcondat.github.io/software.html) and
Daniil Smolyakov's `code <https://github.com/RoundedGlint585/TGVDenoising/blob/master/TGV%20WithoutHist.ipynb>`_.
This algorithm is implemented with warm restart, i.e. the primary and dual variables are kept in memory
between calls to the forward method. This speeds up the computation when using this class in an iterative algorithm.
Attributes
---------
ndim : int
Number of spatial dimensions, can be either ``2`` or ``3``.
ths : float, optional
Denoise threshold. The default is ``0.1``.
trainable : bool, optional
If ``True``, threshold value is trainable, otherwise it is not.
The default is ``False``.
device : str, optional
Device on which the wavelet transform is computed.
The default is ``None`` (infer from input).
verbose : bool, optional
Whether to print computation details or not. The default is ``False``.
niter : int, optional,
Maximum number of iterations. The default is ``1000``.
crit : float, optional
Convergence criterion. The default is 1e-5.
x2 : torch.Tensor, optional
Primary variable for warm restart. The default is ``None``.
u2 : torch.Tensor, optional
Dual variable for warm restart. The default is ``None``.
Notes
-----
The regularization term :math:`\|Dx\|_{1,2}` is implicitly normalized by its Lipschitz constant, i.e.
:math:`\sqrt{8}`, see e.g. A. Beck and M. Teboulle, "Fast gradient-based algorithms for constrained total
variation image denoising and deblurring problems", IEEE T. on Image Processing. 18(11), 2419-2434, 2009.
"""
[docs] def __init__(
self,
ndim,
ths=0.1,
trainable=False,
device=None,
verbose=False,
niter=100,
crit=1e-5,
x2=None,
u2=None,
):
super().__init__()
if trainable:
self.ths = nn.Parameter(ths)
else:
self.ths = ths
self.denoiser = _TVDenoiser(
ndim=ndim,
device=device,
verbose=verbose,
n_it_max=niter,
crit=crit,
x2=x2,
u2=u2,
)
self.denoiser.device = device
def forward(self, input):
# get complex
if torch.is_complex(input):
iscomplex = True
else:
iscomplex = False
# default device
idevice = input.device
if self.denoiser.device is None:
device = idevice
else:
device = self.denoiser.device
# get input shape
ndim = self.denoiser.ndim
ishape = input.shape
# reshape for computation
input = input.reshape(-1, *ishape[-ndim:])
if iscomplex:
input = torch.stack((input.real, input.imag), axis=1)
input = input.reshape(-1, *ishape[-ndim:])
# apply denoising
output = self.denoiser(input[:, None, ...].to(device), self.ths).to(
idevice
) # perform the denoising on the real-valued tensor
# reshape back
if iscomplex:
output = (
output[::2, ...] + 1j * output[1::2, ...]
) # build the denoised complex data
output = output.reshape(ishape)
return output.to(idevice)
[docs]def tv_denoise(
input,
ndim,
ths=0.1,
device=None,
verbose=False,
niter=100,
crit=1e-5,
x2=None,
u2=None,
):
r"""
Apply isotropic Total Variation denoising.
This algorithm converges to the unique image :math:`x` that is the solution of
.. math::
\underset{x}{\arg\min} \; \frac{1}{2}\|x-y\|_2^2 + \gamma \|Dx\|_{1,2},
where :math:`D` maps an image to its gradient field.
The problem is solved with an over-relaxed Chambolle-Pock algorithm (see L. Condat, "A primal-dual splitting method
for convex optimization involving Lipschitzian, proximable and linear composite terms", J. Optimization Theory and
Applications, vol. 158, no. 2, pp. 460-479, 2013.
Code (and description) adapted from ``deepinv``, in turn adapted from
Laurent Condat's matlab version (https://lcondat.github.io/software.html) and
Daniil Smolyakov's `code <https://github.com/RoundedGlint585/TGVDenoising/blob/master/TGV%20WithoutHist.ipynb>`_.
This algorithm is implemented with warm restart, i.e. the primary and dual variables are kept in memory
between calls to the forward method. This speeds up the computation when using this class in an iterative algorithm.
Arguments
---------
input : np.ndarray | torch.Tensor
Input image of shape (..., n_ndim, ..., n_0).
ndim : int
Number of spatial dimensions, can be either ``2`` or ``3``.
ths : float, optional
Denoise threshold. The default is``0.1``.
device : str, optional
Device on which the wavelet transform is computed.
The default is ``None`` (infer from input).
verbose : bool, optional
Whether to print computation details or not. The default is ``False``.
niter : int, optional,
Maximum number of iterations. The default is ``1000``.
crit : float, optional
Convergence criterion. The default is 1e-5.
x2 : torch.Tensor, optional
Primary variable for warm restart. The default is ``None``.
u2 : torch.Tensor, optional
Dual variable for warm restart. The default is ``None``.
Notes
-----
The regularization term :math:`\|Dx\|_{1,2}` is implicitly normalized by its Lipschitz constant, i.e.
:math:`\sqrt{8}`, see e.g. A. Beck and M. Teboulle, "Fast gradient-based algorithms for constrained total
variation image denoising and deblurring problems", IEEE T. on Image Processing. 18(11), 2419-2434, 2009.
Returns
-------
output : np.ndarray | torch.Tensor
Denoised image of shape (..., n_ndim, ..., n_0).
"""
# cast to numpy if required
if isinstance(input, np.ndarray):
isnumpy = True
input = torch.as_tensor(input)
else:
isnumpy = False
# initialize denoiser
TV = TVDenoiser(ndim, ths, False, device, verbose, niter, crit, x2, u2)
output = TV(input)
# cast back to numpy if requried
if isnumpy:
output = output.numpy(force=True)
return output
# %% local utils
class _TVDenoiser(nn.Module):
def __init__(
self,
ndim,
device=None,
verbose=False,
n_it_max=1000,
crit=1e-5,
x2=None,
u2=None,
):
super().__init__()
self.device = device
self.ndim = ndim
if ndim == 2:
self.nabla = self.nabla2
self.nabla_adjoint = self.nabla2_adjoint
elif ndim == 3:
self.nabla = self.nabla3
self.nabla_adjoint = self.nabla3_adjoint
self.verbose = verbose
self.n_it_max = n_it_max
self.crit = crit
self.restart = True
self.tau = 0.01 # > 0
self.rho = 1.99 # in 1,2
self.sigma = 1 / self.tau / 8
self.x2 = x2
self.u2 = u2
self.has_converged = False
def prox_tau_fx(self, x, y):
return (x + self.tau * y) / (1 + self.tau)
def prox_sigma_g_conj(self, u, lambda2):
return u / (
torch.maximum(
torch.sqrt(torch.sum(u**2, axis=-1)) / lambda2,
torch.tensor([1], device=u.device).type(u.dtype),
).unsqueeze(-1)
)
def forward(self, y, ths=None):
restart = (
True
if (self.restart or self.x2 is None or self.x2.shape != y.shape)
else False
)
if restart:
self.x2 = y.clone()
self.u2 = torch.zeros((*self.x2.shape, 2), device=self.x2.device).type(
self.x2.dtype
)
self.restart = False
if ths is not None:
lambd = ths
for _ in range(self.n_it_max):
x_prev = self.x2.clone()
x = self.prox_tau_fx(self.x2 - self.tau * self.nabla_adjoint(self.u2), y)
u = self.prox_sigma_g_conj(
self.u2 + self.sigma * self.nabla(2 * x - self.x2), lambd
)
self.x2 = self.x2 + self.rho * (x - self.x2)
self.u2 = self.u2 + self.rho * (u - self.u2)
rel_err = torch.linalg.norm(
x_prev.flatten() - self.x2.flatten()
) / torch.linalg.norm(self.x2.flatten() + 1e-12)
if _ > 1 and rel_err < self.crit:
if self.verbose:
print("TV prox reached convergence")
break
return self.x2
@staticmethod
def nabla2(x):
r"""
Applies the finite differences operator associated with tensors of the same shape as x.
"""
b, c, h, w = x.shape
u = torch.zeros((b, c, h, w, 2), device=x.device).type(x.dtype)
u[:, :, :-1, :, 0] = u[:, :, :-1, :, 0] - x[:, :, :-1]
u[:, :, :-1, :, 0] = u[:, :, :-1, :, 0] + x[:, :, 1:]
u[:, :, :, :-1, 1] = u[:, :, :, :-1, 1] - x[..., :-1]
u[:, :, :, :-1, 1] = u[:, :, :, :-1, 1] + x[..., 1:]
return u
@staticmethod
def nabla2_adjoint(x):
r"""
Applies the adjoint of the finite difference operator.
"""
b, c, h, w = x.shape[:-1]
u = torch.zeros((b, c, h, w), device=x.device).type(
x.dtype
) # note that we just reversed left and right sides of each line to obtain the transposed operator
u[:, :, :-1] = u[:, :, :-1] - x[:, :, :-1, :, 0]
u[:, :, 1:] = u[:, :, 1:] + x[:, :, :-1, :, 0]
u[..., :-1] = u[..., :-1] - x[..., :-1, 1]
u[..., 1:] = u[..., 1:] + x[..., :-1, 1]
return u
@staticmethod
def nabla3(x):
r"""
Applies the finite differences operator associated with tensors of the same shape as x.
"""
b, c, d, h, w = x.shape
u = torch.zeros((b, c, d, h, w, 3), device=x.device).type(x.dtype)
u[:, :, :-1, :, :, 0] = u[:, :, :-1, :, :, 0] - x[:, :, :-1]
u[:, :, :-1, :, :, 0] = u[:, :, :-1, :, :, 0] + x[:, :, 1:]
u[:, :, :, :-1, :, 1] = u[:, :, :, :-1, :, 1] - x[:, :, :, :-1]
u[:, :, :, :-1, :, 1] = u[:, :, :, :-1, :, 1] + x[:, :, :, 1:]
u[:, :, :, :, :-1, 2] = u[:, :, :, :, :-1, 2] - x[:, :, :, :, :-1]
u[:, :, :, :, :-1, 2] = u[:, :, :, :, :-1, 2] + x[:, :, :, :, 1:]
return u
@staticmethod
def nabla3_adjoint(x):
r"""
Applies the adjoint of the finite difference operator.
"""
b, c, d, h, w = x.shape
u = torch.zeros((b, c, d, h, w), device=x.device).type(x.dtype)
u[:, :, :-1, :, 0] = u[:, :, :-1, :, 0] - x[:, :, :-1]
u[:, :, 1:, :, 0] = u[:, :, 1:, :, 0] + x[:, :, :-1]
u[:, :, :, :-1, 1] = u[:, :, :, :-1, 1] - x[:, :, :, :-1]
u[:, :, :, 1:, 1] = u[:, :, :, 1:, 1] + x[:, :, :, :-1]
u[:, :, :, :, :-1, 2] = u[:, :, :, :, :-1, 2] - x[:, :, :, :, :-1]
u[:, :, :, :, 1:, 2] = u[:, :, :, :, 1:, 2] + x[:, :, :, :, :-1]
return u