# mypy: allow-untyped-defs import contextlib import torch # Common testing utilities for use in public testing APIs. # NB: these should all be importable without optional dependencies # (like numpy and expecttest). def wrapper_set_seed(op, *args, **kwargs): """Wrapper to set seed manually for some functions like dropout See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details. """ with freeze_rng_state(): torch.manual_seed(42) output = op(*args, **kwargs) if isinstance(output, torch.Tensor) and output.device.type == "lazy": # We need to call mark step inside freeze_rng_state so that numerics # match eager execution torch._lazy.mark_step() # type: ignore[attr-defined] return output @contextlib.contextmanager def freeze_rng_state(): # no_dispatch needed for test_composite_compliance # Some OpInfos use freeze_rng_state for rng determinism, but # test_composite_compliance overrides dispatch for all torch functions # which we need to disable to get and set rng state with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): rng_state = torch.get_rng_state() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() try: yield finally: # Modes are not happy with torch.cuda.set_rng_state # because it clones the state (which could produce a Tensor Subclass) # and then grabs the new tensor's data pointer in generator.set_state. # # In the long run torch.cuda.set_rng_state should probably be # an operator. # # NB: Mode disable is to avoid running cross-ref tests on thes seeding with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): if torch.cuda.is_available(): torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] torch.set_rng_state(rng_state)