# mypy: allow-untyped-defs # Extra utilities for working with context managers that should have been # in the standard library but are not import functools import inspect import warnings import sys from typing import Any, Callable, TypeVar, cast # Used for annotating the decorator usage of _DecoratorContextManager (e.g., # 'no_grad' and 'enable_grad'). # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators FuncType = Callable[..., Any] F = TypeVar('F', bound=FuncType) def _wrap_generator(ctx_factory, func): """ Wrap each generator invocation with the context manager factory. The input should be a function that returns a context manager, not a context manager itself, to handle one-shot context managers. """ @functools.wraps(func) def generator_context(*args, **kwargs): gen = func(*args, **kwargs) # Generators are suspended and unsuspended at `yield`, hence we # make sure the grad mode is properly set every time the execution # flow returns into the wrapped generator and restored when it # returns through our `yield` to our caller (see PR #49017). try: # Issuing `None` to a generator fires it up with ctx_factory(): response = gen.send(None) while True: try: # Forward the response to our caller and get its next request request = yield response except GeneratorExit: # Inform the still active generator about its imminent closure with ctx_factory(): gen.close() raise except BaseException: # Propagate the exception thrown at us by the caller with ctx_factory(): response = gen.throw(*sys.exc_info()) else: # Pass the last request to the generator and get its response with ctx_factory(): response = gen.send(request) # We let the exceptions raised above by the generator's `.throw` or # `.send` methods bubble up to our caller, except for StopIteration except StopIteration as e: # The generator informed us that it is done: take whatever its # returned value (if any) was and indicate that we're done too # by returning it (see docs for python's return-statement). return e.value return generator_context def context_decorator(ctx, func): """ Like contextlib.ContextDecorator. But with the following differences: 1. Is done by wrapping, rather than inheritance, so it works with context managers that are implemented from C and thus cannot easily inherit from Python classes 2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743) 3. Errors out if you try to wrap a class, because it is ambiguous whether or not you intended to wrap only the constructor The input argument can either be a context manager (in which case it must be a multi-shot context manager that can be directly invoked multiple times) or a callable that produces a context manager. """ assert not (callable(ctx) and hasattr(ctx, '__enter__')), ( f"Passed in {ctx} is both callable and also a valid context manager " "(has __enter__), making it ambiguous which interface to use. If you " "intended to pass a context manager factory, rewrite your call as " "context_decorator(lambda: ctx()); if you intended to pass a context " "manager directly, rewrite your call as context_decorator(lambda: ctx)" ) if not callable(ctx): def ctx_factory(): return ctx else: ctx_factory = ctx if inspect.isclass(func): raise RuntimeError( "Cannot decorate classes; it is ambiguous whether or not only the " "constructor or all methods should have the context manager applied; " "additionally, decorating a class at definition-site will prevent " "use of the identifier as a conventional type. " "To specify which methods to decorate, decorate each of them " "individually." ) if inspect.isgeneratorfunction(func): return _wrap_generator(ctx_factory, func) @functools.wraps(func) def decorate_context(*args, **kwargs): with ctx_factory(): return func(*args, **kwargs) return decorate_context class _DecoratorContextManager: """Allow a context manager to be used as a decorator.""" def __call__(self, orig_func: F) -> F: if inspect.isclass(orig_func): warnings.warn( "Decorating classes is deprecated and will be disabled in " "future versions. You should only decorate functions or methods. " "To preserve the current behavior of class decoration, you can " "directly decorate the `__init__` method and nothing else.", FutureWarning, stacklevel=2, ) func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs)) else: func = orig_func return cast(F, context_decorator(self.clone, func)) def __enter__(self) -> None: raise NotImplementedError def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: raise NotImplementedError def clone(self): # override this method if your children class takes __init__ parameters return self.__class__() class _NoParamDecoratorContextManager(_DecoratorContextManager): """Allow a context manager to be used as a decorator without parentheses.""" def __new__(cls, orig_func=None): if orig_func is None: return super().__new__(cls) return cls()(orig_func)