import os import re import subprocess from abc import ABCMeta, abstractmethod, abstractclassmethod from dataclasses import dataclass from typing import Union @dataclass(frozen=True) class GPUTarget(object): # Target backend, e.g., cuda, hip backend: str # Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip) arch: Union[int, str] warp_size: int class BaseBackend(metaclass=ABCMeta): def __init__(self, target: GPUTarget) -> None: self.target = target assert self.supports_target(target) @staticmethod def _path_to_binary(binary: str): base_dir = os.path.join(os.path.dirname(__file__), os.pardir) paths = [ os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), os.path.join(base_dir, "third_party", "cuda", "bin", binary), ] for p in paths: bin = p.split(" ")[0] if os.path.exists(bin) and os.path.isfile(bin): result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) if result is not None: version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) if version is not None: return p, version.group(1) raise RuntimeError(f"Cannot find {binary}") @abstractclassmethod def supports_target(target: GPUTarget): raise NotImplementedError @abstractmethod def hash(self) -> str: """Returns a unique identifier for this backend""" raise NotImplementedError @abstractmethod def parse_options(self, options: dict) -> object: """ Converts an `options` dictionary into an arbitrary object and returns it. This function may contain target-specific heuristics and check the legality of the provided options """ raise NotImplementedError @abstractmethod def add_stages(self, stages: dict, options: object) -> None: """ Populates `stages` dictionary with entries of the form: ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes] The value of each entry may populate a `metadata` dictionary. Stages will be run sequentially (in inseriton order) and can communicate using `metadata`. All stages are expected to return a `str` object, except for the last stage which returns a `bytes` object for execution by the launcher. """ raise NotImplementedError @abstractmethod def load_dialects(self, context): """ Load additional MLIR dialects into the provided `context` """ raise NotImplementedError