"""
Wavelet transform routines; adapted from Sigpy [1].
References
----------
[1] https://github.com/mikgroup/sigpy/tree/main
"""
__all__ = ["fwt", "iwt"]
import torch
import numpy as np
import ptwt
import pywt
from .resize import resize
[docs]def fwt(input, ndim=None, device=None, wave_name="db4", level=None):
    """
    Forward wavelet transform.
    Adapted from Sigpy [1].
    Parameters
    ----------
    input : np.ndarray | torch.Tensor
        Input signal of shape (..., nz, ny, nx).
    ndim : int, optional
        Number of spatial dimensions over to which compute
        wavelet transform (``1``, ``2``, ``3``).
        Assume spatial axis are the rightmost ones.
        The default is ``None`` (``ndim = min(3, len(input.shape))``).
    device : str, optional
        Computational device for Wavelet transform.
        If not specified, use ``input.device``.
        The default is ``None``.
    wave_name : str, optional
        Wavelet name. The default is ``"db4"``.
    axes : Iterable[int], optional
        Axes to perform wavelet transform.
        The default is ``None`` (all axes).
    level : int, optional
        Number of wavelet levels. The default is ``None``.
    Returns
    -------
    output : np.ndarray | torch.Tensor
        Output wavelet decomposition.
    shape : Iterable[int]
        Input signal shape (``input.shape``) for synthesis.
    Examples
    --------
    >>> import torch
    >>> import deepmr
    First, generate a 2D phantom and add some noise:
    >>> img = deepmr.shepp_logan(128) + 0.05 * torch.randn(128, 128)
    Now, run wavelet decomposition:
    >>> coeff, shape = deepmr.fwt(img)
    The function returns a ``coeff`` tuple, containing the Wavelet coefficients,
    and a ``shape`` tuple, containing the original image shape for image synthesis via
    ``deepmr.iwt``:
    >>> shape
    torch.Size([128, 128])
    References
    ----------
    [1] https://github.com/mikgroup/sigpy/tree/main
    """
    if isinstance(input, np.ndarray):
        isnumpy = True
    else:
        isnumpy = False
    # cast to tensor
    input = torch.as_tensor(input)
    # get device
    idevice = input.device
    if device is None:
        device = idevice
    input = input.to(device)
    # get default ndim
    if ndim is None:
        ndim = min(3, len(input.shape))
    # pad to nearest even value
    ishape = input.shape
    zshape = [((ishape[n] + 1) // 2) * 2 for n in range(-ndim, 0)]
    zinput = resize(
        input.reshape(-1, *ishape[-ndim:]), [int(np.prod(ishape[:-ndim]))] + zshape
    )
    # select wavelet
    wavelet = pywt.Wavelet(wave_name)
    # select transform
    if ndim == 1:
        _fwt = ptwt.wavedec
    elif ndim == 2:
        _fwt = ptwt.wavedec2
    elif ndim == 3:
        _fwt = ptwt.wavedec3
    else:
        raise ValueError(
            f"Number of dimensions (={ndim}) not recognized; we support only 1, 2 and 3."
        )
    # compute
    output = _fwt(zinput, wavelet, mode="zero", level=level)
    output = list(output)
    output[0] = output[0].to(idevice)
    for n in range(1, len(output)):
        output[n] = [o.to(idevice) for o in output[n]]
    # cast to numpy if required
    if isnumpy:
        output[0] = output.numpy(force=True)
        for n in range(1, len(output)):
            output[n] = [o.numpy(force=True) for o in output[n]]
    return output, ishape 
[docs]def iwt(input, shape, device=None, wave_name="db4", level=None):
    """
    Inverse wavelet transform.
    Adapted from Sigpy [1].
    Parameters
    ----------
    input : np.ndarray | torch.Tensor
        Input wavelet decomposition.
    shape : Iterable[int], optional
        Spatial matrix size of output signal ``(nx)`` (1D signals),
        ``(ny, nx)`` (2D) or ``(nz, ny, nx)`` (3D).
    device : str, optional
        Computational device for Wavelet transform.
        If not specified, use ``input.device``.
        The default is ``None``.
    wave_name : str, optional
        Wavelet name. The default is ``"db4"``.
    axes : Iterable[int], optional
        Axes to perform wavelet transform.
        The default is ``None`` (all axes).
    level : int, optional
        Number of wavelet levels. The default is ``None``.
    Returns
    -------
    output : np.ndarray | torch.Tensor
        Output signal of shape (..., nz, ny, nx).
    Examples
    --------
    >>> import torch
    >>> import deepmr
    First, generate a 2D phantom and add some noise:
    >>> img0 = deepmr.shepp_logan(128) + 0.05 * torch.randn(128, 128)
    Now, run wavelet decomposition:
    >>> coeff, shape = deepmr.fwt(img0)
    The image can be synthesized from ``coeff`` and ``shape`` as:
    >>> img = deepmr.iwt(coeff, shape)
    References
    ----------
    [1] https://github.com/mikgroup/sigpy/tree/main
    """
    if isinstance(input, np.ndarray):
        isnumpy = True
    else:
        isnumpy = False
    # cast to tensor
    output = list(input)
    output[0] = torch.as_tensor(output[0])
    for n in range(1, len(output)):
        output[n] = [torch.as_tensor(o) for o in output[n]]
    # get device
    idevice = output[0].device
    if device is None:
        device = idevice
    # transfer to device
    output[0] = output[0].to(idevice)
    for n in range(1, len(output)):
        output[n] = [o.to(idevice) for o in output[n]]
    # convert to tuple
    for n in range(1, len(output)):
        output[n] = tuple(output[n])
    output = tuple(output)
    # select wavelet
    wavelet = pywt.Wavelet(wave_name)
    # select transform
    ndim = len(shape)
    if ndim == 1:
        _iwt = ptwt.waverec
    elif ndim == 2:
        _iwt = ptwt.waverec2
    elif ndim == 3:
        _iwt = ptwt.waverec3
    else:
        raise ValueError(
            f"Number of dimensions (={ndim}) not recognized; we support only 1, 2 and 3."
        )
    # compute
    zoutput = _iwt(output, wavelet)
    zoutput = zoutput.reshape(*shape[:-ndim], *zoutput.shape[-ndim:])
    output = resize(zoutput, shape)
    output = output.to(idevice)
    # cast to numpy if required
    if isnumpy:
        output = output.numpy(force=True)
    # erase singleton dimension
    if len(output.shape) == ndim + 1 and output.shape[0] == 1:
        output = output[0]
    return output