# mypy: allow-untyped-defs r""" PyTorch provides two global :class:`ConstraintRegistry` objects that link :class:`~torch.distributions.constraints.Constraint` objects to :class:`~torch.distributions.transforms.Transform` objects. These objects both input constraints and return transforms, but they have different guarantees on bijectivity. 1. ``biject_to(constraint)`` looks up a bijective :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` to the given ``constraint``. The returned transform is guaranteed to have ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``. 2. ``transform_to(constraint)`` looks up a not-necessarily bijective :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` to the given ``constraint``. The returned transform is not guaranteed to implement ``.log_abs_det_jacobian()``. The ``transform_to()`` registry is useful for performing unconstrained optimization on constrained parameters of probability distributions, which are indicated by each distribution's ``.arg_constraints`` dict. These transforms often overparameterize a space in order to avoid rotation; they are thus more suitable for coordinate-wise optimization algorithms like Adam:: loc = torch.zeros(100, requires_grad=True) unconstrained = torch.zeros(100, requires_grad=True) scale = transform_to(Normal.arg_constraints['scale'])(unconstrained) loss = -Normal(loc, scale).log_prob(data).sum() The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where samples from a probability distribution with constrained ``.support`` are propagated in an unconstrained space, and algorithms are typically rotation invariant.:: dist = Exponential(rate) unconstrained = torch.zeros(100, requires_grad=True) sample = biject_to(dist.support)(unconstrained) potential_energy = -dist.log_prob(sample).sum() .. note:: An example where ``transform_to`` and ``biject_to`` differ is ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a :class:`~torch.distributions.transforms.SoftmaxTransform` that simply exponentiates and normalizes its inputs; this is a cheap and mostly coordinate-wise operation appropriate for algorithms like SVI. In contrast, ``biject_to(constraints.simplex)`` returns a :class:`~torch.distributions.transforms.StickBreakingTransform` that bijects its input down to a one-fewer-dimensional space; this a more expensive less numerically stable transform but is needed for algorithms like HMC. The ``biject_to`` and ``transform_to`` objects can be extended by user-defined constraints and transforms using their ``.register()`` method either as a function on singleton constraints:: transform_to.register(my_constraint, my_transform) or as a decorator on parameterized constraints:: @transform_to.register(MyConstraintClass) def my_factory(constraint): assert isinstance(constraint, MyConstraintClass) return MyTransform(constraint.param1, constraint.param2) You can create your own registry by creating a new :class:`ConstraintRegistry` object. """ import numbers from torch.distributions import constraints, transforms __all__ = [ "ConstraintRegistry", "biject_to", "transform_to", ] class ConstraintRegistry: """ Registry to link constraints to transforms. """ def __init__(self): self._registry = {} super().__init__() def register(self, constraint, factory=None): """ Registers a :class:`~torch.distributions.constraints.Constraint` subclass in this registry. Usage:: @my_registry.register(MyConstraintClass) def construct_transform(constraint): assert isinstance(constraint, MyConstraint) return MyTransform(constraint.arg_constraints) Args: constraint (subclass of :class:`~torch.distributions.constraints.Constraint`): A subclass of :class:`~torch.distributions.constraints.Constraint`, or a singleton object of the desired class. factory (Callable): A callable that inputs a constraint object and returns a :class:`~torch.distributions.transforms.Transform` object. """ # Support use as decorator. if factory is None: return lambda factory: self.register(constraint, factory) # Support calling on singleton instances. if isinstance(constraint, constraints.Constraint): constraint = type(constraint) if not isinstance(constraint, type) or not issubclass( constraint, constraints.Constraint ): raise TypeError( f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}" ) self._registry[constraint] = factory return factory def __call__(self, constraint): """ Looks up a transform to constrained space, given a constraint object. Usage:: constraint = Normal.arg_constraints['scale'] scale = transform_to(constraint)(torch.zeros(1)) # constrained u = transform_to(constraint).inv(scale) # unconstrained Args: constraint (:class:`~torch.distributions.constraints.Constraint`): A constraint object. Returns: A :class:`~torch.distributions.transforms.Transform` object. Raises: `NotImplementedError` if no transform has been registered. """ # Look up by Constraint subclass. try: factory = self._registry[type(constraint)] except KeyError: raise NotImplementedError( f"Cannot transform {type(constraint).__name__} constraints" ) from None return factory(constraint) biject_to = ConstraintRegistry() transform_to = ConstraintRegistry() ################################################################################ # Registration Table ################################################################################ @biject_to.register(constraints.real) @transform_to.register(constraints.real) def _transform_to_real(constraint): return transforms.identity_transform @biject_to.register(constraints.independent) def _biject_to_independent(constraint): base_transform = biject_to(constraint.base_constraint) return transforms.IndependentTransform( base_transform, constraint.reinterpreted_batch_ndims ) @transform_to.register(constraints.independent) def _transform_to_independent(constraint): base_transform = transform_to(constraint.base_constraint) return transforms.IndependentTransform( base_transform, constraint.reinterpreted_batch_ndims ) @biject_to.register(constraints.positive) @biject_to.register(constraints.nonnegative) @transform_to.register(constraints.positive) @transform_to.register(constraints.nonnegative) def _transform_to_positive(constraint): return transforms.ExpTransform() @biject_to.register(constraints.greater_than) @biject_to.register(constraints.greater_than_eq) @transform_to.register(constraints.greater_than) @transform_to.register(constraints.greater_than_eq) def _transform_to_greater_than(constraint): return transforms.ComposeTransform( [ transforms.ExpTransform(), transforms.AffineTransform(constraint.lower_bound, 1), ] ) @biject_to.register(constraints.less_than) @transform_to.register(constraints.less_than) def _transform_to_less_than(constraint): return transforms.ComposeTransform( [ transforms.ExpTransform(), transforms.AffineTransform(constraint.upper_bound, -1), ] ) @biject_to.register(constraints.interval) @biject_to.register(constraints.half_open_interval) @transform_to.register(constraints.interval) @transform_to.register(constraints.half_open_interval) def _transform_to_interval(constraint): # Handle the special case of the unit interval. lower_is_0 = ( isinstance(constraint.lower_bound, numbers.Number) and constraint.lower_bound == 0 ) upper_is_1 = ( isinstance(constraint.upper_bound, numbers.Number) and constraint.upper_bound == 1 ) if lower_is_0 and upper_is_1: return transforms.SigmoidTransform() loc = constraint.lower_bound scale = constraint.upper_bound - constraint.lower_bound return transforms.ComposeTransform( [transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)] ) @biject_to.register(constraints.simplex) def _biject_to_simplex(constraint): return transforms.StickBreakingTransform() @transform_to.register(constraints.simplex) def _transform_to_simplex(constraint): return transforms.SoftmaxTransform() # TODO define a bijection for LowerCholeskyTransform @transform_to.register(constraints.lower_cholesky) def _transform_to_lower_cholesky(constraint): return transforms.LowerCholeskyTransform() @transform_to.register(constraints.positive_definite) @transform_to.register(constraints.positive_semidefinite) def _transform_to_positive_definite(constraint): return transforms.PositiveDefiniteTransform() @biject_to.register(constraints.corr_cholesky) @transform_to.register(constraints.corr_cholesky) def _transform_to_corr_cholesky(constraint): return transforms.CorrCholeskyTransform() @biject_to.register(constraints.cat) def _biject_to_cat(constraint): return transforms.CatTransform( [biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths ) @transform_to.register(constraints.cat) def _transform_to_cat(constraint): return transforms.CatTransform( [transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths ) @biject_to.register(constraints.stack) def _biject_to_stack(constraint): return transforms.StackTransform( [biject_to(c) for c in constraint.cseq], constraint.dim ) @transform_to.register(constraints.stack) def _transform_to_stack(constraint): return transforms.StackTransform( [transform_to(c) for c in constraint.cseq], constraint.dim )