Source code for torchsim.epg._diffusion

"""Diffusion damping operator."""

__all__ = ["diffusion_op", "diffusion"]

from types import SimpleNamespace

import torch


[docs] def diffusion_op( D: torch.Tensor, time: torch.Tensor, nstates: int, total_dephasing: float, voxelsize: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor]: """ Prepare longitudinal and transverse diffusion damping operators. Parameters ---------- D : torch.Tensor Apparent diffusion coefficient in ``[m**2 / s]``. time : torch.Tensor Time interval in ``[s]``. nstates : int Number of EPG states total_dephasing : float Total dephasing induced by gradient in ``[rad]``. voxelsize : float, optional Voxel thickness along unbalanced direction in ``[m]``. The default is 1.0. Returns ------- D1 : torch.Tensor Diffusion damping operator for longitudinal states. D2 : torch.Tensor Diffusion damping operator for transverse states. """ k0_2 = (total_dephasing / voxelsize) ** 2 # actual operator calculation b_prime = k0_2 * time * 1e-3 # calculate dephasing order l = torch.arange(nstates, dtype=torch.float32, device=D.device)[:, None, None] lsq = l**2 # calculate b-factor b1 = b_prime * lsq b2 = b_prime * (lsq + l + 1.0 / 3.0) # actual operator calculation D1 = torch.exp(-b1 * D * 1e-9) D2 = torch.exp(-b2 * D * 1e-9) return D1, D2
[docs] def diffusion( states: SimpleNamespace, D1: torch.Tensor, D2: torch.Tensor, ) -> SimpleNamespace: """ Apply diffusion damping. Parameters ---------- states : SimpleNamespace Input EPG states. D1 : torch.Tensor Diffusion damping operator for longitudinal states. D2 : torch.Tensor Diffusion damping operator for transverse states. Returns ------- states : SimpleNamespace Output EPG states. """ Fplus = states.Fplus Fminus = states.Fminus Z = states.Z # apply Fplus = Fplus.clone() * D2 # Transverse damping Fminus = Fminus * D2 # Transverse damping Z = Z.clone() * D1 # Longitudinal damping # prepare for output states.Fplus = Fplus states.Fminus = Fminus states.Z = Z return states