"""
EPG Relaxation operators.
Can be used to simulate longitudinal and transverse relaxation either
in absence or presence of exchange (Chemical Exchange or MT), as well as
accounting for chemical shift.
"""
__all__ = ["Relaxation"]
import math
import torch
from ._abstract_op import Operator
from ._utils import matrix_exp
[docs]class Relaxation(Operator):
"""
The "decay operator" applying relaxation and "regrowth" of the magnetization components.
Parameters
----------
device (str): str
Computational device (e.g., ``cpu`` or ``cuda:n``, with ``n=0,1,2...``).
time : torch.Tensor)
Time step in ``[ms]``.
T1 : torch.Tensor
Longitudinal relaxation time in ``[ms]`` of shape ``(npools,)``.
T2 : torch.Tensor
Transverse relaxation time in ``[ms]`` of shape ``(npools,)``.
weight : torch.Tensor, optional
Relative pool fractions of shape ``(npools,)``.
k : torch.Tensor, optional
Exchange matrix of shape ``(npools, npools)`` in ``[s**-1]``.
df : torch.Tensor, optional
Chemical shift in ``[Hz]`` of shape ``(npools,)``.
Other Parameters
----------------
name : str
Name of the operator.
"""
[docs] def __init__(
self, device, time, T1, T2, weight=None, k=None, df=None, **kwargs
): # noqa
super().__init__(**kwargs)
# offload (not sure if this is needed?)
time = torch.as_tensor(time, dtype=torch.float32, device=device)
time = torch.atleast_1d(time)
T1 = torch.as_tensor(T1, dtype=torch.float32, device=device)
T1 = torch.atleast_1d(T1)
T2 = torch.as_tensor(T2, dtype=torch.float32, device=device)
T2 = torch.atleast_1d(T2)
# cast to tensors
if weight is not None:
weight = torch.as_tensor(weight, dtype=torch.float32, device=device)
if k is not None:
k = torch.as_tensor(k, dtype=torch.float32, device=device)
k = _prepare_exchange(weight, k)
if df is not None:
df = torch.as_tensor(df, dtype=torch.float32, device=device)
df = torch.atleast_1d(df)
# prepare operators
if weight is None or k is None:
E2 = _transverse_relax_prep(time, T2)
E1, rE1 = _longitudinal_relax_prep(time, T1)
# assign functions
self._transverse_relax_apply = _transverse_relax_apply
self._longitudinal_relax_apply = _longitudinal_relax_apply
else:
E2, self._transverse_relax_apply = _transverse_relax_exchange_prep(
time, T2, k, df
)
E1, rE1 = _longitudinal_relax_exchange_prep(time, T1, weight, k)
# assign functions
self._longitudinal_relax_apply = _longitudinal_relax_exchange_apply
# assign matrices
self.E1 = E1
self.rE1 = rE1
self.E2 = E2
def apply(self, states):
"""
Apply free precession (relaxation + precession + exchange + recovery).
Parameters
----------
states : dict
Input states matrix for free pools
and, optionally, for bound pools.
Returns
-------
states : dict
Output states matrix for free pools
and, optionally, for bound pools.
"""
states = self._transverse_relax_apply(states, self.E2)
states = self._longitudinal_relax_apply(states, self.E1, self.rE1)
# relaxation for moving spins
if "moving" in states:
states["moving"] = self._transverse_relax_apply(states["moving"], self.E2)
states["moving"] = self._longitudinal_relax_apply(
states["moving"], self.E1, self.rE1
)
return states
# %% local utils
def _prepare_exchange(weight, k):
# prepare
if k.shape[-1] == 1: # BM or MT
k0 = 0 * k
k1 = torch.cat((k0, k * weight[..., [0]]), axis=-1)
k2 = torch.cat((k * weight[..., [1]], k0), axis=-1)
k = torch.stack((k1, k2), axis=-2)
else: # BM-MT
k0 = 0 * k[..., [0]]
k1 = torch.cat((k0, k[..., [0]] * weight[..., [0]], k0), axis=-1)
k2 = torch.cat(
(k[..., [0]] * weight[..., [1]], k0, k[..., [1]] * weight[..., [1]]),
axis=-1,
)
k3 = torch.cat((k0, k[..., [1]] * weight[..., [2]], k0), axis=-1)
k = torch.stack((k1, k2, k3), axis=-2)
# finalize exchange
return _particle_conservation(k)
def _particle_conservation(k):
"""Adjust diagonal of exchange matrix by imposing particle conservation."""
# get shape
npools = k.shape[-1]
for n in range(npools):
k[..., n, n] = 0.0 # ignore existing diagonal
k[..., n, n] = -k[..., n].sum(dim=-1)
return k
def _transverse_relax_apply(states, E2):
# parse
F = states["F"]
# apply
F[..., 0] = F[..., 0].clone() * E2 # F+
F[..., 1] = F[..., 1].clone() * E2.conj() # F-
# prepare for output
states["F"] = F
return states
def _longitudinal_relax_apply(states, E1, rE1):
# parse
Z = states["Z"]
# apply
Z = Z.clone() * E1 # decay
Z[0] = Z[0].clone() + rE1 # regrowth
# prepare for output
states["Z"] = Z
return states
def _transverse_relax_prep(time, T2):
# compute R2
R2 = 1 / T2
# calculate operators
E2 = torch.exp(-R2 * time)
return E2
def _longitudinal_relax_prep(time, T1):
# compute R2
R1 = 1 / T1
# calculate operators
E1 = torch.exp(-R1 * time)
rE1 = 1 - E1
return E1, rE1
def _transverse_relax_exchange_apply(states, E2):
# parse
F = states["F"]
# apply
F[..., 0] = torch.einsum("...ij,...j->...i", E2, F[..., 0].clone())
F[..., 1] = torch.einsum("...ij,...j->...i", E2.conj(), F[..., 1].clone())
# prepare for output
states["F"] = F
return states
def _longitudinal_relax_exchange_apply(states, E1, rE1):
# parse
Z = states["Z"]
# get ztot
if "Zbound" in states:
Zbound = states["Zbound"]
Ztot = torch.cat((Z, Zbound), axis=-1)
else:
Ztot = Z
# apply
Ztot = torch.einsum("...ij,...j->...i", E1, Ztot.clone())
Ztot[0] = Ztot[0].clone() + rE1
# prepare for output
if "Zbound" in states:
states["Z"] = Ztot[..., :-1]
states["Zbound"] = Ztot[..., [-1]]
else:
states["Z"] = Ztot
return states
def _transverse_relax_exchange_prep(time, T2, k, df=None):
# compute R2
R2 = 1 / T2
# add chemical shift
if df is not None:
R2tot = R2 + 1j * 2 * math.pi * df * 1e-3 # (account for time in [ms])
else:
R2tot = R2
# get npools
npools = R2tot.shape[-1]
# case 1: MT
if npools == 1:
return torch.exp(-R2tot * time), _transverse_relax_apply
# case 2: BM or BM-MT
else:
# cast to complex
R2tot = R2tot.to(torch.complex64)
# recovery
Id = torch.eye(npools, dtype=R2tot.dtype, device=R2tot.device)
# coefficients
lambda2 = (
k[..., :npools, :npools] * 1e-3 - R2tot[:, None] * Id
) # assume MT pool is the last
# actual operators
E2 = matrix_exp(lambda2 * time)
return E2, _transverse_relax_exchange_apply
def _longitudinal_relax_exchange_prep(time, T1, weight, k):
# compute R2
R1 = 1 / T1
# get npools
npools = R1.shape[-1]
if weight.shape[-1] == npools + 1: # MT case
R1 = torch.cat((R1, R1[..., [0]]), axis=-1)
npools += 1
# cast to complex
R1 = R1.to(torch.complex64)
# recovery
Id = torch.eye(npools, dtype=R1.dtype, device=R1.device)
C = weight * R1
# coefficients
lambda1 = k * 1e-3 - R1 * Id
# actual operators
E1 = matrix_exp(lambda1 * time)
rE1 = torch.einsum("...ij,...j->...i", (E1 - Id), torch.linalg.solve(lambda1, C))
return E1, rE1