Source code for deepmr.bloch.ops._epg

__all__ = ["EPGstates"]

import numpy as np
import torch


[docs]def EPGstates( device, batch_size, nstates, nlocs, npulses, npools=1, weight=None, model="single", moving=False, ): """ EPG states matrix. Stores dephasing states for transverse and longitudinal magnetization. Parameters ---------- device : str Computational device (e.g., ``cpu`` or ``cuda:n``, with ``n=0,1,2...``). batch_size : int Number of different atoms (e.g., voxels) to be simultaneously simulated. nstates : int Maximum number of dephasing states. nlocs : int Number of spatial locations contributing to each atom states (e.g., slice points). npulses : int Number of RF pulses applied during the sequence. npools : int Number of pools contributing to signal (e.g., Free Water / Myelin Water). weight : torch.Tensor[float], optional Relative fraction for each pool. For single pool, this is the initial magnetization. The default is ``None`` (i.e., ``M0 = [0, 0, 1]``). model : str, optional Type of signal model. If "mt" is in the model name, include pure longitudinal states (bound pool). The default is ``single`` (single pool model). moving : bool, optional Flag for moving spins. If ``True`` include a fresh magnetization pool to replace states with ``v != 0``. The default is ``False``. Returns ------- out : dict Dictionary with states (e.g., ``F``, ``Z``) and signal buffer. """ # prepare output out = {} if weight is not None: weight = weight.clone() if len(weight.shape) == 1: weight = weight[None, :] if np.isscalar(npulses): npulses = [npulses] # initialize free pool # transverse F = torch.zeros( (batch_size, nstates, nlocs, npools, 2), dtype=torch.complex64, device=device ) F = {"real": F.real, "imag": F.imag} # initialize free pool # longitudinal Z = torch.zeros( (batch_size, nstates, nlocs, npools), dtype=torch.complex64, device=device ) Z[:, 0, ...] = 1.0 if weight is not None: Z = Z * weight[:, :npools][:, None, None] Z = {"real": Z.real, "imag": Z.imag} # append out["states"] = {"F": F, "Z": Z} # initialize moving pool if moving: # transverse Fmoving = torch.zeros( (batch_size, nstates, nlocs, npools, 2), dtype=torch.complex64, device=device, ) Fmoving = {"real": Fmoving.real, "imag": Fmoving.imag} # initialize free pool # longitudinal Zmoving = torch.zeros( (batch_size, nstates, nlocs, npools), dtype=torch.complex64, device=device ) Zmoving[:, 0, ...] = 1.0 if weight is not None: Zmoving = Zmoving * weight[:, :npools][:, None, None] Zmoving = {"real": Zmoving.real, "imag": Zmoving.imag} # append out["states"]["moving"] = {} out["states"]["moving"]["F"] = Fmoving out["states"]["moving"]["Z"] = Zmoving # initialize bound pool if model is not None and "mt" in model: Zbound = torch.zeros( (batch_size, nstates, nlocs, 1), dtype=torch.complex64, device=device ) Zbound[:, 0, :, :] = 1.0 Zbound = Zbound * weight[:, -1][:, None, None, None] Zbound = {"real": Zbound.real, "imag": Zbound.imag} out["states"]["Zbound"] = Zbound if moving: out["states"]["moving"]["Zbound"] = Zbound.clone() # initialize output signal sig = torch.zeros([batch_size] + npulses, dtype=torch.complex64, device=device) sig = {"real": sig.real, "imag": sig.imag} # append out["signal"] = sig return out