from __future__ import annotations import hashlib import json from .._C.libtriton import get_cache_invalidating_env_vars, ir from ..backends import backends from ..backends.compiler import GPUTarget from .. import __version__ from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager from ..runtime.driver import driver # TODO: this shouldn't be here from dataclasses import dataclass from .code_generator import ast_to_ttir from pathlib import Path import re import functools import os @dataclass class AttrsDescriptor: divisible_by_16: set = None equal_to_1: set = None def __post_init__(self): if self.divisible_by_16 is None: self.divisible_by_16 = set() if self.equal_to_1 is None: self.equal_to_1 = set() def to_dict(self): return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1)} @staticmethod def from_dict(data): return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])), equal_to_1=set(data.get('equal_to_1', []))) def hash(self): key = str([sorted(x) for x in self.__dict__.values()]) return hashlib.sha256(key.encode("utf-8")).hexdigest() # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, # and any following whitespace # - (public\s+)? : optionally match the keyword public and any following whitespace # - (@\w+) : match an @ symbol followed by one or more word characters # (letters, digits, or underscores), and capture it as group 1 (the function name) # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing # zero or more arguments separated by commas, and capture it as group 2 (the argument list) # - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" prototype_pattern = { "ttir": mlir_prototype_pattern, "ttgir": mlir_prototype_pattern, "ptx": ptx_prototype_pattern, } mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+),?' ptx_arg_type_pattern = r"\.param\s+\.(\w+)" arg_type_pattern = { "ttir": mlir_arg_type_pattern, "ttgir": mlir_arg_type_pattern, "ptx": ptx_arg_type_pattern, } def convert_type_repr(x): # Currently we only capture the pointer type and assume the pointer is on global memory. # TODO: Capture and support shared memory space match = re.search(r'!tt\.ptr<([^,]+)', x) if match is not None: return '*' + convert_type_repr(match.group(1)) return x def _get_num_warps_from_ir_str(src: str): ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if # e.g. someone has an instruction (not module) attribute named "num-warps". num_warps_matches = re.findall(ttgir_num_warps_pattern, src) assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" num_warps = int(num_warps_matches[0]) return num_warps class ASTSource: def __init__(self, fn, signature, constants=None, attrs=None) -> None: self.fn = fn self.ext = "ttir" self.name = fn.__name__ self.signature = signature self.constants = constants self.attrs = attrs if isinstance(self.signature, str): self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} if self.constants is None: self.constants = dict() if self.attrs is None: self.attrs = AttrsDescriptor() def hash(self): sorted_sig = [v for k, v in sorted(self.signature.items())] # Note - we stringify the keys here to allow sorting to work for cases # where constants have mixed int/str keys. sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" return hashlib.sha256(key.encode("utf-8")).hexdigest() def make_ir(self, options, codegen_fns, context): return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns) def parse_options(self): return dict() class IRSource: def __init__(self, path): self.path = path path = Path(path) self.ext = path.suffix[1:] self.src = path.read_text() match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) self.name = match.group(1) signature = match.group(2) types = re.findall(arg_type_pattern[self.ext], signature) self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} def hash(self): return hashlib.sha256(self.src.encode("utf-8")).hexdigest() def make_ir(self, options, codegen_fns, context): module = ir.parse_mlir_module(self.path, context) module.context = context return module def parse_options(self): if self.ext == "ttgir": return {'num_warps': _get_num_warps_from_ir_str(self.src)} return dict() @functools.lru_cache() def triton_key(): import pkgutil TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) contents = [] # frontend with open(__file__, "rb") as f: contents += [hashlib.sha256(f.read()).hexdigest()] # compiler path_prefixes = [ (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), (os.path.join(TRITON_PATH, "backends"), "triton.backends."), ] for path, prefix in path_prefixes: for lib in pkgutil.walk_packages([path], prefix=prefix): with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: contents += [hashlib.sha256(f.read()).hexdigest()] # backend libtriton_hash = hashlib.sha256() with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: while True: chunk = f.read(1024**2) if not chunk: break libtriton_hash.update(chunk) contents.append(libtriton_hash.hexdigest()) # language language_path = os.path.join(TRITON_PATH, 'language') for lib in pkgutil.iter_modules([language_path]): with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: contents += [hashlib.sha256(f.read()).hexdigest()] return f'{__version__}' + '-'.join(contents) def parse(full_name, ext, context): if ext == "ttir" or ext == "ttgir": module = ir.parse_mlir_module(full_name, context) module.context = context return module if ext == "llir" or ext == "ptx": return Path(full_name).read_text() if ext == "cubin": return Path(full_name).read_bytes() def filter_traceback(e: BaseException): """ Removes code_generator.py and related files from tracebacks. These are uninteresting to the user -- "just show me *my* code!" """ if e.__cause__ is not None: filter_traceback(e.__cause__) if e.__context__ is not None: filter_traceback(e.__context__) # If a user has a file that matches one of these, they're out of luck. BAD_FILES = [ "/triton/compiler/code_generator.py", "/ast.py", ] tb = e.__traceback__ frames = [] while tb is not None: if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): frames.append(tb) tb = tb.tb_next for (cur_frame, next_frame) in zip(frames, frames[1:]): cur_frame.tb_next = next_frame if not frames: e.__traceback__ = None else: frames[-1].tb_next = None e.__traceback__ = frames[0] def compile(src, target=None, options=None): if target is None: target = driver.active.get_current_target() assert isinstance(target, GPUTarget), "target must be of GPUTarget type" backend = make_backend(target) ir_source = not isinstance(src, ASTSource) # create backend if ir_source: assert isinstance(src, str), "source must be either AST or a filepath" src = IRSource(src) extra_options = src.parse_options() options = backend.parse_options(dict(options or dict(), **extra_options)) # create cache manager env_vars = get_cache_invalidating_env_vars() key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" hash = hashlib.sha256(key.encode("utf-8")).hexdigest() fn_cache_manager = get_cache_manager(hash) # For dumping/overriding only hash the source as we want it to be independent of triton # core changes to make it easier to track kernels by hash. enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" fn_override_manager = get_override_manager(src.hash()) if enable_override else None fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None metadata_filename = f"{src.name}.json" metadata_group = fn_cache_manager.get_group(metadata_filename) or {} metadata_path = metadata_group.get(metadata_filename) always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1" if not always_compile and metadata_path is not None: # cache hit! metadata = json.loads(Path(metadata_path).read_text()) return CompiledKernel(src, metadata_group, hash) # initialize metadata metadata = { "hash": hash, "target": target, **options.__dict__, **env_vars, } # run compilation pipeline and populate metadata stages = dict() backend.add_stages(stages, options) first_stage = list(stages.keys()).index(src.ext) # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. if ir_source: first_stage += 1 context = ir.context() ir.load_dialects(context) backend.load_dialects(context) codegen_fns = backend.get_codegen_implementation() try: module = src.make_ir(options, codegen_fns, context) except Exception as e: filter_traceback(e) raise use_ttgir_loc = os.environ.get("USE_TTGIR_LOC", "0") == "1" for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) ir_filename = f"{src.name}.{ext}" metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) if fn_dump_manager is not None: fn_dump_manager.put(next_module, ir_filename) if (fn_override_manager is not None and fn_override_manager.has_file(ir_filename)): print(f"\nOverriding kernel with file {ir_filename}") full_name = fn_override_manager.get_file(ir_filename) next_module = parse(full_name, ext, context) # use an env variable to parse ttgir from file if use_ttgir_loc and ext == "ttgir": ttgir_full_name = fn_cache_manager.get_file(ir_filename) next_module.create_location_snapshot(ttgir_full_name) print(f"Create new locations for {ttgir_full_name}") module = next_module # write-back metadata metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) # return handle to compiled kernel return CompiledKernel(src, metadata_group, hash) def make_backend(target): actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] if len(actives) != 1: raise RuntimeError( f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") return actives[0](target) class LazyDict: def __init__(self, data): self.data = data self.extras = [] def get(self) -> None: for func, args in self.extras: self.data = self.data | func(*args) self.extras.clear() return self.data def add(self, func, args): self.extras.append((func, args)) class CompiledKernel: # Hooks for external tools to monitor the execution of triton kernels # TODO: move out of this namespace since it's a runtime thing launch_enter_hook = None launch_exit_hook = None def __init__(self, src, metadata_group, hash): from collections import namedtuple metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) metadata = json.loads(metadata_path.read_text()) metadata['cluster_dims'] = tuple(metadata['cluster_dims']) # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. target = metadata['target'] metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) self.metadata = KernelMetadata(**metadata) backend = make_backend(self.metadata.target) self.packed_metadata = backend.pack_metadata(self.metadata) self.src = src self.hash = hash self.name = self.metadata.name # stores the text of each level of IR that was generated during compilation asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] binary_ext = backend.binary_ext self.asm = { file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() for file in asm_files } self.kernel = self.asm[binary_ext] # binaries are lazily initialized # because it involves doing runtime things # (e.g., checking amount of shared memory on current device) self.module = None self.function = None def _init_handles(self): if self.module is not None: return device = driver.active.get_current_device() # create launcher self.run = driver.active.launcher_cls(self.src, self.metadata) # not enough shared memory to run the kernel max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] if self.metadata.shared > max_shared: raise OutOfResources(self.metadata.shared, max_shared, "shared memory") # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( self.name, self.kernel, self.metadata.shared, device) def __getattribute__(self, name): if name == 'run': self._init_handles() return super().__getattribute__(name) def launch_metadata(self, grid, stream, *args): if CompiledKernel.launch_enter_hook is None: return None ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: return ret arg_dict = {} arg_idx = 0 for i, arg_name in enumerate(self.src.fn.arg_names): if i in self.src.fn.constexprs: arg_dict[arg_name] = self.src.constants[arg_name] else: arg_dict[arg_name] = args[arg_idx] arg_idx += 1 ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) return ret def __getitem__(self, grid): self._init_handles() def runner(*args, stream=None): if stream is None: device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) launch_metadata = self.launch_metadata(grid, stream, *args) self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args) return runner