import functools import os import hashlib import subprocess import tempfile from pathlib import Path from triton.runtime.build import _build from triton.runtime.cache import get_cache_manager from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver dirname = os.path.dirname(os.path.realpath(__file__)) include_dir = [os.path.join(dirname, "include")] @functools.lru_cache() def _get_path_to_hip_runtime_dylib(): lib_name = "libamdhip64.so" # If we are told explicitly what HIP runtime dynamic library to use, obey that. env_libhip_path = os.getenv("TRITON_LIBHIP_PATH") if env_libhip_path: if env_libhip_path.endswith(lib_name) and os.path.exists(env_libhip_path): return env_libhip_path raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}") paths = [] import site # First search the HIP runtime dynamic library packaged with PyTorch. It's very likely # that we run Triton together with PyTorch. This makes sure we use the same dynamic # library to avoid version mismatch. for path in site.getsitepackages(): path = os.path.join(path, "torch", "lib", lib_name) if os.path.exists(path): return path paths.append(path) # Then try to see if developer provides a HIP runtime dynamic library using LD_LIBARAY_PATH. env_ld_library_path = os.getenv("LD_LIBRARY_PATH") if env_ld_library_path: for d in env_ld_library_path.split(":"): f = os.path.join(d, lib_name) if os.path.exists(f): return f paths.append(f) # Afterwards try to search the loader dynamic library resolution paths. libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() # each line looks like the following: # libamdhip64.so.6 (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so.6 # libamdhip64.so (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so locs = [line.split()[-1] for line in libs.splitlines() if line.strip().endswith(lib_name)] for loc in locs: if os.path.exists(loc): return loc paths.append(loc) # As a last resort, guess if we have it in some common installation path. common_install_path = os.path.join('/opt/rocm/lib/', lib_name) if os.path.exists(common_install_path): return common_install_path paths.append(common_install_path) raise RuntimeError(f"cannot locate {lib_name} after attempted paths {paths}") def compile_module_from_src(src, name): key = hashlib.sha256(src.encode("utf-8")).hexdigest() cache = get_cache_manager(key) cache_path = cache.get_file(f"{name}.so") if cache_path is None: with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, "main.c") with open(src_path, "w") as f: f.write(src) so = _build(name, src_path, tmpdir, [], include_dir, []) with open(so, "rb") as f: cache_path = cache.put(f.read(), f"{name}.so", binary=True) import importlib.util spec = importlib.util.spec_from_file_location(name, cache_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) return mod class HIPUtils(object): def __new__(cls): if not hasattr(cls, "instance"): cls.instance = super(HIPUtils, cls).__new__(cls) return cls.instance def __init__(self): libhip_path = _get_path_to_hip_runtime_dylib() src = Path(os.path.join(dirname, "driver.c")).read_text() # Just do a simple search and replace here instead of templates or format strings. # This way we don't need to escape-quote C code curly brackets and we can replace # exactly once. src = src.replace('/*py_libhip_search_path*/', libhip_path, 1) mod = compile_module_from_src(src, "hip_utils") self.load_binary = mod.load_binary self.get_device_properties = mod.get_device_properties # -------------------- Launcher ---------------------------- def ty_to_cpp(ty): if ty[0] == '*': return "hipDeviceptr_t" return { "i1": "int32_t", "i8": "int8_t", "i16": "int16_t", "i32": "int32_t", "i64": "int64_t", "u1": "uint32_t", "u8": "uint8_t", "u16": "uint16_t", "u32": "uint32_t", "u64": "uint64_t", "fp16": "float", "bf16": "float", "fp32": "float", "f32": "float", "fp64": "double", }[ty] def make_launcher(constants, signature, ids, warp_size): start_desc = len(signature) #signature = generate_cu_signature(constants, signature, ids) arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): if ty[0] == '*': return "PyObject*" return { 'i1': 'int32_t', 'i8': 'int8_t', 'i16': 'int16_t', 'i32': 'int32_t', 'i64': 'int64_t', 'u1': 'uint32_t', 'u8': 'uint8_t', 'u16': 'uint16_t', 'u32': 'uint32_t', 'u64': 'uint64_t', 'fp16': 'float', 'bf16': 'float', 'fp32': 'float', 'f32': 'float', 'fp64': 'double', }[ty] def format_of(ty): return { "PyObject*": "O", "float": "f", "double": "d", "long": "l", "int8_t": "b", "int16_t": "h", "int32_t": "i", "int64_t": "l", "uint8_t": "B", "uint16_t": "H", "uint32_t": "I", "uint64_t": "K", }[ty] args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiKKOOOO" + args_format args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' libhip_path = _get_path_to_hip_runtime_dylib() # generate glue code params = [i for i in signature.keys() if i not in constants] src = f""" #define __HIP_PLATFORM_AMD__ #include #include #include #include #include // The list of paths to search for the HIP runtime library. The caller Python // code should substitute the search path placeholder. static const char *hipLibSearchPaths[] = {{"{libhip_path}"}}; // The list of HIP dynamic library symbols and their signature we are interested // in this file. #define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \\ FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \\ FOR_EACH_ERR_FN(hipModuleLaunchKernel, hipFunction_t f, \\ unsigned int gridDimX, unsigned int gridDimY, \\ unsigned int gridDimZ, unsigned int blockDimX, \\ unsigned int blockDimY, unsigned int blockDimZ, \\ unsigned int sharedMemBytes, hipStream_t stream, \\ void **kernelParams, void **extra) \\ FOR_EACH_ERR_FN(hipPointerGetAttribute, void *data, \\ hipPointer_attribute attribute, hipDeviceptr_t ptr) // The HIP symbol table for holding resolved dynamic library symbols. struct HIPSymbolTable {{ #define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...) \\ hipError_t (*hipSymbolName)(__VA_ARGS__); #define DEFINE_EACH_STR_FIELD(hipSymbolName, ...) \\ const char *(*hipSymbolName)(__VA_ARGS__); HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD) }}; static struct HIPSymbolTable hipSymbolTable; bool initSymbolTable() {{ // Use the HIP runtime library loaded into the existing process if it exits. void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD); if (lib) {{ // printf("[triton] chosen loaded libamdhip64.so in the process\\n"); }} // Otherwise, go through the list of search paths to dlopen the first HIP // driver library. if (!lib) {{ int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]); for (int i = 0; i < n; ++i) {{ void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL); if (handle) {{ lib = handle; // printf("[triton] chosen %s\\n", hipLibSearchPaths[i]); }} }} }} if (!lib) {{ PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so"); return false; }} // Resolve all symbols we are interested in. dlerror(); // Clear existing errors const char *error = NULL; #define QUERY_EACH_FN(hipSymbolName, ...) \\ *(void **)&hipSymbolTable.hipSymbolName = dlsym(lib, #hipSymbolName); \\ error = dlerror(); \\ if (error) {{ \\ PyErr_SetString(PyExc_RuntimeError, \\ "cannot query " #hipSymbolName " from libamdhip64.so"); \\ dlclose(lib); \\ return false; \\ }} HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN) return true; }} static inline void gpuAssert(hipError_t code, const char *file, int line) {{ if (code != HIP_SUCCESS) {{ const char* prefix = "Triton Error [HIP]: "; const char* str = hipSymbolTable.hipGetErrorString(code); char err[1024] = {{0}}; snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str ); PyErr_SetString(PyExc_RuntimeError, err); }} }} #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ // printf("_launch hip kernel\\n"); void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; if (gridX*gridY*gridZ > 0) {{ HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0)); }} }} typedef struct _DevicePtrInfo {{ hipDeviceptr_t dev_ptr; bool valid; }} DevicePtrInfo; static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ DevicePtrInfo ptr_info; ptr_info.dev_ptr = 0; ptr_info.valid = true; if (PyLong_Check(obj)) {{ ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj); return ptr_info; }} if (obj == Py_None) {{ // valid nullptr return ptr_info; }} PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); if(ptr){{ PyObject *empty_tuple = PyTuple_New(0); PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); Py_DECREF(empty_tuple); Py_DECREF(ptr); if (!PyLong_Check(ret)) {{ PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); ptr_info.valid = false; return ptr_info; }} ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret); if(!ptr_info.dev_ptr) return ptr_info; uint64_t dev_ptr; hipError_t status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); if (status == hipErrorInvalidValue) {{ PyErr_Format(PyExc_ValueError, "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); ptr_info.valid = false; }} ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr; Py_DECREF(ret); return ptr_info; }} PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); return ptr_info; }} static PyObject* launch(PyObject* self, PyObject* args) {{ // printf("launch\\n"); int gridX, gridY, gridZ; uint64_t _stream; uint64_t _function; PyObject *launch_enter_hook = NULL; PyObject *launch_exit_hook = NULL; PyObject *kernel_metadata = NULL; PyObject *launch_metadata = NULL; {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function, &kernel_metadata, &launch_metadata, &launch_enter_hook, &launch_exit_hook {args_list})) {{ return NULL; }} // extract kernel metadata int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ return NULL; }} // extract launch metadata if (launch_enter_hook != Py_None){{ PyObject* args = Py_BuildValue("(O)", launch_metadata); PyObject* ret = PyObject_CallObject(launch_enter_hook, args); Py_DECREF(args); if (!ret) return NULL; }} // raise exception asap {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); if(launch_exit_hook != Py_None){{ PyObject* args = Py_BuildValue("(O)", launch_metadata); PyObject* ret = PyObject_CallObject(launch_exit_hook, args); Py_DECREF(args); if (!ret) return NULL; }} if(PyErr_Occurred()) {{ return NULL; }} // return None Py_INCREF(Py_None); return Py_None; }} static PyMethodDef ModuleMethods[] = {{ {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, {{NULL, NULL, 0, NULL}} // sentinel }}; static struct PyModuleDef ModuleDef = {{ PyModuleDef_HEAD_INIT, \"__triton_launcher\", NULL, //documentation -1, //size ModuleMethods }}; PyMODINIT_FUNC PyInit___triton_launcher(void) {{ if (!initSymbolTable()) {{ return NULL; }} PyObject *m = PyModule_Create(&ModuleDef); if(m == NULL) {{ return NULL; }} PyModule_AddFunctions(m, ModuleMethods); return m; }} """ return src class HIPLauncher(object): def __init__(self, src, metadata): ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else dict() cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i constants = {cst_key(key): value for key, value in constants.items()} signature = {cst_key(key): value for key, value in src.signature.items()} src = make_launcher(constants, signature, ids, metadata.warp_size) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch def __call__(self, *args, **kwargs): self.launch(*args, **kwargs) class HIPDriver(GPUDriver): def __init__(self): super().__init__() self.utils = HIPUtils() self.launcher_cls = HIPLauncher @staticmethod def is_active(): import torch return torch.version.hip is not None def get_current_target(self): device = self.get_current_device() device_properties = self.utils.get_device_properties(device) arch = device_properties['arch'] warp_size = device_properties['warpSize'] return GPUTarget("hip", arch.split(':')[0], warp_size)