"""
Wavelet transform routines; adapted from Sigpy [1].
References
----------
[1] https://github.com/mikgroup/sigpy/tree/main
"""
__all__ = ["fwt", "iwt"]
import torch
import numpy as np
import ptwt
import pywt
from .resize import resize
[docs]def fwt(input, ndim=None, device=None, wave_name="db4", level=None):
"""
Forward wavelet transform.
Adapted from Sigpy [1].
Parameters
----------
input : np.ndarray | torch.Tensor
Input signal of shape (..., nz, ny, nx).
ndim : int, optional
Number of spatial dimensions over to which compute
wavelet transform (``1``, ``2``, ``3``).
Assume spatial axis are the rightmost ones.
The default is ``None`` (``ndim = min(3, len(input.shape))``).
device : str, optional
Computational device for Wavelet transform.
If not specified, use ``input.device``.
The default is ``None``.
wave_name : str, optional
Wavelet name. The default is ``"db4"``.
axes : Iterable[int], optional
Axes to perform wavelet transform.
The default is ``None`` (all axes).
level : int, optional
Number of wavelet levels. The default is ``None``.
Returns
-------
output : np.ndarray | torch.Tensor
Output wavelet decomposition.
shape : Iterable[int]
Input signal shape (``input.shape``) for synthesis.
Examples
--------
>>> import torch
>>> import deepmr
First, generate a 2D phantom and add some noise:
>>> img = deepmr.shepp_logan(128) + 0.05 * torch.randn(128, 128)
Now, run wavelet decomposition:
>>> coeff, shape = deepmr.fwt(img)
The function returns a ``coeff`` tuple, containing the Wavelet coefficients,
and a ``shape`` tuple, containing the original image shape for image synthesis via
``deepmr.iwt``:
>>> shape
torch.Size([128, 128])
References
----------
[1] https://github.com/mikgroup/sigpy/tree/main
"""
if isinstance(input, np.ndarray):
isnumpy = True
else:
isnumpy = False
# cast to tensor
input = torch.as_tensor(input)
# get device
idevice = input.device
if device is None:
device = idevice
input = input.to(device)
# get default ndim
if ndim is None:
ndim = min(3, len(input.shape))
# pad to nearest even value
ishape = input.shape
zshape = [((ishape[n] + 1) // 2) * 2 for n in range(-ndim, 0)]
zinput = resize(
input.reshape(-1, *ishape[-ndim:]), [int(np.prod(ishape[:-ndim]))] + zshape
)
# select wavelet
wavelet = pywt.Wavelet(wave_name)
# select transform
if ndim == 1:
_fwt = ptwt.wavedec
elif ndim == 2:
_fwt = ptwt.wavedec2
elif ndim == 3:
_fwt = ptwt.wavedec3
else:
raise ValueError(
f"Number of dimensions (={ndim}) not recognized; we support only 1, 2 and 3."
)
# compute
output = _fwt(zinput, wavelet, mode="zero", level=level)
output = list(output)
output[0] = output[0].to(idevice)
for n in range(1, len(output)):
output[n] = [o.to(idevice) for o in output[n]]
# cast to numpy if required
if isnumpy:
output[0] = output.numpy(force=True)
for n in range(1, len(output)):
output[n] = [o.numpy(force=True) for o in output[n]]
return output, ishape
[docs]def iwt(input, shape, device=None, wave_name="db4", level=None):
"""
Inverse wavelet transform.
Adapted from Sigpy [1].
Parameters
----------
input : np.ndarray | torch.Tensor
Input wavelet decomposition.
shape : Iterable[int], optional
Spatial matrix size of output signal ``(nx)`` (1D signals),
``(ny, nx)`` (2D) or ``(nz, ny, nx)`` (3D).
device : str, optional
Computational device for Wavelet transform.
If not specified, use ``input.device``.
The default is ``None``.
wave_name : str, optional
Wavelet name. The default is ``"db4"``.
axes : Iterable[int], optional
Axes to perform wavelet transform.
The default is ``None`` (all axes).
level : int, optional
Number of wavelet levels. The default is ``None``.
Returns
-------
output : np.ndarray | torch.Tensor
Output signal of shape (..., nz, ny, nx).
Examples
--------
>>> import torch
>>> import deepmr
First, generate a 2D phantom and add some noise:
>>> img0 = deepmr.shepp_logan(128) + 0.05 * torch.randn(128, 128)
Now, run wavelet decomposition:
>>> coeff, shape = deepmr.fwt(img0)
The image can be synthesized from ``coeff`` and ``shape`` as:
>>> img = deepmr.iwt(coeff, shape)
References
----------
[1] https://github.com/mikgroup/sigpy/tree/main
"""
if isinstance(input, np.ndarray):
isnumpy = True
else:
isnumpy = False
# cast to tensor
output = list(input)
output[0] = torch.as_tensor(output[0])
for n in range(1, len(output)):
output[n] = [torch.as_tensor(o) for o in output[n]]
# get device
idevice = output[0].device
if device is None:
device = idevice
# transfer to device
output[0] = output[0].to(idevice)
for n in range(1, len(output)):
output[n] = [o.to(idevice) for o in output[n]]
# convert to tuple
for n in range(1, len(output)):
output[n] = tuple(output[n])
output = tuple(output)
# select wavelet
wavelet = pywt.Wavelet(wave_name)
# select transform
ndim = len(shape)
if ndim == 1:
_iwt = ptwt.waverec
elif ndim == 2:
_iwt = ptwt.waverec2
elif ndim == 3:
_iwt = ptwt.waverec3
else:
raise ValueError(
f"Number of dimensions (={ndim}) not recognized; we support only 1, 2 and 3."
)
# compute
zoutput = _iwt(output, wavelet)
zoutput = zoutput.reshape(*shape[:-ndim], *zoutput.shape[-ndim:])
output = resize(zoutput, shape)
output = output.to(idevice)
# cast to numpy if required
if isnumpy:
output = output.numpy(force=True)
# erase singleton dimension
if len(output.shape) == ndim + 1 and output.shape[0] == 1:
output = output[0]
return output