"""Array shape manipulation routines."""
__all__ = ["resize", "resample"]
import numpy as np
import torch
from .filter import fermi
[docs]def resize(input, oshape):
"""
Resize with zero-padding or cropping.
Adapted from SigPy [1].
Parameters
----------
input : np.ndarray | torch.Tensor
Input tensor of shape ``(..., ishape)``.
oshape : Iterable
Output shape.
Returns
-------
output : np.ndarray | torch.Tensor
Zero-padded or cropped tensor of shape ``(..., oshape)``.
Examples
--------
>>> import torch
>>> import deepmr
We can pad tensors to desired shape:
>>> x = torch.tensor([0, 1, 0])
>>> y = deepmr.resize(x, [5])
>>> y
tensor([0, 0, 1, 0, 0])
Batch dimensions are automatically expanded (pad will be applied starting from rightmost dimension):
>>> x = torch.tensor([0, 1, 0])[None, ...]
>>> x.shape
torch.Size([1, 3])
>>> y = deepmr.resize(x, [5]) # len(oshape) == 1
>>> y.shape
torch.Size([1, 5])
Similarly, if oshape is smaller than ishape, the tensor will be cropped:
>>> x = torch.tensor([0, 0, 1, 0, 0])
>>> y = deepmr.resize(x, [3])
>>> y
tensor([0, 1, 0])
Again, batch dimensions are automatically expanded:
>>> x = torch.tensor([0, 0, 1, 0, 0])[None, ...]
>>> x.shape
torch.Size([1, 5])
>>> y = deepmr.resize(x, [3]) # len(oshape) == 1
>>> y.shape
torch.Size([1, 3])
References
----------
[1] https://github.com/mikgroup/sigpy/blob/main/sigpy/util.py
"""
if isinstance(input, np.ndarray):
isnumpy = True
input = torch.as_tensor(input)
else:
isnumpy = False
if isinstance(oshape, int):
oshape = [oshape]
ishape1, oshape1 = _expand_shapes(input.shape, oshape)
if ishape1 == oshape1:
return input
# shift not supported for now
ishift = [max(i // 2 - o // 2, 0) for i, o in zip(ishape1, oshape1)]
oshift = [max(o // 2 - i // 2, 0) for i, o in zip(ishape1, oshape1)]
copy_shape = [
min(i - si, o - so) for i, si, o, so in zip(ishape1, ishift, oshape1, oshift)
]
islice = tuple([slice(si, si + c) for si, c in zip(ishift, copy_shape)])
oslice = tuple([slice(so, so + c) for so, c in zip(oshift, copy_shape)])
output = torch.zeros(oshape1, dtype=input.dtype, device=input.device)
input = input.reshape(ishape1)
output[oslice] = input[islice]
if isnumpy:
output = output.numpy(force=True)
return output
[docs]def resample(input, oshape, filt=True, polysmooth=False):
"""
Resample a n-dimensional signal.
Parameters
----------
input : np.ndarray | torch.Tensor
Input tensor of shape ``(..., ishape)``.
oshape : Iterable
Output shape.
filt : bool, optional
If True and signal is upsampled (i.e., ``any(oshape > ishape)``),
apply Fermi filter to limit ringing.
The default is True.
polysmooth : bool, optional
If true, perform polynomial smoothing.
The default is False. !!! NOT IMPLEMENTED YET !!!
Returns
-------
output : np.ndarray | torch.Tensor
Resampled tensor of shape ``(..., oshape)``.
"""
if isinstance(input, np.ndarray):
isnumpy = True
input = torch.as_tensor(input)
else:
isnumpy = False
if isinstance(oshape, int):
oshape = [oshape]
# first, get number of dimensions
ndim = len(oshape)
axes = list(range(-ndim, 0))
isreal = torch.isreal(input).all()
# take fourier transform along last ndim axes
freq = _fftc(input, axes)
# get initial and final shapes
ishape1, oshape1 = _expand_shapes(input.shape, oshape)
# build filter
if filt and np.any(np.asarray(oshape1) > np.asarray(ishape1)):
size = np.max(oshape1)
width = np.min(oshape1)
filt = fermi(ndim, size, width)
filt = resize(filt, oshape1) # crop to match dimension
else:
filt = None
# resize in frequency space
freq = resize(freq, oshape1)
# if required, apply filtering
if filt is not None:
freq *= filt.to(freq.device)
# transform back
output = _ifftc(freq, axes)
# smooth
if polysmooth:
print("Polynomial smoothing not implemented yet; skipping")
# take magnitude if original signal was real
if isreal:
output = abs(output)
if isnumpy:
output = output.numpy(force=True)
return output
# %% subroutines
def _expand_shapes(*shapes):
shapes = [list(shape) for shape in shapes]
max_ndim = max(len(shape) for shape in shapes)
shapes_exp = [np.asarray([1] * (max_ndim - len(shape)) + shape) for shape in shapes]
shapes_exp = np.stack(shapes_exp, axis=0) # (nshapes, max_ndim)
shapes_exp = np.max(shapes_exp, axis=0)
# restore original shape in non-padded portions
shapes_exp = [list(shapes_exp[: -len(shape)]) + shape for shape in shapes]
return tuple(shapes_exp)
def _fftc(x, ax):
return torch.fft.fftshift(
torch.fft.fftn(torch.fft.ifftshift(x, dim=ax), dim=ax, norm="ortho"), dim=ax
)
def _ifftc(x, ax):
return torch.fft.fftshift(
torch.fft.ifftn(torch.fft.ifftshift(x, dim=ax), dim=ax, norm="ortho"), dim=ax
)