# mypy: allow-untyped-defs import re import torch._C as C """ PythonDispatcher class is a thin python-binding to C++ dispatcher and it is designed to show how dispatcher precompute works. In particular, it shows for a certain op `foo`, what the computed dispatch table looks like after user register their kernels to certains dispatch keys. In the real C++ dispatcher we support many dispatch keys for different functionalities. For simplicity PythonDispatcher only supports dispatch keys for a single example of each use case. These use cases are listed below: - CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference & autograd kernel in pytorch core library. E.g. CPU, CUDA - FPGA/AutogradOther: represents in-tree backends which we usually have backend specific inference kernels, but they share the same autograd kernel specified in AutogradOther. E.g. FPGA, SparseCsrCPU - XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd kernel defined in pytorch core library. Backend owner is responsible for registering both inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support. E.g. XLA, XPU, MPS - CompositeExplicitAutograd: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc. Kernels registered to this key MUST work for inference for all backends. - Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther. Kernels registered to this key MUST work for autograd for all backends. - CompositeImplicitAutograd: alias key CompositeImplicitAutograd = CompositeExplicitAutograd + Autograd Kernels registered to this key MUST work for both inference + autograd for all backends. Note we only allow registrations to alias keys inside pytorch core library. E.g you shouldn't register a CompositeImplicitAutograd or CompositeExplicitAutograd kernel from torch-xla extension, instead you should upstream the kernel into pytorch/pytorch repo so that it's available for all backends and continuously tested even without the extension. Usage: dispatcher = PythonDispatcher() dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"]) print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend. # For more debugging information # print(dispatcher.keys()) # print(dispatcher.registrations()) # print(dispatcher.rawRegistrations()) # print(dispatcher.rawDispatchTable()) PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table. This file only provides the simplified API for developers, relevant test code is located in test/test_dispatch.py """ class PythonDispatcher: namespace = "__test__" name = "foo" # fmt: off runtime_keys = [ "CPU", "AutogradCPU", "FPGA", "AutogradOther", "XLA", "AutogradXLA", "Lazy", "AutogradLazy", ] # fmt: on alias_keys = [ "CompositeExplicitAutograd", "Autograd", "CompositeImplicitAutograd", ] supported_keys = runtime_keys + alias_keys def __init__(self): C._dispatch_check_invariants(self.name) # type: ignore[attr-defined] self.ref = C._dispatch_library("FRAGMENT", self.namespace, "") self.ref.def_("foo(Tensor x) -> Tensor") """ Returns a list of dispatch keys supported by PythonDispatcher. You can register kernels to these keys. """ def keys(self): return self.supported_keys """ Register kernels to the target dispatchKeys. dispatchKeys(list[str]): a list of dispatch keys that you want to register your own kernel. Note that you don't need to write the kernel yourself in this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is automatically generated and registered. """ def register(self, dispatchKeys): # Overriden is not supported and triggers a warning in C++ dispatcher. if len(set(dispatchKeys)) != len(dispatchKeys): raise RuntimeError( f"Overriden is not allowed but found duplicates in {dispatchKeys}." ) # We currently forbid this in codegen instead of C++ dispatcher. if ( "CompositeImplicitAutograd" in dispatchKeys and "CompositeExplicitAutograd" in dispatchKeys ): raise RuntimeError( "Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed." ) for key in dispatchKeys: if key not in self.supported_keys: raise RuntimeError( f"{key} is not supported, please select a dispatch key in {self.supported_keys}." ) self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key) """ Helper function to format (key, kernel). """ def _format_line(self, key, kernel): return f"{key:<15} {kernel}\n" """ Helper function to print a table header. """ def _format_header(self, header): s = f""" {header} """ s += self._format_line("key", "kernel") s += "---------------------------\n" return s """ Returns raw output of all registration info for debugging only. Use registrations() for a simplified version. """ def rawRegistrations(self): return C._dispatch_dump(f"{self.namespace}::{self.name}") # type: ignore[attr-defined] """ Returns raw output of computed dispatch table for debugging only. Use dispatchTable() for a simplified version. """ def rawDispatchTable(self): return C._dispatch_dump_table(f"{self.namespace}::{self.name}") # type: ignore[attr-defined] """ Returns a table(str) including all the registrations from users. Note this includes registrations to both runtime keys and alias keys. """ def registrations(self): output = self._format_header("Registered Kernels") state = self.rawRegistrations() state_entries = state.split("\n") for line in state_entries: first = line.split(":")[0] if any(first.startswith(k) for k in self.supported_keys): kernel = line.split("::")[0].split(" ")[1] output += self._format_line(first, kernel) return output """ Returns the computed dispatch table(str). Note this only include runtime keys, registrations to alias keys have been decoded to their mapped runtime keys. """ def dispatchTable(self): output = self._format_header("Computed Dispatch Table") table = self.rawDispatchTable() table_entries = table.split("\n") regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)") for line in table_entries: k = line.split(":")[0] if k in self.runtime_keys: entry = regex.sub("[", line) output += self._format_line(k, entry.split(": ")[1]) return output