# mypy: allow-untyped-defs """ APIs related to torch.compile which lazily import torch._dynamo to avoid circular dependencies. """ import functools def _disable_dynamo(fn=None, recursive=True): """ This API should be only used inside torch, external users should still use torch._dynamo.disable. The main goal of this API is to avoid circular imports issues that is common while using _dynamo.disable inside torch itself. This API avoids it by lazily importing torch._dynamo from the import time to the invocation of the decorated function. """ if fn is not None: @functools.wraps(fn) def inner(*args, **kwargs): # cache this on the first invocation to avoid adding too much overhead. disable_fn = getattr(fn, "__dynamo_disable", None) if disable_fn is None: import torch._dynamo disable_fn = torch._dynamo.disable(fn, recursive) fn.__dynamo_disable = disable_fn return disable_fn(*args, **kwargs) return inner else: # decorator usage like @_disable_dynamo(recursive=False). The resulting # object expects the original decorated function as the arg. return functools.partial(_disable_dynamo, recursive=recursive)