from triton.backends.compiler import BaseBackend, GPUTarget from triton._C.libtriton import ir, passes, llvm, amd from dataclasses import dataclass from typing import Any, Tuple import hashlib import tempfile import os import re import subprocess import functools from pathlib import Path @dataclass(frozen=True) class HIPOptions: num_warps: int = 4 waves_per_eu: int = 1 num_stages: int = 0 num_ctas: int = 1 extern_libs: dict = None cluster_dims: tuple = (1, 1, 1) debug: bool = False arch: str = None allow_fp8e4nv: bool = False allow_fp8e4b15: bool = False default_dot_input_precision: str = "ieee" allowed_dot_input_precisions: Tuple[str] = ("ieee", ) enable_fp_fusion: bool = True matrix_instr_nonkdim: int = 0 kpack: int = 1 allow_flush_denorm: bool = False max_num_imprecise_acc_default: int = 0 backend_name: str = 'hip' def __post_init__(self): default_libdir = Path(__file__).parent / 'lib' extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) # Ignore user-defined warp size for gfx9 warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch else 64 object.__setattr__(self, 'warp_size', warp_size) libs = ["ocml", "ockl"] for lib in libs: extern_libs[lib] = str(default_libdir / f'{lib}.bc') object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ "num_warps must be a power of 2" def hash(self): key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) return hashlib.sha256(key.encode("utf-8")).hexdigest() class HIPBackend(BaseBackend): @staticmethod def supports_target(target: GPUTarget): return target.backend == 'hip' def __init__(self, target: GPUTarget) -> None: super().__init__(target) assert isinstance(target.arch, str) self.binary_ext = "hsaco" def parse_options(self, opts) -> Any: args = {'arch': self.target.arch} args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts}) return HIPOptions(**args) def pack_metadata(self, metadata): return ( metadata.num_warps, metadata.num_ctas, metadata.shared, metadata.cluster_dims[0], metadata.cluster_dims[1], metadata.cluster_dims[2], ) def get_codegen_implementation(self): codegen_fns = dict() return codegen_fns def load_dialects(self, ctx): amd.load_dialects(ctx) @staticmethod def path_to_rocm_lld(): # Check env path for ld.lld lld_env_path = os.getenv("TRITON_HIP_LLD_PATH") if lld_env_path is not None: lld = Path(lld_env_path) if lld.is_file(): return lld # Check backend for ld.lld (used for pytorch wheels) lld = Path(__file__).parent / "llvm/bin/ld.lld" if lld.is_file(): return lld lld = Path("/opt/rocm/llvm/bin/ld.lld") if lld.is_file(): return lld lld = Path("/usr/bin/ld.lld") if lld.is_file(): return lld raise Exception("ROCm linker /opt/rocm/llvm/bin/ld.lld not found") @staticmethod def make_ttir(mod, metadata, options): pm = ir.pass_manager(mod.context) pm.enable_debug() passes.common.add_inliner(pm) passes.ttir.add_rewrite_tensor_pointer(pm) passes.ttir.add_combine(pm) passes.common.add_canonicalizer(pm) passes.ttir.add_reorder_broadcast(pm) passes.common.add_cse(pm) passes.common.add_licm(pm) passes.common.add_symbol_dce(pm) pm.run(mod) return mod @staticmethod def make_ttgir(mod, metadata, options): pm = ir.pass_manager(mod.context) pm.enable_debug() passes.ttir.add_convert_to_ttgpuir(pm, f"hip:{options.arch}", options.num_warps, options.warp_size, options.num_ctas) pm.run(mod) pm = ir.pass_manager(mod.context) pm.enable_debug() passes.ttgpuir.add_coalesce(pm) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_thread_locality(pm) amd.passes.ttgpuir.add_accelerate_matmul(pm, options.arch, options.matrix_instr_nonkdim, options.kpack) passes.ttgpuir.add_remove_layout_conversions(pm) amd.passes.ttgpuir.add_optimize_epilogue(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) if options.num_stages == 0 and amd.has_matrix_core_feature(options.arch): amd.passes.ttgpuir.add_stream_pipeline(pm) passes.common.add_canonicalizer(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_reduce_data_duplication(pm) if options.num_stages != 0: amd.passes.ttgpuir.add_reorder_instructions(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) pm.run(mod) return mod @staticmethod def make_llir(src, metadata, options): mod = src # TritonGPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm, options.arch) passes.convert.add_scf_to_cf(pm) passes.convert.add_index_to_llvmir(pm) passes.ttgpuir.add_allocate_shared_memory(pm) ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows: ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless ## of the value of kernel arg `allow_flush_denorm`. ## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output ## depends on the value of kernel arg `allow_flush_denorm`. ## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument. ## For now it is used as a controller for developers only. __HIP_FTZ = True amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.convert.add_cf_to_llvmir(pm) passes.convert.add_arith_to_llvmir(pm) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": passes.llvmir.add_di_scope(pm) # This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block # count caused by predicated loads/stores. In certain kernels, the addition of these blocks can cause the MLIR # canonicalizer to never finish when attempting to merge blocks. The permanent solution under consideration # involves using MUBUF instructions that have built-in out-of-bounds checks, which would eliminate the need # for conditional branching around memory accesses. amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm) pm.run(mod) # LLVM-IR (MLIR) -> LLVM-IR (LLVM) llvm.init_targets() context = llvm.context() llvm_mod = llvm.to_module(mod, context) # Set various control constants on the LLVM module so that device # libraries can resolve references to them. amd.set_isa_version(llvm_mod, options.arch) amd.set_abi_version(llvm_mod, 400) amd.set_bool_control_constant(llvm_mod, "__oclc_finite_only_opt", False) amd.set_bool_control_constant(llvm_mod, "__oclc_correctly_rounded_sqrt32", True) amd.set_bool_control_constant(llvm_mod, "__oclc_unsafe_math_opt", False) amd.set_bool_control_constant(llvm_mod, "__oclc_wavefrontsize64", options.warp_size == 64) # Set kernel attributes first given this may affect later optimizations. fns = [fn for fn in llvm_mod.get_functions() if not fn.is_declaration()] # The public kernel should be kernel 0. fns[0].set_calling_conv(amd.CALLING_CONV_AMDGPU_KERNEL) fns[0].add_fn_attr("amdgpu-flat-work-group-size", f"1,{options.num_warps*options.warp_size}") fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}") denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee" fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode) if options.extern_libs: paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)] llvm.link_extern_libs(llvm_mod, paths) llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, amd.TARGET_TRIPLE) # Get some metadata metadata["shared"] = src.get_int_attr("triton_gpu.shared") amd.cleanup_bitcode_metadata(llvm_mod) return str(llvm_mod) @staticmethod def make_amdgcn(src, metadata, options): # Find kernel names (there should only be one) # We get the name at the last possible step to accomodate `triton.compile` # on user-provided LLVM names = re.findall(r"define amdgpu_kernel void @([a-zA-Z_][a-zA-Z0-9_]*)", src) assert len(names) == 1 metadata["name"] = names[0] # llvm -> hsaco amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', [], options.enable_fp_fusion, False) if os.environ.get("AMDGCN_ENABLE_DUMP", "0") == "1": print("// -----// AMDGCN Dump //----- //") print(amdgcn) return amdgcn @staticmethod def make_hsaco(src, metadata, options): hsaco = amd.assemble_amdgcn(src, options.arch, '') rocm_path = HIPBackend.path_to_rocm_lld() with tempfile.NamedTemporaryFile() as tmp_out: with tempfile.NamedTemporaryFile() as tmp_in: with open(tmp_in.name, 'wb') as fd_in: fd_in.write(hsaco) subprocess.check_call([rocm_path, '-flavor', 'gnu', '-shared', tmp_in.name, '-o', tmp_out.name]) with open(tmp_out.name, 'rb') as fd_out: ret = fd_out.read() return ret def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options) stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options) @functools.lru_cache() def hash(self): version = subprocess.check_output([HIPBackend.path_to_rocm_lld(), "--version"], encoding='utf-8') return f'{version}-{self.target}'