"""Wavelet denoising."""
__all__ = [
"WaveletDenoiser",
"wavelet_denoise",
"WaveletDictDenoiser",
"wavelet_dict_denoise",
]
import numpy as np
import torch
import torch.nn as nn
import ptwt
import pywt
from . import threshold
[docs]class WaveletDenoiser(nn.Module):
r"""
Orthogonal Wavelet denoising with the :math:`\ell_1` norm.
Adapted from :func:``deepinv.denoisers.WaveletDenoiser``
to support complex-valued inputs.
This denoiser is defined as the solution to the optimization problem:
.. math::
\underset{x}{\arg\min} \; \|x-y\|^2 + \gamma \|\Psi x\|_n
where :math:`\Psi` is an orthonormal wavelet transform, :math:`\lambda>0` is a hyperparameter, and where
:math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``) or
the :math:`\ell_0` norm (``non_linearity="hard"``). A variant of the :math:`\ell_0` norm is also available
(``non_linearity="topk"``), where the thresholding is done by keeping the :math:`k` largest coefficients
in each wavelet subband and setting the others to zero.
The solution is available in closed-form, thus the denoiser is cheap to compute.
Notes
-----
Following common practice in signal processing, only detail coefficients are regularized, and the approximation
coefficients are left untouched.
Warning
-------
For 3D data, the computational complexity of the wavelet transform cubically with the size of the support. For
large 3D data, it is recommended to use wavelets with small support (e.g. db1 to db4).
Attributes
----------
ndim : int
Number of spatial dimensions, can be either ``2`` or ``3``.
ths : float, optional
Denoise threshold. The default is ``0.1``.
trainable : bool, optional
If ``True``, threshold value is trainable, otherwise it is not.
The default is ``False``.
wv : str, optional
Wavelet name to choose among those available in `pywt <https://pywavelets.readthedocs.io/en/latest/>`_.
The default is ``"db4"``.
device : str, optional
Device on which the wavelet transform is computed.
The default is ``None`` (infer from input).
non_linearity : str, optional
``"soft"``, ``"hard"`` or ``"topk"`` thresholding.
The default is ``"soft"``.
level: int, optional
Level of the wavelet transform. The default is ``None``.
"""
[docs] def __init__(
self,
ndim,
ths=0.1,
trainable=False,
wv="db4",
device=None,
non_linearity="soft",
level=None,
*args,
**kwargs
):
super().__init__()
if trainable:
self.ths = nn.Parameter(ths)
else:
self.ths = ths
self.denoiser = _WaveletDenoiser(
level=level,
wv=wv,
device=device,
non_linearity=non_linearity,
wvdim=ndim,
*args,
**kwargs
)
self.denoiser.device = device
def forward(self, input):
# get complex
if torch.is_complex(input):
iscomplex = True
else:
iscomplex = False
# default device
idevice = input.device
if self.denoiser.device is None:
device = idevice
else:
device = self.denoiser.device
# get input shape
ndim = self.denoiser.dimension
ishape = input.shape
# reshape for computation
input = input.reshape(-1, *ishape[-ndim:])
if iscomplex:
input = torch.stack((input.real, input.imag), axis=1)
input = input.reshape(-1, *ishape[-ndim:])
# apply denoising
output = self.denoiser(input[:, None, ...].to(device), self.ths).to(
idevice
) # perform the denoising on the real-valued tensor
# reshape back
if iscomplex:
output = (
output[::2, ...] + 1j * output[1::2, ...]
) # build the denoised complex data
output = output.reshape(ishape)
return output.to(idevice)
[docs]def wavelet_denoise(
input, ndim, ths, wv="db4", device=None, non_linearity="soft", level=None
):
r"""
Apply orthogonal Wavelet denoising with the :math:`\ell_1` norm.
Adapted from :func:``deepinv.denoisers.WaveletDenoiser``
to support complex-valued inputs.
This denoiser is defined as the solution to the optimization problem:
.. math::
\underset{x}{\arg\min} \; \|x-y\|^2 + \gamma \|\Psi x\|_n
where :math:`\Psi` is an orthonormal wavelet transform, :math:`\lambda>0` is a hyperparameter, and where
:math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``) or
the :math:`\ell_0` norm (``non_linearity="hard"``). A variant of the :math:`\ell_0` norm is also available
(``non_linearity="topk"``), where the thresholding is done by keeping the :math:`k` largest coefficients
in each wavelet subband and setting the others to zero.
The solution is available in closed-form, thus the denoiser is cheap to compute.
Arguments
---------
input : np.ndarray | torch.Tensor
Input image of shape (..., n_ndim, ..., n_0).
ndim : int
Number of spatial dimensions, can be either ``2`` or ``3``.
ths : float
Denoise threshold.
wv : str, optional
Wavelet name to choose among those available in `pywt <https://pywavelets.readthedocs.io/en/latest/>`_.
The default is ``"db4"``.
device : str, optional
Device on which the wavelet transform is computed.
The default is ``None`` (infer from input).
non_linearity : str, optional
``"soft"``, ``"hard"`` or ``"topk"`` thresholding.
The default is ``"soft"``.
level: int, optional
Level of the wavelet transform. The default is ``None``.
Returns
-------
output : np.ndarray | torch.Tensor
Denoised image of shape (..., n_ndim, ..., n_0).
"""
# cast to numpy if required
if isinstance(input, np.ndarray):
isnumpy = True
input = torch.as_tensor(input)
else:
isnumpy = False
# initialize denoiser
W = WaveletDenoiser(ndim, ths, False, wv, device, non_linearity, level)
output = W(input)
# cast back to numpy if requried
if isnumpy:
output = output.numpy(force=True)
return output
[docs]class WaveletDictDenoiser(nn.Module):
r"""
Overcomplete Wavelet denoising with the :math:`\ell_1` norm.
This denoiser is defined as the solution to the optimization problem:
.. math::
\underset{x}{\arg\min} \; \|x-y\|^2 + \lambda \|\Psi x\|_n
where :math:`\Psi` is an overcomplete wavelet transform, composed of 2 or more wavelets, i.e.,
:math:`\Psi=[\Psi_1,\Psi_2,\dots,\Psi_L]`, :math:`\lambda>0` is a hyperparameter, and where
:math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``),
the :math:`\ell_0` norm (``non_linearity="hard"``) or a variant of the :math:`\ell_0` norm
(``non_linearity="topk"``) where only the top-k coefficients are kept; see :meth:`deepinv.models.WaveletDenoiser` for
more details.
The solution is not available in closed-form, thus the denoiser runs an optimization algorithm for each test image.
Attributes
----------
ndim : int
Number of spatial dimensions, can be either ``2`` or ``3``.
ths : float, optional
Denoise threshold. The default is ``0.1``.
trainable : bool, optional
If ``True``, threshold value is trainable, otherwise it is not.
The default is ``False``.
wv : Iterable[str], optional
List of mother wavelets. The names of the wavelets can be found in `here <https://wavelets.pybytes.com/>`_.
The default is ``["db8", "db4"]``.
device : str, optional
Device on which the wavelet transform is computed.
The default is ``None`` (infer from input).
non_linearity : str, optional
``"soft"``, ``"hard"`` or ``"topk"`` thresholding.
The default is ``"soft"``.
level: int, optional
Level of the wavelet transform. The default is ``None``.
max_iter : int, optional
Number of iterations of the optimization algorithm.
The default is ``10``.
"""
[docs] def __init__(
self,
ndim,
ths=0.1,
trainable=False,
wv=None,
device=None,
non_linearity="soft",
level=None,
max_iter=10,
*args,
**kwargs
):
super().__init__()
if trainable:
self.ths = nn.Parameter(ths)
else:
self.ths = ths
self.denoiser = _WaveletDictDenoiser(
level=level,
wv=wv,
device=device,
max_iter=max_iter,
non_linearity=non_linearity,
wvdim=ndim,
*args,
**kwargs
)
self.denoiser.device = device
def forward(self, input):
# get complex
if torch.is_complex(input):
iscomplex = True
else:
iscomplex = False
# default device
idevice = input.device
if self.denoiser.device is None:
device = idevice
else:
device = self.denoiser.device
# get input shape
ndim = self.denoiser.dimension
ishape = input.shape
# reshape for computation
input = input.reshape(-1, *ishape[-ndim:])
if iscomplex:
input = torch.stack((input.real, input.imag), axis=1)
input = input.reshape(-1, *ishape[-ndim:])
# apply denoising
output = self.denoiser(input[:, None, ...].to(device), self.ths).to(
idevice
) # perform the denoising on the real-valued tensor
# reshape back
if iscomplex:
output = (
output[::2, ...] + 1j * output[1::2, ...]
) # build the denoised complex data
output = output.reshape(ishape)
return output.to(idevice)
[docs]def wavelet_dict_denoise(
input,
ndim,
ths,
wv=None,
device=None,
non_linearity="soft",
level=None,
max_iter=10,
):
r"""
Apply overcomplete Wavelet denoising with the :math:`\ell_1` norm.
This denoiser is defined as the solution to the optimization problem:
.. math::
\underset{x}{\arg\min} \; \|x-y\|^2 + \lambda \|\Psi x\|_n
where :math:`\Psi` is an overcomplete wavelet transform, composed of 2 or more wavelets, i.e.,
:math:`\Psi=[\Psi_1,\Psi_2,\dots,\Psi_L]`, :math:`\lambda>0` is a hyperparameter, and where
:math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``),
the :math:`\ell_0` norm (``non_linearity="hard"``) or a variant of the :math:`\ell_0` norm
(``non_linearity="topk"``) where only the top-k coefficients are kept; see :meth:`deepinv.models.WaveletDenoiser` for
more details.
The solution is not available in closed-form, thus the denoiser runs an optimization algorithm for each test image.
Attributes
----------
input : np.ndarray | torch.Tensor
Input image of shape (..., n_ndim, ..., n_0).
ndim : int
Number of spatial dimensions, can be either ``2`` or ``3``.
ths : float
Denoise threshold.
wv : Iterable[str], optional
List of mother wavelets. The names of the wavelets can be found in `here <https://wavelets.pybytes.com/>`_.
The default is ``["db8", "db4"]``.
device : str, optional
Device on which the wavelet transform is computed.
The default is ``None`` (infer from input).
non_linearity : str, optional
``"soft"``, ``"hard"`` or ``"topk"`` thresholding.
The default is ``"soft"``.
level: int, optional
Level of the wavelet transform. The default is ``None``.
max_iter : int, optional
Number of iterations of the optimization algorithm.
The default is ``10``.
Returns
-------
output : np.ndarray | torch.Tensor
Denoised image of shape (..., n_ndim, ..., n_0).
"""
# cast to numpy if required
if isinstance(input, np.ndarray):
isnumpy = True
input = torch.as_tensor(input)
else:
isnumpy = False
# initialize denoiser
WD = WaveletDictDenoiser(
ndim, ths, False, wv, device, non_linearity, level, max_iter
)
output = WD(input)
# cast back to numpy if requried
if isnumpy:
output = output.numpy(force=True)
return output
# %% local utils
class _WaveletDenoiser(nn.Module):
def __init__(
self, level=None, wv="db4", device="cpu", non_linearity="soft", wvdim=2
):
super().__init__()
self.level = level
self.wv = wv
self.device = device
self.non_linearity = non_linearity
self.dimension = wvdim
def dwt(self, x):
r"""
Applies the wavelet analysis.
"""
if self.level is None:
level = pywt.dwtn_max_level(x.shape[-self.dimension :], self.wv)
self.level = level
else:
level = self.level
if self.dimension == 2:
dec = ptwt.wavedec2(x, pywt.Wavelet(self.wv), mode="zero", level=level)
elif self.dimension == 3:
dec = ptwt.wavedec3(x, pywt.Wavelet(self.wv), mode="zero", level=level)
dec = [list(t) if isinstance(t, tuple) else t for t in dec]
return dec
def flatten_coeffs(self, dec):
r"""
Flattens the wavelet coefficients and returns them in a single torch vector of shape (n_coeffs,).
"""
if self.dimension == 2:
flat = torch.hstack(
[dec[0].flatten()]
+ [decl.flatten() for l in range(1, len(dec)) for decl in dec[l]]
)
elif self.dimension == 3:
flat = torch.hstack(
[dec[0].flatten()]
+ [dec[l][key].flatten() for l in range(1, len(dec)) for key in dec[l]]
)
return flat
@staticmethod
def psi(x, wavelet="db4", level=None, dimension=2):
r"""
Returns a flattened list containing the wavelet coefficients.
"""
if level is None:
level = pywt.dwtn_max_level(x.shape[-dimension:], wavelet)
if dimension == 2:
dec = ptwt.wavedec2(x, pywt.Wavelet(wavelet), mode="zero", level=level)
dec = [list(t) if isinstance(t, tuple) else t for t in dec]
vec = [decl.flatten() for l in range(1, len(dec)) for decl in dec[l]]
elif dimension == 3:
dec = ptwt.wavedec3(x, pywt.Wavelet(wavelet), mode="zero", level=level)
dec = [list(t) if isinstance(t, tuple) else t for t in dec]
vec = [dec[l][key].flatten() for l in range(1, len(dec)) for key in dec[l]]
return vec
def iwt(self, coeffs):
r"""
Applies the wavelet synthesis.
"""
coeffs = [tuple(t) if isinstance(t, list) else t for t in coeffs]
if self.dimension == 2:
rec = ptwt.waverec2(coeffs, pywt.Wavelet(self.wv))
elif self.dimension == 3:
rec = ptwt.waverec3(coeffs, pywt.Wavelet(self.wv))
return rec
def prox_l1(self, x, ths):
r"""
Soft thresholding of the wavelet coefficients.
Arguments
---------
x : torch.Tensor
Wavelet coefficients.
ths : float, optional
Threshold. It can be element-wise, in which case
it is assumed to be broadcastable with ``input``.
The default is ``0.1``.
Returns
-------
torch.Tensor
Thresholded wavelet coefficients.
"""
return threshold.soft_thresh(x, ths)
def prox_l0(self, x, ths):
r"""
Hard thresholding of the wavelet coefficients.
Arguments
---------
x : torch.Tensor
Wavelet coefficients.
ths : float, optional
Threshold. It can be element-wise, in which case
it is assumed to be broadcastable with ``input``.
The default is ``0.1``.
Returns
-------
torch.Tensor
Thresholded wavelet coefficients.
"""
if isinstance(ths, float):
ths_map = ths
else:
ths_map = ths.repeat(
1, 1, 1, x.shape[-2], x.shape[-1]
) # Reshaping to image wavelet shape
out = x.clone()
out[abs(out) < ths_map] = 0
return out
def hard_threshold_topk(self, x, ths):
r"""
Hard thresholding of the wavelet coefficients by keeping only the top-k coefficients and setting the others to 0.
Arguments
---------
x : torch.Tensor
Wavelet coefficients.
ths : float | int, optional
Top k coefficients to keep. If ``float``, it is interpreted as a proportion of the total
number of coefficients. If ``int``, it is interpreted as the number of coefficients to keep.
The default is ``0.1`.
Returns
-------
torch.Tensor
Thresholded wavelet coefficients.
"""
if isinstance(ths, float):
k = int(ths * x.shape[-3] * x.shape[-2] * x.shape[-1])
else:
k = int(ths)
# Reshape arrays to 2D and initialize output to 0
x_flat = x.reshape(x.shape[0], -1)
out = torch.zeros_like(x_flat)
topk_indices_flat = torch.topk(abs(x_flat), k, dim=-1)[1]
# Convert the flattened indices to the original indices of x
batch_indices = (
torch.arange(x.shape[0], device=x.device).unsqueeze(1).repeat(1, k)
)
topk_indices = torch.stack([batch_indices, topk_indices_flat], dim=-1)
# Set output's top-k elements to values from original x
out[tuple(topk_indices.view(-1, 2).t())] = x_flat[
tuple(topk_indices.view(-1, 2).t())
]
return torch.reshape(out, x.shape)
def thresold_func(self, x, ths):
r""" "
Apply thresholding to the wavelet coefficients.
"""
if self.non_linearity == "soft":
y = self.prox_l1(x, ths)
elif self.non_linearity == "hard":
y = self.prox_l0(x, ths)
elif self.non_linearity == "topk":
y = self.hard_threshold_topk(x, ths)
return y
def thresold_2D(self, coeffs, ths):
r"""
Thresholds coefficients of the 2D wavelet transform.
"""
for level in range(1, self.level + 1):
ths_cur = self.reshape_ths(ths, level)
for c in range(3):
coeffs[level][c] = self.thresold_func(coeffs[level][c], ths_cur[c])
return coeffs
def threshold_3D(self, coeffs, ths):
r"""
Thresholds coefficients of the 3D wavelet transform.
"""
for level in range(1, self.level + 1):
ths_cur = self.reshape_ths(ths, level)
for c, key in enumerate(["aad", "ada", "daa", "add", "dad", "dda", "ddd"]):
coeffs[level][key] = self.prox_l1(coeffs[level][key], ths_cur[c])
return coeffs
def threshold_ND(self, coeffs, ths):
r"""
Apply thresholding to the wavelet coefficients of arbitrary dimension.
"""
if self.dimension == 2:
coeffs = self.thresold_2D(coeffs, ths)
elif self.dimension == 3:
coeffs = self.threshold_3D(coeffs, ths)
else:
raise ValueError("Only 2D and 3D wavelet transforms are supported")
return coeffs
def pad_input(self, x):
r"""
Pad the input to make it compatible with the wavelet transform.
"""
if self.dimension == 2:
h, w = x.size()[-2:]
padding_bottom = h % 2
padding_right = w % 2
p = (padding_bottom, padding_right)
x = torch.nn.ReplicationPad2d((0, p[0], 0, p[1]))(x)
elif self.dimension == 3:
d, h, w = x.size()[-3:]
padding_depth = d % 2
padding_bottom = h % 2
padding_right = w % 2
p = (padding_depth, padding_bottom, padding_right)
x = torch.nn.ReplicationPad3d((0, p[0], 0, p[1], 0, p[2]))(x)
return x, p
def crop_output(self, x, padding):
r"""
Crop the output to make it compatible with the wavelet transform.
"""
d, h, w = x.size()[-3:]
if len(padding) == 2:
out = x[..., : h - padding[0], : w - padding[1]]
elif len(padding) == 3:
out = x[..., : d - padding[0], : h - padding[1], : w - padding[2]]
return out
def reshape_ths(self, ths, level):
r"""
Reshape the thresholding parameter in the appropriate format, i.e. either:
- a list of 3 elements, or
- a tensor of 3 elements.
Since the approximation coefficients are not thresholded, we do not need to provide a thresholding parameter,
ths has shape (n_levels-1, 3).
"""
numel = 3 if self.dimension == 2 else 7
if not torch.is_tensor(ths):
if isinstance(ths, int) or isinstance(ths, float):
ths_cur = [ths] * numel
elif len(ths) == 1:
ths_cur = [ths[0]] * numel
else:
ths_cur = ths[level]
if len(ths_cur) == 1:
ths_cur = [ths_cur[0]] * numel
else:
if len(ths.shape) == 1: # Needs to reshape to shape (n_levels-1, 3)
ths_cur = ths.squeeze().repeat(numel)
else:
ths_cur = ths[level - 2]
return ths_cur
def forward(self, x, ths):
# Pad data
x, padding = self.pad_input(x)
# Apply wavelet transform
coeffs = self.dwt(x)
# Threshold coefficients (we do not threshold the approximation coefficients)
coeffs = self.threshold_ND(coeffs, ths)
# Inverse wavelet transform
y = self.iwt(coeffs)
# Crop data
y = self.crop_output(y, padding)
return y
class _WaveletDictDenoiser(nn.Module):
r"""
Overcomplete Wavelet denoising with the :math:`\ell_1` norm.
This denoiser is defined as the solution to the optimization problem:
.. math::
\underset{x}{\arg\min} \; \|x-y\|^2 + \lambda \|\Psi x\|_n
where :math:`\Psi` is an overcomplete wavelet transform, composed of 2 or more wavelets, i.e.,
:math:`\Psi=[\Psi_1,\Psi_2,\dots,\Psi_L]`, :math:`\lambda>0` is a hyperparameter, and where
:math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``),
the :math:`\ell_0` norm (``non_linearity="hard"``) or a variant of the :math:`\ell_0` norm
(``non_linearity="topk"``) where only the top-k coefficients are kept; see :meth:`deepinv.models.WaveletDenoiser` for
more details.
The solution is not available in closed-form, thus the denoiser runs an optimization algorithm for each test image.
:param int level: decomposition level of the wavelet transform.
:param list[str] wv: list of mother wavelets. The names of the wavelets can be found in `here
<https://wavelets.pybytes.com/>`_. (default: ["db8", "db4"]).
:param str device: cpu or gpu.
:param int max_iter: number of iterations of the optimization algorithm (default: 10).
:param str non_linearity: "soft", "hard" or "topk" thresholding (default: "soft")
"""
def __init__(
self,
level=None,
list_wv=["db8", "db4"],
max_iter=10,
non_linearity="soft",
wvdim=2,
):
super().__init__()
self.level = level
self.list_prox = nn.ModuleList(
[
_WaveletDenoiser(
level=level, wv=wv, non_linearity=non_linearity, wvdim=wvdim
)
for wv in list_wv
]
)
self.max_iter = max_iter
def forward(self, y, ths):
z_p = y.repeat(len(self.list_prox), *([1] * (len(y.shape))))
p_p = torch.zeros_like(z_p)
x = p_p.clone()
for it in range(self.max_iter):
x_prev = x.clone()
for p in range(len(self.list_prox)):
p_p[p, ...] = self.list_prox[p](z_p[p, ...], ths)
x = torch.mean(p_p.clone(), axis=0)
for p in range(len(self.list_prox)):
z_p[p, ...] = x + z_p[p, ...].clone() - p_p[p, ...]
rel_crit = torch.linalg.norm((x - x_prev).flatten()) / torch.linalg.norm(
x.flatten() + 1e-6
)
if rel_crit < 1e-3:
break
return x