Source code for torchsim.base.decorators._autocast

"""Automatic torch converter."""

__all__ = ["autocast"]

import inspect

from functools import wraps
from typing import Callable

import torch

from mrinufft._array_compat import _to_torch, _get_leading_argument, _get_device


[docs] def autocast(func: Callable) -> Callable: """ Force all inputs to be torch tensors of the same size on the same device. """ @wraps(func) def wrapper(*args, **kwargs): args, kwargs = _fill_kwargs(func, args, kwargs) # convert arrays to torch args, kwargs = _to_torch(*args, **kwargs) # convert remaining objects to torch args, kwargs = _to_tensors(*args, **kwargs) # enforce float32 for floating point tensors args, kwargs = _enforce_precision(*args, **kwargs) # get device from first positional or keyworded argument leading_arg = _get_leading_argument(args, kwargs) # get array module from leading argument device = _get_device(leading_arg) # move everything to the leading argument device args, kwargs = _to_device(device, *args, **kwargs) # run function return func(*args, **kwargs) return wrapper
# %% subroutines def _fill_kwargs(func, args, kwargs): """This automatically fills missing kwargs with default values.""" signature = inspect.signature(func) # Get number of arguments n_args = len(args) # Create a dictionary of keyword arguments and their default values _kwargs = {} for k, v in signature.parameters.items(): if v.default is not inspect.Parameter.empty: _kwargs[k] = v.default else: _kwargs[k] = None # Merge the default keyword arguments with the provided kwargs for k in kwargs.keys(): _kwargs[k] = kwargs[k] # Replace args _keys = list(_kwargs.keys())[n_args:] _values = list(_kwargs.values())[n_args:] return args, dict(zip(_keys, _values)) def _enforce_precision(*args, **kwargs): """Enforce tensors precision.""" args = list(args) for n in range(len(args)): if isinstance(args[n], torch.Tensor) and torch.is_floating_point(args[n]): args[n] = args[n].to(torch.float32) # convert keyworded if kwargs: process_kwargs_vals, _ = _to_tensors(*kwargs.values()) kwargs = {k: v for k, v in zip(kwargs.keys(), process_kwargs_vals)} return args, kwargs def _to_tensors(*args, **kwargs): """Enforce tensors.""" args = list(args) for n in range(len(args)): try: args[n] = torch.as_tensor(args[n]) except Exception: pass # convert keyworded if kwargs: process_kwargs_vals, _ = _to_tensors(*kwargs.values()) kwargs = {k: v for k, v in zip(kwargs.keys(), process_kwargs_vals)} return args, kwargs def _to_device(device, *args, **kwargs): """Enforce same device.""" for arg in args: try: arg = arg.to(device) except Exception: pass # convert keyworded if kwargs: process_kwargs_vals, _ = _to_device(device, *kwargs.values()) kwargs = {k: v for k, v in zip(kwargs.keys(), process_kwargs_vals)} return args, kwargs