import functools import os import hashlib import subprocess import tempfile from pathlib import Path from triton.runtime.build import _build from triton.runtime.cache import get_cache_manager from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver dirname = os.path.dirname(os.path.realpath(__file__)) include_dir = [os.path.join(dirname, "include")] libdevice_dir = os.path.join(dirname, "lib") libraries = ['cuda'] @functools.lru_cache() def libcuda_dirs(): env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH") if env_libcuda_path: return [env_libcuda_path] libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() # each line looks like the following: # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so.1" in line] dirs = [os.path.dirname(loc) for loc in locs] env_ld_library_path = os.getenv("LD_LIBRARY_PATH") if env_ld_library_path and not dirs: dirs = [dir for dir in env_ld_library_path.split(":") if os.path.exists(os.path.join(dir, "libcuda.so.1"))] msg = 'libcuda.so cannot found!\n' if locs: msg += 'Possible files are located at %s.' % str(locs) msg += 'Please create a symlink of libcuda.so to any of the files.' else: msg += 'Please make sure GPU is set up and then run "/sbin/ldconfig"' msg += ' (requires sudo) to refresh the linker cache.' assert any(os.path.exists(os.path.join(path, 'libcuda.so.1')) for path in dirs), msg return dirs @functools.lru_cache() def library_dirs(): return [libdevice_dir, *libcuda_dirs()] def compile_module_from_src(src, name): key = hashlib.sha256(src.encode("utf-8")).hexdigest() cache = get_cache_manager(key) cache_path = cache.get_file(f"{name}.so") if cache_path is None: with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, "main.c") with open(src_path, "w") as f: f.write(src) so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries) with open(so, "rb") as f: cache_path = cache.put(f.read(), f"{name}.so", binary=True) import importlib.util spec = importlib.util.spec_from_file_location(name, cache_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) return mod # ------------------------ # Utils # ------------------------ class CudaUtils(object): def __new__(cls): if not hasattr(cls, "instance"): cls.instance = super(CudaUtils, cls).__new__(cls) return cls.instance def __init__(self): mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils") self.load_binary = mod.load_binary self.get_device_properties = mod.get_device_properties self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters self.set_printf_fifo_size = mod.set_printf_fifo_size self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor # ------------------------ # Launcher # ------------------------ def ty_to_cpp(ty): if ty[0] == '*': return "CUdeviceptr" return { "i1": "int32_t", "i8": "int8_t", "i16": "int16_t", "i32": "int32_t", "i64": "int64_t", "u1": "uint32_t", "u8": "uint8_t", "u16": "uint16_t", "u32": "uint32_t", "u64": "uint64_t", "fp16": "float", "bf16": "float", "fp32": "float", "f32": "float", "fp64": "double", }[ty] def make_launcher(constants, signature, ids): # Record the end of regular arguments; # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): if ty[0] == '*': return "PyObject*" return ty_to_cpp(ty) def format_of(ty): return { "PyObject*": "O", "float": "f", "double": "d", "long": "l", "int8_t": "b", "int16_t": "h", "int32_t": "i", "int64_t": "l", "uint8_t": "B", "uint16_t": "H", "uint32_t": "I", "uint64_t": "K", }[ty] args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiKKOOOO" + args_format args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' # generate glue code params = [i for i in signature.keys() if i not in constants] src = f""" #include \"cuda.h\" #include #include #include static inline void gpuAssert(CUresult code, const char *file, int line) {{ if (code != CUDA_SUCCESS) {{ const char* prefix = "Triton Error [CUDA]: "; const char* str; cuGetErrorString(code, &str); char err[1024] = {{0}}; strcat(err, prefix); strcat(err, str); PyGILState_STATE gil_state; gil_state = PyGILState_Ensure(); PyErr_SetString(PyExc_RuntimeError, err); PyGILState_Release(gil_state); }} }} #define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra); static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ // Open the shared library void* handle = dlopen("libcuda.so.1", RTLD_LAZY); if (!handle) {{ PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); return NULL; }} // Clear any existing error dlerror(); cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx"); // Check for errors const char *dlsym_error = dlerror(); if (dlsym_error) {{ PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1"); return NULL; }} return cuLaunchKernelExHandle; }} static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; if (gridX*gridY*gridZ > 0) {{ if (num_ctas == 1) {{ CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); }} else {{ CUlaunchAttribute launchAttr[2]; launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; launchAttr[0].value.clusterDim.x = clusterDimX; launchAttr[0].value.clusterDim.y = clusterDimY; launchAttr[0].value.clusterDim.z = clusterDimZ; launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; CUlaunchConfig config; config.gridDimX = gridX * clusterDimX; config.gridDimY = gridY * clusterDimY; config.gridDimZ = gridZ * clusterDimZ; config.blockDimX = 32 * num_warps; config.blockDimY = 1; config.blockDimZ = 1; config.sharedMemBytes = shared_memory; config.hStream = stream; config.attrs = launchAttr; config.numAttrs = 2; static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; if (cuLaunchKernelExHandle == NULL) {{ cuLaunchKernelExHandle = getLaunchKernelExHandle(); }} CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); }} }} }} typedef struct _DevicePtrInfo {{ CUdeviceptr dev_ptr; bool valid; }} DevicePtrInfo; static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ DevicePtrInfo ptr_info; ptr_info.dev_ptr = 0; ptr_info.valid = true; if (PyLong_Check(obj)) {{ ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); return ptr_info; }} if (obj == Py_None) {{ // valid nullptr return ptr_info; }} PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); if(ptr){{ PyObject *empty_tuple = PyTuple_New(0); PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); Py_DECREF(empty_tuple); Py_DECREF(ptr); if (!PyLong_Check(ret)) {{ PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); ptr_info.valid = false; return ptr_info; }} ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); if(!ptr_info.dev_ptr) return ptr_info; uint64_t dev_ptr; int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); if (status == CUDA_ERROR_INVALID_VALUE) {{ PyErr_Format(PyExc_ValueError, "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); ptr_info.valid = false; }} ptr_info.dev_ptr = dev_ptr; Py_DECREF(ret); // Thanks ChatGPT! return ptr_info; }} PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); ptr_info.valid = false; return ptr_info; }} static PyObject* launch(PyObject* self, PyObject* args) {{ int gridX, gridY, gridZ; uint64_t _stream; uint64_t _function; PyObject *launch_enter_hook = NULL; PyObject *launch_exit_hook = NULL; PyObject *kernel_metadata = NULL; PyObject *launch_metadata = NULL; {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function, &kernel_metadata, &launch_metadata, &launch_enter_hook, &launch_exit_hook {args_list})) {{ return NULL; }} int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); return NULL; }} // extract launch metadata if (launch_enter_hook != Py_None){{ PyObject* args = Py_BuildValue("(O)", launch_metadata); PyObject* ret = PyObject_CallObject(launch_enter_hook, args); Py_DECREF(args); if (!ret) return NULL; }} // raise exception asap {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; Py_BEGIN_ALLOW_THREADS; _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); Py_END_ALLOW_THREADS; if (PyErr_Occurred()) {{ return NULL; }} if(launch_exit_hook != Py_None){{ PyObject* args = Py_BuildValue("(O)", launch_metadata); PyObject* ret = PyObject_CallObject(launch_exit_hook, args); Py_DECREF(args); if (!ret) return NULL; }} // return None Py_INCREF(Py_None); return Py_None; }} static PyMethodDef ModuleMethods[] = {{ {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, {{NULL, NULL, 0, NULL}} // sentinel }}; static struct PyModuleDef ModuleDef = {{ PyModuleDef_HEAD_INIT, \"__triton_launcher\", NULL, //documentation -1, //size ModuleMethods }}; PyMODINIT_FUNC PyInit___triton_launcher(void) {{ PyObject *m = PyModule_Create(&ModuleDef); if(m == NULL) {{ return NULL; }} PyModule_AddFunctions(m, ModuleMethods); return m; }} """ return src class CudaLauncher(object): def __init__(self, src, metadata): ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else dict() cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i constants = {cst_key(key): value for key, value in constants.items()} signature = {cst_key(key): value for key, value in src.signature.items()} src = make_launcher(constants, signature, ids) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch def __call__(self, *args, **kwargs): self.launch(*args, **kwargs) class CudaDriver(GPUDriver): def __init__(self): self.utils = CudaUtils() # TODO: make static self.launcher_cls = CudaLauncher super().__init__() def get_current_target(self): device = self.get_current_device() capability = self.get_device_capability(device) capability = capability[0] * 10 + capability[1] warp_size = 32 return GPUTarget("cuda", capability, warp_size) @staticmethod def is_active(): import torch return torch.cuda.is_available() and (torch.version.hip is None)