## @package scope # Module caffe2.python.scope import contextlib import threading from past.builtins import basestring from caffe2.proto import caffe2_pb2 # The name scope and device scope when creating a new operator. _NAMESCOPE_SEPARATOR = '/' _threadlocal_scope = threading.local() def CurrentNameScope(): global _threadlocal_scope if not hasattr(_threadlocal_scope, "namescope"): _threadlocal_scope.namescope = '' return _threadlocal_scope.namescope def CurrentDeviceScope(): global _threadlocal_scope if not hasattr(_threadlocal_scope, "devicescope"): _threadlocal_scope.devicescope = None return _threadlocal_scope.devicescope @contextlib.contextmanager def NameScope(prefix, reset=False): global _threadlocal_scope assert isinstance(prefix, basestring) or prefix is None, \ "NameScope takes in a string as its argument." old_scope = CurrentNameScope() prefix = prefix + _NAMESCOPE_SEPARATOR if prefix else '' if reset: _threadlocal_scope.namescope = prefix else: _threadlocal_scope.namescope = _threadlocal_scope.namescope + prefix try: yield finally: assert _threadlocal_scope.namescope.endswith(prefix), \ "The namescope variable is changed from outside NameScope() calls." _threadlocal_scope.namescope = old_scope @contextlib.contextmanager def DeviceScope(scope, node_name=None): new_scope = caffe2_pb2.DeviceOption() if scope: assert isinstance(scope, caffe2_pb2.DeviceOption), \ "DeviceScope takes in a caffe2_pb2.DeviceOption as its argument." new_scope.CopyFrom(scope) else: assert node_name, "At least one argument should be non-null in DeviceScope" # rewrite node_name if it is explicitly given if node_name: new_scope.node_name = node_name global _threadlocal_scope old_scope = CurrentDeviceScope() # nested scope should inherit the node_name if it is not explicitly set if old_scope and old_scope.HasField('node_name') and \ not new_scope.HasField('node_name'): new_scope.node_name = old_scope.node_name # nested scope should inherit the extra_info and merged it with new extra_info if old_scope and hasattr(old_scope, 'extra_info'): new_scope.extra_info.extend(old_scope.extra_info) new_scope.extra_info.sort() _threadlocal_scope.devicescope = new_scope try: yield finally: assert _threadlocal_scope.devicescope == new_scope, \ "The device scope is changed from outside DeviceScope() calls." _threadlocal_scope.devicescope = old_scope @contextlib.contextmanager def EmptyNameScope(): """ Allow users to 'disable' the name scope behaviour. This sets the CurrentNameScope() to None, so that the field is not set in CreateOperator(...), etc. """ old_scope = CurrentNameScope() try: _threadlocal_scope.namescope = '' yield finally: _threadlocal_scope.namescope = old_scope return @contextlib.contextmanager def EmptyDeviceScope(): """ Allow users to 'disable' the device scope behaviour (so it can be controlled at a NetDef::DeviceOption level, not overridden at OperatorDef::DeviceOption level). This sets the CurrentDeviceScope() to None, so that the field is not set in CreateOperator(...), etc. """ old_scope = CurrentDeviceScope() try: _threadlocal_scope.devicescope = None yield finally: _threadlocal_scope.devicescope = old_scope return