## @package context # Module caffe2.python.context import inspect import threading import functools class _ContextInfo(object): def __init__(self, cls, allow_default): self.cls = cls self.allow_default = allow_default self._local_stack = threading.local() @property def _stack(self): if not hasattr(self._local_stack, 'obj'): self._local_stack.obj = [] return self._local_stack.obj def enter(self, value): self._stack.append(value) def exit(self, value): assert len(self._stack) > 0, 'Context %s is empty.' % self.cls assert self._stack.pop() == value def get_active(self, required=True): if len(self._stack) == 0: if not required: return None assert self.allow_default, ( 'Context %s is required but none is active.' % self.cls) self.enter(self.cls()) return self._stack[-1] class _ContextRegistry(object): def __init__(self): self._ctxs = {} def get(self, cls): if cls not in self._ctxs: assert issubclass(cls, Managed), "must be a context managed class, got {}".format(cls) self._ctxs[cls] = _ContextInfo(cls, allow_default=issubclass(cls, DefaultManaged)) return self._ctxs[cls] _CONTEXT_REGISTRY = _ContextRegistry() def _context_registry(): global _CONTEXT_REGISTRY return _CONTEXT_REGISTRY def _get_managed_classes(obj): return [ cls for cls in inspect.getmro(obj.__class__) if issubclass(cls, Managed) and cls != Managed and cls != DefaultManaged ] class Managed(object): """ Managed makes the inheritted class a context managed class. class Foo(Managed): ... with Foo() as f: assert f == Foo.current() """ @classmethod def current(cls, value=None, required=True): ctx_info = _context_registry().get(cls) if value is not None: assert isinstance(value, cls), ( 'Wrong context type. Expected: %s, got %s.' % (cls, type(value))) return value return ctx_info.get_active(required=required) def __enter__(self): for cls in _get_managed_classes(self): _context_registry().get(cls).enter(self) return self def __exit__(self, *args): for cls in _get_managed_classes(self): _context_registry().get(cls).exit(self) def __call__(self, func): @functools.wraps(func) def wrapper(*args, **kwargs): with self: return func(*args, **kwargs) return wrapper class DefaultManaged(Managed): """ DefaultManaged is similar to Managed but if there is no parent when current() is called it makes a new one. """ pass