from . import core from . import semantic from functools import wraps from typing import List T = core.TypeVar('T') def _check_dtype(dtypes: List[str]) -> T: """ We're following libdevice's convention to check accepted data types for math functions. It is not a good practice to support all data types as accelerators/GPUs don't support many float16 and bfloat16 math operations. We should let the users know that they are using and invoke explicit cast to convert the data type to the supported one. """ def wrapper(fn): @wraps(fn) def check(*args, **kwargs): # concatenate args and kwargs all_args = list(args) + list(kwargs.values()) for arg in [a for a in all_args if isinstance(a, core.tensor)]: if arg.type.scalar.name not in dtypes: raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}") return fn(*args, **kwargs) return check return wrapper def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]: def _decorator(func: T) -> T: docstr = """ Computes the element-wise {name} of :code:`x`. :param x: the input values :type x: Block """ func.__doc__ = docstr.format(name=name) return func return _decorator def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]: def _decorator(func: T) -> T: docstr = """ Computes the element-wise {name} of :code:`x` and :code:`y`. :param x: the input values :type x: Block :param y: the input values :type y: Block """ func.__doc__ = docstr.format(name=name) return func return _decorator def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]: def _decorator(func: T) -> T: docstr = """ Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`. :param x: the input values :type x: Block :param y: the input values :type y: Block :param z: the input values :type z: Block """ func.__doc__ = docstr.format(name=name) return func return _decorator @core.builtin @_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"]) @_add_math_2arg_docstr("most significant N bits of the 2N-bit product") def umulhi(x, y, _builder=None): x = core._to_tensor(x, _builder) y = core._to_tensor(y, _builder) x, y = core.binary_op_type_legalization(x, y, _builder) return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) @core.builtin @_check_dtype(dtypes=["fp32", "fp64"]) @_add_math_1arg_docstr("exponential") @core._tensor_member_fn def exp(x, _builder=None): x = core._to_tensor(x, _builder) return core.tensor(_builder.create_exp(x.handle), x.type) @core.builtin @_check_dtype(dtypes=["fp32", "fp64"]) @_add_math_1arg_docstr("exponential (base 2)") @core._tensor_member_fn def exp2(x, _builder=None): x = core._to_tensor(x, _builder) return core.tensor(_builder.create_exp2(x.handle), x.type) @core.builtin @_check_dtype(dtypes=["fp32", "fp64"]) @_add_math_1arg_docstr("natural logarithm") @core._tensor_member_fn def log(x, _builder=None): x = core._to_tensor(x, _builder) return core.tensor(_builder.create_log(x.handle), x.type) @core.builtin @_check_dtype(dtypes=["fp32", "fp64"]) @_add_math_1arg_docstr("logarithm (base 2)") @core._tensor_member_fn def log2(x, _builder=None): x = core._to_tensor(x, _builder) return core.tensor(_builder.create_log2(x.handle), x.type) @core.builtin @_check_dtype(dtypes=["fp32", "fp64"]) @_add_math_1arg_docstr("cosine") @core._tensor_member_fn def cos(x, _builder=None): x = core._to_tensor(x, _builder) return core.tensor(_builder.create_cos(x.handle), x.type) @core.builtin @_check_dtype(dtypes=["fp32", "fp64"]) @_add_math_1arg_docstr("sine") @core._tensor_member_fn def sin(x, _builder=None): x = core._to_tensor(x, _builder) return core.tensor(_builder.create_sin(x.handle), x.type) @core.builtin @_check_dtype(dtypes=["fp32", "fp64"]) @_add_math_1arg_docstr("fast square root") @core._tensor_member_fn def sqrt(x, _builder=None): x = core._to_tensor(x, _builder) return core.tensor(_builder.create_sqrt(x.handle), x.type) @core.builtin @_check_dtype(dtypes=["fp32"]) @_add_math_1arg_docstr("precise square root (rounding to nearest)") @core._tensor_member_fn def sqrt_rn(x, _builder=None): x = core._to_tensor(x, _builder) return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) @core.builtin @_check_dtype(dtypes=["fp32", "fp64"]) @_add_math_1arg_docstr("inverse square root") @core._tensor_member_fn def rsqrt(x, _builder=None): x = core._to_tensor(x, _builder) return core.tensor(_builder.create_rsqrt(x.handle), x.type) @core.builtin @_add_math_1arg_docstr("absolute value") @core._tensor_member_fn def abs(x, _builder=None): x = core._to_tensor(x, _builder) dtype = x.dtype if dtype.is_fp8e4b15(): mask = core.full(x.shape, 0x7F, core.int8, _builder=_builder) return core.tensor(_builder.create_and(x.handle, mask.handle), x.type) elif dtype.is_floating(): return core.tensor(_builder.create_fabs(x.handle), x.type) elif dtype.is_int_signed(): return core.tensor(_builder.create_iabs(x.handle), x.type) elif dtype.is_int_unsigned(): return x # no-op else: assert False, f"Unexpected dtype {dtype}" @core.builtin @_add_math_2arg_docstr("fast division") def fdiv(x, y, ieee_rounding=False, _builder=None): ieee_rounding = core._constexpr_to_value(ieee_rounding) x = core._to_tensor(x, _builder) y = core._to_tensor(y, _builder) return semantic.fdiv(x, y, ieee_rounding, _builder) @core.builtin @_check_dtype(dtypes=["fp32"]) @_add_math_2arg_docstr("precise division (rounding to nearest)") def div_rn(x, y, _builder=None): x = core._to_tensor(x, _builder) y = core._to_tensor(y, _builder) x, y = core.binary_op_type_legalization(x, y, _builder) return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) @core.builtin @_check_dtype(dtypes=["fp32", "fp64"]) @_add_math_1arg_docstr("error function") @core._tensor_member_fn def erf(x, _builder=None): x = core._to_tensor(x, _builder) return core.tensor(_builder.create_erf(x.handle), x.type) @core.builtin @_check_dtype(dtypes=["fp32", "fp64"]) @_add_math_1arg_docstr("floor") @core._tensor_member_fn def floor(x, _builder=None): x = core._to_tensor(x, _builder) return core.tensor(_builder.create_floor(x.handle), x.type) @core.builtin @_check_dtype(dtypes=["fp32", "fp64"]) @_add_math_1arg_docstr("ceil") @core._tensor_member_fn def ceil(x, _builder=None): x = core._to_tensor(x, _builder) return core.tensor(_builder.create_ceil(x.handle), x.type) @core.builtin @_add_math_3arg_docstr("fused multiply-add") def fma(x, y, z, _builder=None): x = core._to_tensor(x, _builder) y = core._to_tensor(y, _builder) z = core._to_tensor(z, _builder) x, y = core.binary_op_type_legalization(x, y, _builder) z, x = core.binary_op_type_legalization(z, x, _builder) z, y = core.binary_op_type_legalization(z, y, _builder) return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type)