{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Automatic differentiation\n\nThis example showcase the automatic differentiation capabilities\nof the framework.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n\n# Install libraries needed for Colab\n\nThe below installation commands are needed to be run only on Google Colab.\n
\n
\n \n \"Open\n \n
\n " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Install libraries\n!pip install torchsim" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we will import the required packages:\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import warnings\n\nwarnings.filterwarnings(\"ignore\")\n\nfrom functools import partial\n\nimport numpy as np\nimport torch\n\nfrom torch.func import jacrev\n\nimport matplotlib.pyplot as plt\nimport time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will show how to use automatic differentiation\nto automatically compute Cramer Rao Lower Bound.\n\nThis can be used as a cost function to optimize acquisition schedules,\nfor example for quantitative MRI\n\nWe'll focuse on a simple Fast Spin Echo acquisition:\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torchsim" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cramer Rao Lower Bound is defined as the diagonal of the inverse\nof Fisher information matrix. This can be computed as\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def calculate_crlb(grad, W=None, weight=1.0):\n if len(grad.shape) == 1:\n grad = grad[None, :]\n\n if W is None:\n W = torch.eye(grad.shape[0], dtype=grad.dtype, device=grad.device)\n\n J = torch.stack((grad.real, grad.imag), axis=0) # (nparams, nechoes)\n J = J.permute(2, 1, 0)\n\n # calculate Fischer information matrix\n In = torch.einsum(\"bij,bjk->bik\", J, J.permute(0, 2, 1))\n I = In.sum(axis=0) # (nparams, nparams)\n\n # Invert\n return torch.trace(torch.linalg.inv(I) * W).real * weight" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "notice that we used the trace as a cost function.\nFor optimization, we need the gradient of this cost\nwrt sequence parameters.\n\nThis can be obtained as:\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def _crlb_cost(ESP, T1, T2, flip):\n\n # calculate signal and derivative\n _, grad = torchsim.fse_sim(flip=flip, ESP=ESP, T1=T1, T2=T2, diff=\"T2\")\n\n # calculate cost\n return calculate_crlb(grad)\n\n\ndef crlb_cost(flip, ESP, T1, T2):\n flip = torch.as_tensor(flip, dtype=torch.float32)\n flip.requires_grad = True\n\n # get partial function\n _cost = partial(_crlb_cost, ESP, T1, T2)\n _dcost = jacrev(_cost)\n\n return _cost(flip).detach().cpu().numpy(), _dcost(flip).detach().cpu().numpy()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As reference, we compute derivatives via finite differences\napproximation. This is inaccurate, but as easy to implement\nas automatic differentiation:\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def fse_finitediff_grad(flip, ESP, T1, T2):\n sig = torchsim.fse_sim(flip=flip, ESP=ESP, T1=T1, T2=T2)\n\n # numerical derivative\n dt = 1.0\n dsig = torchsim.fse_sim(flip=flip, ESP=ESP, T1=T1, T2=T2 + dt)\n\n return sig, (dsig - sig) / dt\n\n\ndef _crlb_finitediff_cost(ESP, T1, T2, flip):\n\n # calculate signal and derivative\n _, grad = fse_finitediff_grad(flip, ESP, T1, T2)\n\n # calculate cost\n return calculate_crlb(grad).cpu().detach().numpy()\n\n\ndef crlb_finitediff_cost(flip, ESP, T1, T2):\n\n # initial cost\n cost0 = _crlb_finitediff_cost(ESP, T1, T2, flip)\n dcost = []\n\n for n in range(len(flip)):\n # get angles\n angles = flip.copy()\n angles[n] += 1.0\n dcost.append(_crlb_finitediff_cost(ESP, T1, T2, angles))\n\n return cost0, (np.asarray(dcost) - cost0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we can compute optimization for a specific tissue.\n\nWe assume T1 = 1000.0 ms and T2 = 100.0 ms:\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "t1 = 1000.0\nt2 = 100.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's compute CRLB for a constant 180.0 refocusing schedule, preceded by\na ramp:\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "angles = np.ones(96) * 60.0\nesp = 5.0 # ms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Run and plot timings:\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "tstart = time.time()\nsig0, grad0 = fse_finitediff_grad(angles, esp, t1, t2)\ntstop = time.time()\ntgrad0 = tstop - tstart\n\ntstart = time.time()\nsig, grad = torchsim.fse_sim(flip=angles, ESP=esp, T1=t1, T2=t2, diff=\"T2\")\ntstop = time.time()\ntgrad = tstop - tstart\n\n# cost and derivative\ntstart = time.time()\ncost0, dcost0 = crlb_finitediff_cost(angles, esp, t1, t2)\ntstop = time.time()\ntcost0 = tstop - tstart\n\ntstart = time.time()\ncost, dcost = crlb_cost(angles, esp, t1, t2)\ntstop = time.time()\ntcost = tstop - tstart\n\nfsz = 10\nplt.figure()\nplt.subplot(4, 1, 1)\nplt.rcParams.update({\"font.size\": 0.5 * fsz})\nplt.plot(angles, \".\")\nplt.xlabel(\"Echo #\", fontsize=fsz)\nplt.xlim([-1, 97])\nplt.ylabel(\"Flip Angle [deg]\", fontsize=fsz)\n\nplt.subplot(4, 1, 2)\nplt.rcParams.update({\"font.size\": 0.5 * fsz})\nplt.plot(abs(grad), \"-k\"), plt.plot(abs(grad0), \"*r\")\nplt.xlabel(\"Echo #\", fontsize=fsz)\nplt.xlim([-1, 97])\nplt.ylabel(r\"$\\frac{\\partial signal}{\\partial T2}$ [a.u.]\", fontsize=fsz)\nplt.legend(\n [\n \"Auto Diff\",\n \"Finite Diff\",\n ]\n)\n\nplt.subplot(4, 1, 3)\nplt.rcParams.update({\"font.size\": 0.5 * fsz})\nplt.plot(abs(dcost), \"-k\"), plt.plot(abs(dcost0), \"*r\")\nplt.xlabel(\"Echo #\", fontsize=fsz)\nplt.xlim([-1, 97])\nplt.ylabel(r\"$\\frac{\\partial CRLB}{\\partial FA}$ [a.u.]\", fontsize=fsz)\nplt.legend([\"Auto Diff\", \"Finite Diff\"])\n\nplt.subplot(4, 1, 4)\nlabels = [\"derivative of signal\", \"CRLB objective gradient\"]\ntime_finite = [round(tgrad0, 2), round(tcost0, 2)]\ntime_auto = [round(tgrad, 2), round(tcost, 2)]\n\nx = np.arange(len(labels)) # the label locations\nwidth = 0.35 # the width of the bars\nrects1 = plt.bar(x + width / 2, time_finite, width, label=\"Finite Diff\")\nrects2 = plt.bar(x - width / 2, time_auto, width, label=\"Auto Diff\")\n\n# Add some text for labels, title and custom x-axis tick labels, etc.\nplt.ylabel(\"Execution Time [s]\", fontsize=fsz)\nplt.xticks(x, labels, fontsize=fsz)\nplt.legend()\n\nplt.bar_label(rects1, padding=3, fontsize=fsz)\nplt.bar_label(rects2, padding=3, fontsize=fsz)\nplt.tight_layout()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 0 }