Source code for deepmr.prox.wavelet

"""Wavelet denoising."""

__all__ = [
    "WaveletDenoiser",
    "wavelet_denoise",
    "WaveletDictDenoiser",
    "wavelet_dict_denoise",
]

import numpy as np
import torch
import torch.nn as nn

import ptwt
import pywt

from . import threshold


[docs]class WaveletDenoiser(nn.Module): r""" Orthogonal Wavelet denoising with the :math:`\ell_1` norm. Adapted from :func:``deepinv.denoisers.WaveletDenoiser`` to support complex-valued inputs. This denoiser is defined as the solution to the optimization problem: .. math:: \underset{x}{\arg\min} \; \|x-y\|^2 + \gamma \|\Psi x\|_n where :math:`\Psi` is an orthonormal wavelet transform, :math:`\lambda>0` is a hyperparameter, and where :math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``) or the :math:`\ell_0` norm (``non_linearity="hard"``). A variant of the :math:`\ell_0` norm is also available (``non_linearity="topk"``), where the thresholding is done by keeping the :math:`k` largest coefficients in each wavelet subband and setting the others to zero. The solution is available in closed-form, thus the denoiser is cheap to compute. Notes ----- Following common practice in signal processing, only detail coefficients are regularized, and the approximation coefficients are left untouched. Warning ------- For 3D data, the computational complexity of the wavelet transform cubically with the size of the support. For large 3D data, it is recommended to use wavelets with small support (e.g. db1 to db4). Attributes ---------- ndim : int Number of spatial dimensions, can be either ``2`` or ``3``. ths : float, optional Denoise threshold. The default is ``0.1``. trainable : bool, optional If ``True``, threshold value is trainable, otherwise it is not. The default is ``False``. wv : str, optional Wavelet name to choose among those available in `pywt <https://pywavelets.readthedocs.io/en/latest/>`_. The default is ``"db4"``. device : str, optional Device on which the wavelet transform is computed. The default is ``None`` (infer from input). non_linearity : str, optional ``"soft"``, ``"hard"`` or ``"topk"`` thresholding. The default is ``"soft"``. level: int, optional Level of the wavelet transform. The default is ``None``. """
[docs] def __init__( self, ndim, ths=0.1, trainable=False, wv="db4", device=None, non_linearity="soft", level=None, *args, **kwargs ): super().__init__() if trainable: self.ths = nn.Parameter(ths) else: self.ths = ths self.denoiser = _WaveletDenoiser( level=level, wv=wv, device=device, non_linearity=non_linearity, wvdim=ndim, *args, **kwargs ) self.denoiser.device = device
def forward(self, input): # get complex if torch.is_complex(input): iscomplex = True else: iscomplex = False # default device idevice = input.device if self.denoiser.device is None: device = idevice else: device = self.denoiser.device # get input shape ndim = self.denoiser.dimension ishape = input.shape # reshape for computation input = input.reshape(-1, *ishape[-ndim:]) if iscomplex: input = torch.stack((input.real, input.imag), axis=1) input = input.reshape(-1, *ishape[-ndim:]) # apply denoising output = self.denoiser(input[:, None, ...].to(device), self.ths).to( idevice ) # perform the denoising on the real-valued tensor # reshape back if iscomplex: output = ( output[::2, ...] + 1j * output[1::2, ...] ) # build the denoised complex data output = output.reshape(ishape) return output.to(idevice)
[docs]def wavelet_denoise( input, ndim, ths, wv="db4", device=None, non_linearity="soft", level=None ): r""" Apply orthogonal Wavelet denoising with the :math:`\ell_1` norm. Adapted from :func:``deepinv.denoisers.WaveletDenoiser`` to support complex-valued inputs. This denoiser is defined as the solution to the optimization problem: .. math:: \underset{x}{\arg\min} \; \|x-y\|^2 + \gamma \|\Psi x\|_n where :math:`\Psi` is an orthonormal wavelet transform, :math:`\lambda>0` is a hyperparameter, and where :math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``) or the :math:`\ell_0` norm (``non_linearity="hard"``). A variant of the :math:`\ell_0` norm is also available (``non_linearity="topk"``), where the thresholding is done by keeping the :math:`k` largest coefficients in each wavelet subband and setting the others to zero. The solution is available in closed-form, thus the denoiser is cheap to compute. Arguments --------- input : np.ndarray | torch.Tensor Input image of shape (..., n_ndim, ..., n_0). ndim : int Number of spatial dimensions, can be either ``2`` or ``3``. ths : float Denoise threshold. wv : str, optional Wavelet name to choose among those available in `pywt <https://pywavelets.readthedocs.io/en/latest/>`_. The default is ``"db4"``. device : str, optional Device on which the wavelet transform is computed. The default is ``None`` (infer from input). non_linearity : str, optional ``"soft"``, ``"hard"`` or ``"topk"`` thresholding. The default is ``"soft"``. level: int, optional Level of the wavelet transform. The default is ``None``. Returns ------- output : np.ndarray | torch.Tensor Denoised image of shape (..., n_ndim, ..., n_0). """ # cast to numpy if required if isinstance(input, np.ndarray): isnumpy = True input = torch.as_tensor(input) else: isnumpy = False # initialize denoiser W = WaveletDenoiser(ndim, ths, False, wv, device, non_linearity, level) output = W(input) # cast back to numpy if requried if isnumpy: output = output.numpy(force=True) return output
[docs]class WaveletDictDenoiser(nn.Module): r""" Overcomplete Wavelet denoising with the :math:`\ell_1` norm. This denoiser is defined as the solution to the optimization problem: .. math:: \underset{x}{\arg\min} \; \|x-y\|^2 + \lambda \|\Psi x\|_n where :math:`\Psi` is an overcomplete wavelet transform, composed of 2 or more wavelets, i.e., :math:`\Psi=[\Psi_1,\Psi_2,\dots,\Psi_L]`, :math:`\lambda>0` is a hyperparameter, and where :math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``), the :math:`\ell_0` norm (``non_linearity="hard"``) or a variant of the :math:`\ell_0` norm (``non_linearity="topk"``) where only the top-k coefficients are kept; see :meth:`deepinv.models.WaveletDenoiser` for more details. The solution is not available in closed-form, thus the denoiser runs an optimization algorithm for each test image. Attributes ---------- ndim : int Number of spatial dimensions, can be either ``2`` or ``3``. ths : float, optional Denoise threshold. The default is ``0.1``. trainable : bool, optional If ``True``, threshold value is trainable, otherwise it is not. The default is ``False``. wv : Iterable[str], optional List of mother wavelets. The names of the wavelets can be found in `here <https://wavelets.pybytes.com/>`_. The default is ``["db8", "db4"]``. device : str, optional Device on which the wavelet transform is computed. The default is ``None`` (infer from input). non_linearity : str, optional ``"soft"``, ``"hard"`` or ``"topk"`` thresholding. The default is ``"soft"``. level: int, optional Level of the wavelet transform. The default is ``None``. max_iter : int, optional Number of iterations of the optimization algorithm. The default is ``10``. """
[docs] def __init__( self, ndim, ths=0.1, trainable=False, wv=None, device=None, non_linearity="soft", level=None, max_iter=10, *args, **kwargs ): super().__init__() if trainable: self.ths = nn.Parameter(ths) else: self.ths = ths self.denoiser = _WaveletDictDenoiser( level=level, wv=wv, device=device, max_iter=max_iter, non_linearity=non_linearity, wvdim=ndim, *args, **kwargs ) self.denoiser.device = device
def forward(self, input): # get complex if torch.is_complex(input): iscomplex = True else: iscomplex = False # default device idevice = input.device if self.denoiser.device is None: device = idevice else: device = self.denoiser.device # get input shape ndim = self.denoiser.dimension ishape = input.shape # reshape for computation input = input.reshape(-1, *ishape[-ndim:]) if iscomplex: input = torch.stack((input.real, input.imag), axis=1) input = input.reshape(-1, *ishape[-ndim:]) # apply denoising output = self.denoiser(input[:, None, ...].to(device), self.ths).to( idevice ) # perform the denoising on the real-valued tensor # reshape back if iscomplex: output = ( output[::2, ...] + 1j * output[1::2, ...] ) # build the denoised complex data output = output.reshape(ishape) return output.to(idevice)
[docs]def wavelet_dict_denoise( input, ndim, ths, wv=None, device=None, non_linearity="soft", level=None, max_iter=10, ): r""" Apply overcomplete Wavelet denoising with the :math:`\ell_1` norm. This denoiser is defined as the solution to the optimization problem: .. math:: \underset{x}{\arg\min} \; \|x-y\|^2 + \lambda \|\Psi x\|_n where :math:`\Psi` is an overcomplete wavelet transform, composed of 2 or more wavelets, i.e., :math:`\Psi=[\Psi_1,\Psi_2,\dots,\Psi_L]`, :math:`\lambda>0` is a hyperparameter, and where :math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``), the :math:`\ell_0` norm (``non_linearity="hard"``) or a variant of the :math:`\ell_0` norm (``non_linearity="topk"``) where only the top-k coefficients are kept; see :meth:`deepinv.models.WaveletDenoiser` for more details. The solution is not available in closed-form, thus the denoiser runs an optimization algorithm for each test image. Attributes ---------- input : np.ndarray | torch.Tensor Input image of shape (..., n_ndim, ..., n_0). ndim : int Number of spatial dimensions, can be either ``2`` or ``3``. ths : float Denoise threshold. wv : Iterable[str], optional List of mother wavelets. The names of the wavelets can be found in `here <https://wavelets.pybytes.com/>`_. The default is ``["db8", "db4"]``. device : str, optional Device on which the wavelet transform is computed. The default is ``None`` (infer from input). non_linearity : str, optional ``"soft"``, ``"hard"`` or ``"topk"`` thresholding. The default is ``"soft"``. level: int, optional Level of the wavelet transform. The default is ``None``. max_iter : int, optional Number of iterations of the optimization algorithm. The default is ``10``. Returns ------- output : np.ndarray | torch.Tensor Denoised image of shape (..., n_ndim, ..., n_0). """ # cast to numpy if required if isinstance(input, np.ndarray): isnumpy = True input = torch.as_tensor(input) else: isnumpy = False # initialize denoiser WD = WaveletDictDenoiser( ndim, ths, False, wv, device, non_linearity, level, max_iter ) output = WD(input) # cast back to numpy if requried if isnumpy: output = output.numpy(force=True) return output
# %% local utils class _WaveletDenoiser(nn.Module): def __init__( self, level=None, wv="db4", device="cpu", non_linearity="soft", wvdim=2 ): super().__init__() self.level = level self.wv = wv self.device = device self.non_linearity = non_linearity self.dimension = wvdim def dwt(self, x): r""" Applies the wavelet analysis. """ if self.level is None: level = pywt.dwtn_max_level(x.shape[-self.dimension :], self.wv) self.level = level else: level = self.level if self.dimension == 2: dec = ptwt.wavedec2(x, pywt.Wavelet(self.wv), mode="zero", level=level) elif self.dimension == 3: dec = ptwt.wavedec3(x, pywt.Wavelet(self.wv), mode="zero", level=level) dec = [list(t) if isinstance(t, tuple) else t for t in dec] return dec def flatten_coeffs(self, dec): r""" Flattens the wavelet coefficients and returns them in a single torch vector of shape (n_coeffs,). """ if self.dimension == 2: flat = torch.hstack( [dec[0].flatten()] + [decl.flatten() for l in range(1, len(dec)) for decl in dec[l]] ) elif self.dimension == 3: flat = torch.hstack( [dec[0].flatten()] + [dec[l][key].flatten() for l in range(1, len(dec)) for key in dec[l]] ) return flat @staticmethod def psi(x, wavelet="db4", level=None, dimension=2): r""" Returns a flattened list containing the wavelet coefficients. """ if level is None: level = pywt.dwtn_max_level(x.shape[-dimension:], wavelet) if dimension == 2: dec = ptwt.wavedec2(x, pywt.Wavelet(wavelet), mode="zero", level=level) dec = [list(t) if isinstance(t, tuple) else t for t in dec] vec = [decl.flatten() for l in range(1, len(dec)) for decl in dec[l]] elif dimension == 3: dec = ptwt.wavedec3(x, pywt.Wavelet(wavelet), mode="zero", level=level) dec = [list(t) if isinstance(t, tuple) else t for t in dec] vec = [dec[l][key].flatten() for l in range(1, len(dec)) for key in dec[l]] return vec def iwt(self, coeffs): r""" Applies the wavelet synthesis. """ coeffs = [tuple(t) if isinstance(t, list) else t for t in coeffs] if self.dimension == 2: rec = ptwt.waverec2(coeffs, pywt.Wavelet(self.wv)) elif self.dimension == 3: rec = ptwt.waverec3(coeffs, pywt.Wavelet(self.wv)) return rec def prox_l1(self, x, ths): r""" Soft thresholding of the wavelet coefficients. Arguments --------- x : torch.Tensor Wavelet coefficients. ths : float, optional Threshold. It can be element-wise, in which case it is assumed to be broadcastable with ``input``. The default is ``0.1``. Returns ------- torch.Tensor Thresholded wavelet coefficients. """ return threshold.soft_thresh(x, ths) def prox_l0(self, x, ths): r""" Hard thresholding of the wavelet coefficients. Arguments --------- x : torch.Tensor Wavelet coefficients. ths : float, optional Threshold. It can be element-wise, in which case it is assumed to be broadcastable with ``input``. The default is ``0.1``. Returns ------- torch.Tensor Thresholded wavelet coefficients. """ if isinstance(ths, float): ths_map = ths else: ths_map = ths.repeat( 1, 1, 1, x.shape[-2], x.shape[-1] ) # Reshaping to image wavelet shape out = x.clone() out[abs(out) < ths_map] = 0 return out def hard_threshold_topk(self, x, ths): r""" Hard thresholding of the wavelet coefficients by keeping only the top-k coefficients and setting the others to 0. Arguments --------- x : torch.Tensor Wavelet coefficients. ths : float | int, optional Top k coefficients to keep. If ``float``, it is interpreted as a proportion of the total number of coefficients. If ``int``, it is interpreted as the number of coefficients to keep. The default is ``0.1`. Returns ------- torch.Tensor Thresholded wavelet coefficients. """ if isinstance(ths, float): k = int(ths * x.shape[-3] * x.shape[-2] * x.shape[-1]) else: k = int(ths) # Reshape arrays to 2D and initialize output to 0 x_flat = x.reshape(x.shape[0], -1) out = torch.zeros_like(x_flat) topk_indices_flat = torch.topk(abs(x_flat), k, dim=-1)[1] # Convert the flattened indices to the original indices of x batch_indices = ( torch.arange(x.shape[0], device=x.device).unsqueeze(1).repeat(1, k) ) topk_indices = torch.stack([batch_indices, topk_indices_flat], dim=-1) # Set output's top-k elements to values from original x out[tuple(topk_indices.view(-1, 2).t())] = x_flat[ tuple(topk_indices.view(-1, 2).t()) ] return torch.reshape(out, x.shape) def thresold_func(self, x, ths): r""" " Apply thresholding to the wavelet coefficients. """ if self.non_linearity == "soft": y = self.prox_l1(x, ths) elif self.non_linearity == "hard": y = self.prox_l0(x, ths) elif self.non_linearity == "topk": y = self.hard_threshold_topk(x, ths) return y def thresold_2D(self, coeffs, ths): r""" Thresholds coefficients of the 2D wavelet transform. """ for level in range(1, self.level + 1): ths_cur = self.reshape_ths(ths, level) for c in range(3): coeffs[level][c] = self.thresold_func(coeffs[level][c], ths_cur[c]) return coeffs def threshold_3D(self, coeffs, ths): r""" Thresholds coefficients of the 3D wavelet transform. """ for level in range(1, self.level + 1): ths_cur = self.reshape_ths(ths, level) for c, key in enumerate(["aad", "ada", "daa", "add", "dad", "dda", "ddd"]): coeffs[level][key] = self.prox_l1(coeffs[level][key], ths_cur[c]) return coeffs def threshold_ND(self, coeffs, ths): r""" Apply thresholding to the wavelet coefficients of arbitrary dimension. """ if self.dimension == 2: coeffs = self.thresold_2D(coeffs, ths) elif self.dimension == 3: coeffs = self.threshold_3D(coeffs, ths) else: raise ValueError("Only 2D and 3D wavelet transforms are supported") return coeffs def pad_input(self, x): r""" Pad the input to make it compatible with the wavelet transform. """ if self.dimension == 2: h, w = x.size()[-2:] padding_bottom = h % 2 padding_right = w % 2 p = (padding_bottom, padding_right) x = torch.nn.ReplicationPad2d((0, p[0], 0, p[1]))(x) elif self.dimension == 3: d, h, w = x.size()[-3:] padding_depth = d % 2 padding_bottom = h % 2 padding_right = w % 2 p = (padding_depth, padding_bottom, padding_right) x = torch.nn.ReplicationPad3d((0, p[0], 0, p[1], 0, p[2]))(x) return x, p def crop_output(self, x, padding): r""" Crop the output to make it compatible with the wavelet transform. """ d, h, w = x.size()[-3:] if len(padding) == 2: out = x[..., : h - padding[0], : w - padding[1]] elif len(padding) == 3: out = x[..., : d - padding[0], : h - padding[1], : w - padding[2]] return out def reshape_ths(self, ths, level): r""" Reshape the thresholding parameter in the appropriate format, i.e. either: - a list of 3 elements, or - a tensor of 3 elements. Since the approximation coefficients are not thresholded, we do not need to provide a thresholding parameter, ths has shape (n_levels-1, 3). """ numel = 3 if self.dimension == 2 else 7 if not torch.is_tensor(ths): if isinstance(ths, int) or isinstance(ths, float): ths_cur = [ths] * numel elif len(ths) == 1: ths_cur = [ths[0]] * numel else: ths_cur = ths[level] if len(ths_cur) == 1: ths_cur = [ths_cur[0]] * numel else: if len(ths.shape) == 1: # Needs to reshape to shape (n_levels-1, 3) ths_cur = ths.squeeze().repeat(numel) else: ths_cur = ths[level - 2] return ths_cur def forward(self, x, ths): # Pad data x, padding = self.pad_input(x) # Apply wavelet transform coeffs = self.dwt(x) # Threshold coefficients (we do not threshold the approximation coefficients) coeffs = self.threshold_ND(coeffs, ths) # Inverse wavelet transform y = self.iwt(coeffs) # Crop data y = self.crop_output(y, padding) return y class _WaveletDictDenoiser(nn.Module): r""" Overcomplete Wavelet denoising with the :math:`\ell_1` norm. This denoiser is defined as the solution to the optimization problem: .. math:: \underset{x}{\arg\min} \; \|x-y\|^2 + \lambda \|\Psi x\|_n where :math:`\Psi` is an overcomplete wavelet transform, composed of 2 or more wavelets, i.e., :math:`\Psi=[\Psi_1,\Psi_2,\dots,\Psi_L]`, :math:`\lambda>0` is a hyperparameter, and where :math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``), the :math:`\ell_0` norm (``non_linearity="hard"``) or a variant of the :math:`\ell_0` norm (``non_linearity="topk"``) where only the top-k coefficients are kept; see :meth:`deepinv.models.WaveletDenoiser` for more details. The solution is not available in closed-form, thus the denoiser runs an optimization algorithm for each test image. :param int level: decomposition level of the wavelet transform. :param list[str] wv: list of mother wavelets. The names of the wavelets can be found in `here <https://wavelets.pybytes.com/>`_. (default: ["db8", "db4"]). :param str device: cpu or gpu. :param int max_iter: number of iterations of the optimization algorithm (default: 10). :param str non_linearity: "soft", "hard" or "topk" thresholding (default: "soft") """ def __init__( self, level=None, list_wv=["db8", "db4"], max_iter=10, non_linearity="soft", wvdim=2, ): super().__init__() self.level = level self.list_prox = nn.ModuleList( [ _WaveletDenoiser( level=level, wv=wv, non_linearity=non_linearity, wvdim=wvdim ) for wv in list_wv ] ) self.max_iter = max_iter def forward(self, y, ths): z_p = y.repeat(len(self.list_prox), *([1] * (len(y.shape)))) p_p = torch.zeros_like(z_p) x = p_p.clone() for it in range(self.max_iter): x_prev = x.clone() for p in range(len(self.list_prox)): p_p[p, ...] = self.list_prox[p](z_p[p, ...], ths) x = torch.mean(p_p.clone(), axis=0) for p in range(len(self.list_prox)): z_p[p, ...] = x + z_p[p, ...].clone() - p_p[p, ...] rel_crit = torch.linalg.norm((x - x_prev).flatten()) / torch.linalg.norm( x.flatten() + 1e-6 ) if rel_crit < 1e-3: break return x