from .scope import enter_scope, exit_scope from triton.compiler import CompiledKernel, LazyDict COMPUTE_METADATA_SCOPE_NAME = "__proton_launch_metadata" class TritonHook: flops_width = [8, 16, 32, 64] metrics = [f"flops{width}" for width in flops_width] + ["bytes"] @staticmethod def enter(lazy_dict: LazyDict) -> None: enter_scope(COMPUTE_METADATA_SCOPE_NAME) metadata = lazy_dict.get() exit_scope() fn_metrics = {k: metadata[k] for k in TritonHook.metrics if k in metadata} enter_scope(metadata["name"], triton_op=True, metrics=fn_metrics) @staticmethod def exit(lazy_dict: LazyDict) -> None: exit_scope(triton_op=True) def register_triton_hook() -> None: if CompiledKernel.launch_enter_hook is None: CompiledKernel.launch_enter_hook = TritonHook.enter CompiledKernel.launch_exit_hook = TritonHook.exit def unregister_triton_hook() -> None: if CompiledKernel.launch_enter_hook == TritonHook.enter: CompiledKernel.launch_enter_hook = None CompiledKernel.launch_exit_hook = None