# mypy: allow-untyped-defs from typing import Optional import torch from torch.overrides import TorchFunctionMode from torch.utils._contextlib import context_decorator import functools CURRENT_DEVICE: Optional[torch.device] = None @functools.lru_cache(1) def _device_constructors(): return { # standard ones torch.empty, torch.empty_permuted, torch.empty_strided, torch.empty_quantized, torch.ones, torch.arange, torch.bartlett_window, torch.blackman_window, torch.eye, torch.fft.fftfreq, torch.fft.rfftfreq, torch.full, torch.fill, torch.hamming_window, torch.hann_window, torch.kaiser_window, torch.linspace, torch.logspace, torch.nested.nested_tensor, # This function doesn't actually take a device argument # torch.normal, torch.ones, torch.rand, torch.randn, torch.randint, torch.randperm, torch.range, torch.sparse_coo_tensor, torch.sparse_compressed_tensor, torch.sparse_csr_tensor, torch.sparse_csc_tensor, torch.sparse_bsr_tensor, torch.sparse_bsc_tensor, torch.tril_indices, torch.triu_indices, torch.vander, torch.zeros, torch.asarray, # weird ones torch.tensor, torch.as_tensor, torch.scalar_tensor, torch.asarray, } # NB: This is directly called from C++ in torch/csrc/Device.cpp class DeviceContext(TorchFunctionMode): def __init__(self, device): self.device = torch.device(device) def __enter__(self): global CURRENT_DEVICE self.old_device = CURRENT_DEVICE CURRENT_DEVICE = self.device return super().__enter__() def __exit__(self, exc_type, exc_val, exc_tb): global CURRENT_DEVICE CURRENT_DEVICE = self.old_device return super().__exit__(exc_type, exc_val, exc_tb) def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} if func in _device_constructors() and kwargs.get('device') is None: kwargs['device'] = self.device return func(*args, **kwargs) # NB: This is directly called from C++ in torch/csrc/Device.cpp def device_decorator(device, func): return context_decorator(lambda: device, func) def set_device(device): """ Set the default device inside of the wrapped function by decorating it with this function. If you would like to use this as a context manager, use device as a context manager directly, e.g., ``with torch.device(device)``. """ return lambda func: device_decorator(torch.device(device), func)