deepmr.prox.WaveletDenoiser#

class deepmr.prox.WaveletDenoiser(*args: Any, **kwargs: Any)[source]#

Orthogonal Wavelet denoising with the \(\ell_1\) norm.

Adapted from :func:deepinv.denoisers.WaveletDenoiser to support complex-valued inputs.

This denoiser is defined as the solution to the optimization problem:

\[\underset{x}{\arg\min} \; \|x-y\|^2 + \gamma \|\Psi x\|_n\]

where \(\Psi\) is an orthonormal wavelet transform, \(\lambda>0\) is a hyperparameter, and where \(\|\cdot\|_n\) is either the \(\ell_1\) norm (non_linearity="soft") or the \(\ell_0\) norm (non_linearity="hard"). A variant of the \(\ell_0\) norm is also available (non_linearity="topk"), where the thresholding is done by keeping the \(k\) largest coefficients in each wavelet subband and setting the others to zero.

The solution is available in closed-form, thus the denoiser is cheap to compute.

Notes

Following common practice in signal processing, only detail coefficients are regularized, and the approximation coefficients are left untouched.

Warning

For 3D data, the computational complexity of the wavelet transform cubically with the size of the support. For large 3D data, it is recommended to use wavelets with small support (e.g. db1 to db4).

ndim#

Number of spatial dimensions, can be either 2 or 3.

Type:

int

ths#

Denoise threshold. The default is 0.1.

Type:

float, optional

trainable#

If True, threshold value is trainable, otherwise it is not. The default is False.

Type:

bool, optional

wv#

Wavelet name to choose among those available in pywt. The default is "db4".

Type:

str, optional

device#

Device on which the wavelet transform is computed. The default is None (infer from input).

Type:

str, optional

non_linearity#

"soft", "hard" or "topk" thresholding. The default is "soft".

Type:

str, optional

level#

Level of the wavelet transform. The default is None.

Type:

int, optional

__init__(ndim, ths=0.1, trainable=False, wv='db4', device=None, non_linearity='soft', level=None, *args, **kwargs)[source]#

Methods

__init__(ndim[, ths, trainable, wv, device, ...])

forward(input)