# mypy: allow-untyped-defs """This file exports ONNX ops for opset 9. Opset 9 is supported by ONNX release 1.4.1 release on 01/23/19 """ from __future__ import annotations import builtins import functools import math import sys import warnings from typing import Callable, List, Optional, Sequence, Tuple, Union import torch import torch._C._onnx as _C_onnx import torch.nn.modules.utils import torch.onnx from torch import _C # Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_helper from torch.onnx._globals import GLOBALS from torch.onnx._internal import _beartype, jit_utils, registration from torch.types import Number # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md __all__ = [ "abs", "acos", "add", "addcmul", "addmm", "alias", "amax", "amin", "aminmax", "arange", "argmax", "argmin", "as_strided", "as_tensor", "asin", "atan", "atan2", "baddbmm", "batch_norm", "bernoulli", "bitwise_not", "bitwise_or", "bmm", "broadcast_tensors", "broadcast_to", "bucketize", "cat", "cdist", "ceil", "clamp_max", "clamp_min", "clamp", "clone", "constant_pad_nd", "contiguous", "conv_tbc", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d", "conv1d", "conv2d", "conv3d", "convert_element_type", "convolution", "cos", "cosine_similarity", "cross", "cumsum", "detach", "dim", "div", "dot", "dropout", "elu", "embedding_bag", "embedding", "empty_like", "empty", "eq", "erf", "exp", "expand_as", "expand", "eye", "fill", "flatten", "floor_divide", "floor", "floordiv", "frobenius_norm", "full_like", "full", "gather", "ge", "gelu", "get_pool_ceil_padding", "glu", "group_norm", "gt", "hann_window", "hardshrink", "hardsigmoid", "hardswish", "hardtanh", "index_add", "index_copy", "index_fill", "index_put", "index_select", "index", "instance_norm", "is_floating_point", "is_pinned", "isnan", "item", "kl_div", "layer_norm", "le", "leaky_relu", "lerp", "lift", "linalg_cross", "linalg_matrix_norm", "linalg_norm", "linalg_vector_norm", "linear", "linspace", "log_sigmoid", "log_softmax", "log", "log10", "log1p", "log2", "logical_and", "logical_not", "logical_or", "logical_xor", "logit", "logsumexp", "lstm_cell", "lstm", "lt", "masked_fill", "masked_fill_", "matmul", "max_pool1d_with_indices", "max_pool2d_with_indices", "max_pool3d_with_indices", "max", "maximum", "meshgrid", "min", "minimum", "mish", "mm", "movedim", "mse_loss", "mul", "multinomial", "mv", "narrow", "native_layer_norm", "ne", "neg", "new_empty", "new_full", "new_ones", "new_zeros", "nonzero_numpy", "nonzero", "norm", "numel", "numpy_T", "one_hot", "ones_like", "ones", "onnx_placeholder", "pad", "pairwise_distance", "permute", "pixel_shuffle", "pixel_unshuffle", "pow", "prelu", "prim_constant_chunk", "prim_constant_split", "prim_constant", "prim_data", "prim_device", "prim_dtype", "prim_if", "prim_layout", "prim_list_construct", "prim_list_unpack", "prim_loop", "prim_max", "prim_min", "prim_shape", "prim_tolist", "prim_tuple_construct", "prim_type", "prim_unchecked_cast", "prim_uninitialized", "rand_like", "rand", "randint_like", "randint", "randn_like", "randn", "reciprocal", "reflection_pad", "relu", "relu6", "remainder", "repeat_interleave", "repeat", "replication_pad", "reshape_as", "reshape", "roll", "rrelu", "rsqrt", "rsub", "scalar_tensor", "scatter_add", "scatter", "select", "selu", "sigmoid", "sign", "silu", "sin", "size", "slice", "softmax", "softplus", "softshrink", "sort", "split_with_sizes", "split", "sqrt", "square", "squeeze", "stack", "std_mean", "std", "sub", "t", "take", "tan", "tanh", "tanhshrink", "tensor", "threshold", "to", "topk", "transpose", "true_divide", "type_as", "unbind", "unfold", "unsafe_chunk", "unsafe_split_with_sizes", "unsafe_split", "unsqueeze", "unsupported_complex_operators", "noop_complex_operators", "unused", "var_mean", "var", "view_as", "view", "where", "wrap_logical_op_with_cast_to", "wrap_logical_op_with_negation", "zeros_like", "zeros", "zero", ] _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9) def _export(name: str): """Exports the function in the current global namespace.""" def wrapper(func): globals()[name] = func __all__.append(name) return func return wrapper @_beartype.beartype def unused(g): """Represents "missing" optional inputs.""" n = g.op("prim::Constant") n.setType(_C.OptionalType.ofTensor()) return n @_onnx_symbolic("aten::_shape_as_tensor") @_beartype.beartype def _shape_as_tensor(g: jit_utils.GraphContext, input): return g.op("Shape", input) @_onnx_symbolic("aten::_reshape_from_tensor") @_beartype.beartype def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape): if isinstance(shape, list): shape = g.op("Concat", *shape, axis_i=0) return reshape(g, input, shape) @_onnx_symbolic("aten::reshape") @symbolic_helper.quantized_args(True) @_beartype.beartype def reshape(g: jit_utils.GraphContext, self, shape): return symbolic_helper._reshape_helper(g, self, shape) @_onnx_symbolic("aten::reshape_as") @symbolic_helper.quantized_args(True) @_beartype.beartype def reshape_as(g: jit_utils.GraphContext, self, other): shape = g.op("Shape", other) return reshape(g, self, shape) @_onnx_symbolic("aten::add") @_beartype.beartype def add(g: jit_utils.GraphContext, self, other, alpha=None): if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): return symbolic_helper._onnx_opset_unsupported_detailed( "Add", 9, 11, "Add between list of tensors not supported", self ) if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: other = g.op("Mul", other, alpha) return g.op("Add", self, other) @_onnx_symbolic("aten::sub") @_beartype.beartype def sub(g: jit_utils.GraphContext, self, other, alpha=None): if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: other = g.op("Mul", other, alpha) return g.op("Sub", self, other) @_onnx_symbolic("aten::rsub") @_beartype.beartype def rsub(g: jit_utils.GraphContext, self, other, alpha=None): return sub(g, other, self, alpha=alpha) @_onnx_symbolic("aten::mul") @_beartype.beartype def mul(g: jit_utils.GraphContext, self, other): if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other): # ONNX Mul doesn't support Boolean, so use And as an equivalent operator. return g.op("And", self, other) else: return g.op("Mul", self, other) @_onnx_symbolic("aten::div") @_beartype.beartype def div(g: jit_utils.GraphContext, self, other, *args): if len(args) == 0: return true_divide(g, self, other) else: return _div_rounding_mode(g, self, other, *args) @_onnx_symbolic("aten::addcmul") @symbolic_helper.parse_args("v", "v", "v", "f") @_beartype.beartype def addcmul(g: jit_utils.GraphContext, self, tensor1, tensor2, value=1.0): value_tens = g.op("Constant", value_t=torch.tensor([value])) return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens)) @symbolic_helper.parse_args("v", "v", "s") @_beartype.beartype def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): if rounding_mode is None: return true_divide(g, self, other) elif rounding_mode == "floor": return _floor_divide(g, self, other) elif rounding_mode == "trunc": return _trunc_divide(g, self, other) else: raise errors.SymbolicValueError( f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"', self, ) @_beartype.beartype def _trunc_divide(g: jit_utils.GraphContext, self, other): out = g.op("Div", self, other) # the correct operation is truncate, which is not supported in ONNX, # we cannot call floor since it will behave differently for negative numbers # (eg. -0.1 should become -0 ) # - if scalar_type information are not available, assume that # we need to call floor (treat as float) out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.INT64) # Matching PyTorch's behavior: # - if self is fp the output's type is self's type # - if self is not fp and other is fp, the output is of type JitScalarType.FLOAT # - self is not fp and other is not fp, the output's type is self's output type # - the output type defaults to Float scalar_type = _type_utils.JitScalarType.from_value( self, _type_utils.JitScalarType.UNDEFINED ) if scalar_type != _type_utils.JitScalarType.UNDEFINED: if not symbolic_helper._is_fp(self) and symbolic_helper._is_fp(other): out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) else: out = g.op( "Cast", out, to_i=scalar_type.onnx_type(), ) else: out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) return out @_beartype.beartype def _floor_divide(g: jit_utils.GraphContext, self, other): if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): out = true_divide(g, self, other) return g.op("Floor", out) else: # Integer division does trunction rounding div = g.op("Div", self, other) # Division is negative if: self < 0 != other < 0 zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) negative = g.op( "Xor", symbolic_helper._lt_helper(g, self, zero), symbolic_helper._lt_helper(g, other, zero), ) # For negative numbers with self % other != 0, subtract 1 to round down instead of up mod = g.op("Sub", self, g.op("Mul", div, other)) fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) fixup = g.op("Mul", fixup_mask, one) return g.op("Sub", div, fixup) @_onnx_symbolic("aten::floor_divide") @_beartype.beartype def floor_divide(g: jit_utils.GraphContext, self, other): # Deprecated behavior, floor_divide actually truncates return _trunc_divide(g, self, other) @_onnx_symbolic("aten::floordiv") @_beartype.beartype def floordiv(g: jit_utils.GraphContext, self, other): return floor_divide(g, self, other) @_onnx_symbolic("aten::true_divide") @_beartype.beartype def true_divide(g: jit_utils.GraphContext, self, other): """Division where both inputs are cast to floating types If both inputs are floating, performs div as usual If only one input is a floating type, the other input is cast to its type If neither input is a floating type, both inputs are cast to the default scalar type """ # Case 1: either values are floating # Performs div as usual. # Implicit casting will be handled in scalar type analysis pass. if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): return g.op("Div", self, other) # Case 2: neither is floating # Casts both inputs to the default scalar type scalar_type = torch.get_default_dtype() onnx_scalar_type = _C_onnx.TensorProtoDataType.FLOAT assert scalar_type is torch.float or scalar_type is torch.double if torch.get_default_dtype() is torch.double: onnx_scalar_type = _C_onnx.TensorProtoDataType.DOUBLE self = g.op("Cast", self, to_i=onnx_scalar_type) other = g.op("Cast", other, to_i=onnx_scalar_type) return g.op("Div", self, other) @_onnx_symbolic("aten::reciprocal") @_beartype.beartype def reciprocal(g: jit_utils.GraphContext, self): # torch.reciprocal implicitly casts to float, so we do the same. if not symbolic_helper._is_fp(self): self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) return g.op("Reciprocal", self) @_onnx_symbolic("aten::cat") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def cat(g: jit_utils.GraphContext, tensor_list, dim): """Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension. Parameters: g (jit_utils.GraphContext): Graph context. tensor_list (List[torch.Tensor]): List of tensors to concatenate. dim (int): Dimension along which to concatenate the tensors. Returns: ONNX graph node representing the concatenated tensor. """ tensors = symbolic_helper._unpack_list(tensor_list) # torch.cat ignores empty tensors such as `torch.Tensor([])` # These needs to be removed as input from ONNX's concat too, otherwise shape inference # will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else) nonempty_tensors = [] for t in tensors: if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size( t, 0 ): continue nonempty_tensors.append(t) assert len(nonempty_tensors) > 0 assert all( symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None or symbolic_helper._get_tensor_rank(t) is None or symbolic_helper._get_tensor_rank(t) == symbolic_helper._get_tensor_rank(nonempty_tensors[0]) for t in nonempty_tensors ) tensor_list.node().removeAllInputs() for t in nonempty_tensors: tensor_list.node().addInput(t) tensors = symbolic_helper._unpack_list(tensor_list) return g.op("Concat", *tensors, axis_i=dim) @_onnx_symbolic("aten::stack") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def stack(g: jit_utils.GraphContext, tensor_list, dim): unsqueezed = [ symbolic_helper._unsqueeze_helper(g, t, [dim]) for t in symbolic_helper._unpack_list(tensor_list) ] return g.op("Concat", *unsqueezed, axis_i=dim) @_onnx_symbolic("aten::list") @_beartype.beartype def _list(g: jit_utils.GraphContext, self): return self @_onnx_symbolic("aten::mm") @_beartype.beartype def mm(g: jit_utils.GraphContext, self, other): # Create a dummy C tensor. Only needed for API purposes, the value is # since beta = 0 C = g.op("Constant", value_t=torch.tensor([1])) return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0) @_onnx_symbolic("aten::bmm") @_beartype.beartype def bmm(g: jit_utils.GraphContext, self, other): return g.op("MatMul", self, other) @_onnx_symbolic("aten::matmul") @_beartype.beartype def matmul(g: jit_utils.GraphContext, self, other): return g.op("MatMul", self, other) @_onnx_symbolic("aten::addmm") @symbolic_helper.parse_args("v", "v", "v", "t", "t") @_beartype.beartype def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): scalar_type = None self_scalar_type = symbolic_helper._try_get_scalar_type(self) mat1_scalar_type = symbolic_helper._try_get_scalar_type(mat1) mat2_scalar_type = symbolic_helper._try_get_scalar_type(mat2) if self_scalar_type is not None: scalar_type = self_scalar_type elif mat1_scalar_type is not None: scalar_type = mat1_scalar_type elif mat2_scalar_type is not None: scalar_type = mat2_scalar_type mat1_rank = symbolic_helper._get_tensor_rank(mat1) mat2_rank = symbolic_helper._get_tensor_rank(mat2) def is_not_none_nor(v, u): return v is not None and v != u if scalar_type is not None and ( is_not_none_nor(mat1_rank, 2) or is_not_none_nor(mat2_rank, 2) ): res1 = g.op("MatMul", mat1, mat2) res2 = self alpha = symbolic_helper._scalar(alpha) beta = symbolic_helper._scalar(beta) if alpha != 1: alpha = g.op( "Constant", value_t=torch.tensor(alpha, dtype=scalar_type.dtype()) ) res1 = g.op("Mul", res1, alpha) if beta != 1: beta = g.op( "Constant", value_t=torch.tensor( symbolic_helper._scalar(beta), dtype=scalar_type.dtype() ), ) res2 = g.op("Mul", res2, beta) return g.op("Add", res1, res2) return g.op( "Gemm", mat1, mat2, self, beta_f=symbolic_helper._scalar(beta), alpha_f=symbolic_helper._scalar(alpha), ) @_onnx_symbolic("aten::neg") @_beartype.beartype def neg(g: jit_utils.GraphContext, self): return g.op("Neg", self) @_onnx_symbolic("aten::sqrt") @_beartype.beartype def sqrt(g: jit_utils.GraphContext, self): if _type_utils.JitScalarType.from_value( self, _type_utils.JitScalarType.UNDEFINED ) in { _type_utils.JitScalarType.UINT8, _type_utils.JitScalarType.INT8, _type_utils.JitScalarType.INT16, _type_utils.JitScalarType.INT, _type_utils.JitScalarType.INT64, }: # torch converts all int inputs to sqrt to float self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) return g.op("Sqrt", self) @_onnx_symbolic("aten::rsqrt") @_beartype.beartype def rsqrt(g: jit_utils.GraphContext, self): return g.op( "Div", symbolic_helper._if_scalar_type_as(torch.ones(1), self), sqrt(g, self) ) @_onnx_symbolic("aten::tanh") # Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qtanh.cpp @symbolic_helper.quantized_args(True, scale=2.0 / 256.0, zero_point=128) @_beartype.beartype def tanh(g: jit_utils.GraphContext, self): return g.op("Tanh", self) @_onnx_symbolic("aten::sin") @_beartype.beartype def sin(g: jit_utils.GraphContext, self): return g.op("Sin", self) @_onnx_symbolic("aten::cos") @_beartype.beartype def cos(g: jit_utils.GraphContext, self): return g.op("Cos", self) @_onnx_symbolic("aten::tan") @_beartype.beartype def tan(g: jit_utils.GraphContext, self): return g.op("Tan", self) @_onnx_symbolic("aten::asin") @_beartype.beartype def asin(g: jit_utils.GraphContext, self): return g.op("Asin", self) @_onnx_symbolic("aten::acos") @_beartype.beartype def acos(g: jit_utils.GraphContext, self): return g.op("Acos", self) @_onnx_symbolic("aten::atan") @_beartype.beartype def atan(g: jit_utils.GraphContext, self): return g.op("Atan", self) @_onnx_symbolic("aten::atan2") @_beartype.beartype def atan2(g: jit_utils.GraphContext, self, other): # self is y, and other is x on coordinate slope = g.op("Div", self, other) atan = g.op("Atan", slope) const_zero = g.op("Constant", value_t=torch.tensor(0)) const_pi = g.op("Constant", value_t=torch.tensor(math.pi)) condition_second_or_third_quadrant = g.op("Greater", self, const_zero) second_third_quadrant = g.op( "Where", condition_second_or_third_quadrant, g.op("Add", atan, const_pi), g.op("Sub", atan, const_pi), ) condition_14_or_23_quadrant = g.op("Less", other, const_zero) result = g.op("Where", condition_14_or_23_quadrant, second_third_quadrant, atan) return result @_onnx_symbolic("aten::sigmoid") # Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp @symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) @_beartype.beartype def sigmoid(g: jit_utils.GraphContext, self): """Converts the corresponding PyTorch function into ONNX operators. It is not meant to be called directly by a user. Args: g (jit_utils.GraphContext): Graph context. self (Tensor): the input tensor. Returns: ONNX operator """ return g.op("Sigmoid", self) @_onnx_symbolic("aten::sign") @_beartype.beartype def sign(g: jit_utils.GraphContext, self): return g.op("Sign", self) @symbolic_helper.quantized_args(True) @_beartype.beartype def _slice(g: jit_utils.GraphContext, input, axes, starts, ends): assert len(starts) == len(ends) if len(starts) == 1 and starts[0] == 0 and ends[0] == _constants.INT64_MAX: return input return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends) @_onnx_symbolic( "aten::sum", decorate=[symbolic_helper._apply_params("ReduceSum", "sum")] ) @_onnx_symbolic( "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] ) # torch.prod does not support multidimensional "dim" @_onnx_symbolic( "aten::prod", decorate=[ symbolic_helper._apply_params( "ReduceProd", "prod", allow_multi_dim_support=False ) ], ) @_beartype.beartype def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): return symbolic_helper._reduce_with_dtype_helper( onnx_op, name, allow_multi_dim_support ) @_onnx_symbolic("aten::cumsum") @symbolic_helper.parse_args("v", "i", "none") @_beartype.beartype def cumsum(g: jit_utils.GraphContext, input, dim, dtype): if symbolic_helper.is_caffe2_aten_fallback(): if dtype.node().kind() != "prim::Constant": return symbolic_helper._unimplemented("cumsum", "dtype", dtype) return g.at("cumsum", input, dim_i=dim) symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) @_onnx_symbolic("aten::_sample_dirichlet") @_beartype.beartype def _sample_dirichlet(g: jit_utils.GraphContext, self, generator): if symbolic_helper.is_caffe2_aten_fallback(): if not symbolic_helper._is_none(generator): return symbolic_helper._unimplemented( "_sample_dirichlet", "We are not able to export generator", self ) return g.at("_sample_dirichlet", self) return symbolic_helper._onnx_unsupported("_sample_dirichlet", self) @_onnx_symbolic("aten::_standard_gamma") @_beartype.beartype def _standard_gamma(g: jit_utils.GraphContext, self, generator): if symbolic_helper.is_caffe2_aten_fallback(): if not symbolic_helper._is_none(generator): return symbolic_helper._unimplemented( "_standard_gamma", "not able to export generator", self ) return g.at("_standard_gamma", self) return symbolic_helper._onnx_unsupported("_standard_gamma", self) @_onnx_symbolic("aten::t") @_beartype.beartype def t(g: jit_utils.GraphContext, self): rank = symbolic_helper._get_tensor_rank(self) if rank is None or rank < 2: # The transpose of a 1d or 0d tensor is itself. ONNX does not define the behavior # clearly and onnxruntime fails on these cases. So we add an Identity node to # mirror the behavior of eager mode. return g.op("Identity", self) return g.op("Transpose", self, perm_i=(1, 0)) @_onnx_symbolic("aten::numpy_T") @symbolic_helper.quantized_args(True) @_beartype.beartype def numpy_T(g: jit_utils.GraphContext, input): ndim = symbolic_helper._get_tensor_rank(input) assert ndim is not None perm = list(reversed(range(0, ndim))) return g.op("Transpose", input, perm_i=perm) @_onnx_symbolic("aten::expand") @symbolic_helper.quantized_args(True) @_beartype.beartype def expand(g: jit_utils.GraphContext, self, size, implicit): """Implement the expand function for a pytorch tensor in ONNX according to specified `size`""" size = symbolic_helper._maybe_get_const(size, "is") if not symbolic_helper._is_value(size): size = g.op("Constant", value_t=torch.LongTensor(size)) elif symbolic_helper._is_packed_list(size): # Expand with -1 dim value means dim is unchanged. # Since onnx::expand supports two-way broadcasting, # -1 dim value can be exported to onnx as 1 size = symbolic_helper._reshape_helper( g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) ) dtype = _type_utils.JitScalarType.INT64 ones = ones_like(g, size, dtype) neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) size = where(g, g.op("Equal", size, neg_ones), ones, size) return g.op("Expand", self, size) @_onnx_symbolic("aten::broadcast_to") @symbolic_helper.quantized_args(True) @_beartype.beartype def broadcast_to(g: jit_utils.GraphContext, self, size): size = symbolic_helper._maybe_get_const(size, "is") if not symbolic_helper._is_value(size): size = g.op("Constant", value_t=torch.LongTensor(size)) elif symbolic_helper._is_packed_list(size): # Expand with -1 dim value means dim is unchanged. # Since onnx::expand supports two-way broadcasting, # -1 dim value can be exported to onnx as 1 size = symbolic_helper._reshape_helper( g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) ) dtype = _type_utils.JitScalarType.INT64 ones = ones_like(g, size, dtype) neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) size = where(g, g.op("Equal", size, neg_ones), ones, size) return g.op("Expand", self, size) @_onnx_symbolic("aten::expand_as") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def expand_as(g: jit_utils.GraphContext, self, other): self_t = symbolic_helper._maybe_get_const(self, "t") if isinstance(self_t, torch.Tensor): orig_type = self_t.dtype self_t = self_t.to(torch.double) dims = [] for d in range(self_t.dim()): if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t): dims.append(d) self = g.op( "Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type) ) shape = g.op("Shape", other) return g.op("Expand", self, shape) @_onnx_symbolic("aten::embedding") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "v", "i", "b", "v") @_beartype.beartype def embedding( g: jit_utils.GraphContext, weight, indices, padding_idx, scale_grad_by_freq, sparse, ): if scale_grad_by_freq and GLOBALS.export_training: raise errors.SymbolicValueError( "Unsupported: ONNX export of embedding with scale_grad_by_freq=True " "for training mode. ONNX does not support scaling the gradients.", weight, ) if padding_idx >= 0 and GLOBALS.export_training: warnings.warn( "Warning: ONNX export of embedding with padding_idx >= 0 " "for training mode. " "ONNX does not support not updating the embedding vector at padding_idx during training." ) return g.op("Gather", weight, indices) @_onnx_symbolic("aten::embedding_bag") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") @_beartype.beartype def embedding_bag( g: jit_utils.GraphContext, embedding_matrix, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, ): if not symbolic_helper._is_none(per_sample_weights): return symbolic_helper._onnx_unsupported( "embedding_bag with per_sample_weights" ) if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "embedding_bag", embedding_matrix, indices, offsets, outputs=4, scale_grad_by_freq_i=scale_grad_by_freq, mode_i=mode, sparse_i=sparse, include_last_offset_i=include_last_offset, padding_idx_i=padding_idx, ) return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix) @_onnx_symbolic("aten::size") @symbolic_helper.quantized_args(True, quantize_output=False) @_beartype.beartype def size(g: jit_utils.GraphContext, self, dim=None): if dim is None: return g.op("Shape", self) if symbolic_helper._maybe_get_const(dim, "i") < 0: rank = symbolic_helper._get_tensor_rank(self) if rank is not None: dim = symbolic_helper._maybe_get_const(dim, "i") + rank dim = g.op("Constant", value_t=torch.tensor(dim)) return symbolic_helper._size_helper(g, self, dim) @_onnx_symbolic("aten::transpose") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype def transpose(g: jit_utils.GraphContext, self, dim0, dim1): if dim0 == dim1: # micro-optimization return self # NB: Transpose in ONNX is actually a Permute rank = symbolic_helper._get_tensor_rank(self) if rank is not None: axes = list(range(rank)) axes[dim0], axes[dim1] = axes[dim1], axes[dim0] return g.op("Transpose", self, perm_i=axes) elif symbolic_helper.is_caffe2_aten_fallback(): # if we don't have dim information we cannot # output a permute so use ATen instead return g.at("transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1) else: raise errors.SymbolicValueError( "Unsupported: ONNX export of transpose for tensor of unknown rank.", self, ) @_onnx_symbolic("aten::permute") @symbolic_helper.parse_args("v", "is") @_beartype.beartype def permute(g: jit_utils.GraphContext, self, dims): if dims == list(range(0, len(dims))): return self return g.op("Transpose", self, perm_i=dims) @_onnx_symbolic("aten::view") @symbolic_helper.quantized_args(True) @_beartype.beartype def view(g: jit_utils.GraphContext, self, size): return reshape(g, self, size) @_onnx_symbolic("aten::view_as") @_beartype.beartype def view_as(g: jit_utils.GraphContext, self, other): shape = g.op("Shape", other) return reshape(g, self, shape) @_onnx_symbolic("aten::unsafe_chunk") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): if _outputs is None: return symbolic_helper._onnx_opset_unsupported_detailed( "unsafe_chunk", 9, 11, "Dynamic number of outputs not supported", self ) size = symbolic_helper._get_tensor_dim_size(self, dim) if size is None: return symbolic_helper._unimplemented( "unsafe_chunk", "unknown dimension size", self ) split_size = (size + chunks - 1) // chunks splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: splits.append(leftover) return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) @_onnx_symbolic("aten::split") @symbolic_helper.parse_args("v", "v", "i", "i") @_beartype.beartype def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): return symbolic_helper._onnx_opset_unsupported_detailed( "split", 9, 11, "Dynamic number of outputs not supported", self ) split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") if split_val.dim() > 0: return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs) split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") size = symbolic_helper._get_tensor_dim_size(self, dim) if size is None: if _outputs is not None: size = split_size * _outputs else: return symbolic_helper._onnx_opset_unsupported_detailed( "split", 9, 11, "Unknown dimension size not supported", self ) splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: splits.append(leftover) return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) @_onnx_symbolic("aten::unsafe_split") @_beartype.beartype def unsafe_split( g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None ): return split(g, self, split_size_or_sizes, dim, _outputs) @_onnx_symbolic("aten::split_with_sizes") @symbolic_helper.parse_args("v", "is", "i", "i") @_beartype.beartype def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): if not symbolic_helper._is_split_static(split_sizes, _outputs): return symbolic_helper._onnx_opset_unsupported_detailed( "split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self ) return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs) @_onnx_symbolic("aten::unsafe_split_with_sizes") @_beartype.beartype def unsafe_split_with_sizes( g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None ): return split_with_sizes(g, self, split_sizes, dim, _outputs) @_onnx_symbolic("aten::unbind") @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): if _outputs is None: return symbolic_helper._onnx_opset_unsupported_detailed( "unbind", 9, 11, "Dynamic number of outputs not supported", self ) outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs) outputs = [outputs] if _outputs == 1 else outputs squeezed_outputs = [ symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs ] return squeezed_outputs @_onnx_symbolic("aten::select") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "i", "v") @_beartype.beartype def select(g: jit_utils.GraphContext, self, dim, index): """Implement the select functionality for a pytorch tensor in ONNX. Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor. """ index = symbolic_helper._maybe_get_scalar(index) if (not symbolic_helper._is_value(index)) and (index < 0): if index == -1: end_index = _constants.INT64_MAX else: end_index = index + 1 slice_node = symbolic_helper._slice_helper( g, self, axes=[dim], starts=[index], ends=[end_index] ) return symbolic_helper._squeeze_helper(g, slice_node, [dim]) else: # FIXME(justinchuby): can index be an int and not a value? return g.op("Gather", self, index, axis_i=dim) @_onnx_symbolic("aten::square") @_beartype.beartype def square(g: jit_utils.GraphContext, self): return g.op("Mul", self, self) @_onnx_symbolic("aten::squeeze") @_beartype.beartype def squeeze(g: jit_utils.GraphContext, self, dim=None): if dim is None: return g.op("Squeeze", self) squeeze_dim = symbolic_helper._get_const(dim, "i", "dim") # Handle negative dims if squeeze_dim < 0: rank = symbolic_helper._get_tensor_rank(self) if rank is not None: warnings.warn( "ONNX export squeeze with negative axis " + str(squeeze_dim) + " might cause the onnx model to be incorrect. " + "Negative axis is not supported in ONNX. " + "Axis is converted to " + str(squeeze_dim + rank) + " based on input shape at export time. " + "Passing an tensor of different rank in execution will be incorrect." ) squeeze_dim += rank else: return symbolic_helper._unimplemented( "squeeze", "negative axis with unknown input rank", self ) dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim) if dim_size is None: warnings.warn( "This model contains a squeeze operation on dimension " + str(squeeze_dim) + " on an input " + "with unknown shape. Note that if the size of dimension " + str(squeeze_dim) + " of the input " + "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " + "non-singleton dimensions, it is recommended to export this model using opset " + "version 11 or higher." ) return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) if dim_size > 1: warnings.warn( "This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". The size of " + "this dimension in the given input is " + str(dim_size) + ". The model will " + "be exported without the squeeze node. If the model is intended to be used with dynamic " + "input shapes, please use opset version 11 to " + "export the model." ) return self warnings.warn( "This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". If the model is " + "intended to be used with dynamic input shapes, please use opset version 11 to export the model." ) return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) @_onnx_symbolic("aten::prelu") @_beartype.beartype def prelu(g: jit_utils.GraphContext, self, weight): self_rank = symbolic_helper._get_tensor_rank(self) weight_sizes = symbolic_helper._get_tensor_sizes(weight) weight_rank = len(weight_sizes) if self_rank is not None: if self_rank > 2: # make weight unidirectional broadcastable weight = symbolic_helper._unsqueeze_helper( g, weight, list(range(1, self_rank - 1)) ) elif self_rank == 0 and weight_sizes == [1]: # self and weight are both scalar but weight has rank == 1, squeeze weight. weight = symbolic_helper._squeeze_helper(g, weight, [0]) weight_rank = 0 if self_rank is not None and weight_rank is not None: assert ( self_rank >= weight_rank ), f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}" return g.op("PRelu", self, weight) @_onnx_symbolic("aten::silu") @_beartype.beartype def silu(g: jit_utils.GraphContext, input): return g.op("Mul", input, g.op("Sigmoid", input)) @_onnx_symbolic("aten::mish") @_beartype.beartype def mish(g: jit_utils.GraphContext, input): return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input))) @_onnx_symbolic("aten::relu") @symbolic_helper.quantized_args(True) @_beartype.beartype def relu(g: jit_utils.GraphContext, input): return symbolic_helper._op_with_optional_float_cast( g, "Relu", input, opset_before=14 ) @_onnx_symbolic("aten::relu6") @symbolic_helper.quantized_args(True) @_beartype.beartype def relu6(g: jit_utils.GraphContext, input): return clamp(g, input, 0, 6) @_onnx_symbolic("aten::ceil") @_beartype.beartype def ceil(g: jit_utils.GraphContext, input): return g.op("Ceil", input) @_onnx_symbolic("aten::floor") @_beartype.beartype def floor(g: jit_utils.GraphContext, input): return g.op("Floor", input) @_onnx_symbolic("aten::len") @_beartype.beartype def _len(g: jit_utils.GraphContext, self): sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) return symbolic_helper._squeeze_helper(g, sz_0, [0]) @_onnx_symbolic("aten::threshold") @symbolic_helper.parse_args("v", "t", "t") @_beartype.beartype def threshold(g: jit_utils.GraphContext, self, threshold, value): # See Note [Export inplace] if symbolic_helper._scalar(threshold) != 0: return symbolic_helper._unimplemented("threshold", "non-zero threshold", self) if symbolic_helper._scalar(value) != 0: return symbolic_helper._unimplemented("threshold", "non-zero value", self) return g.op("Relu", self) @_onnx_symbolic("aten::leaky_relu") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "f", "b") @_beartype.beartype def leaky_relu( g: jit_utils.GraphContext, input: _C.Value, negative_slope: float, inplace: bool = False, ): # See Note [Export inplace] return g.op("LeakyRelu", input, alpha_f=negative_slope) @_onnx_symbolic("aten::glu") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def glu(g: jit_utils.GraphContext, input, dim): dim_size = symbolic_helper._get_tensor_dim_size(input, dim) if dim_size is not None: assert dim_size % 2 == 0 first, second = g.op("Split", input, axis_i=dim, outputs=2) return g.op("Mul", first, g.op("Sigmoid", second)) @_onnx_symbolic("aten::softmax") @symbolic_helper.parse_args("v", "i", "none") @_beartype.beartype def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): # Softmax does normalization at vector level. # PyTorch and ONNX use different strategies to split the input tensor into vectors. # Thus dim and axis have different meanings. # PyTorch slices the input tensor into vectors along the `dim`-th dimension. # ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced. # If input is a 2 x 3 tensor: # input = [[1.0, 1.0, 1.0], # [1.0, 1,0, 1,0]] # with dim = 0, the result is: # result = [[0.5, 0.5, 0.5], # [0.5, 0.5, 0.5]] # with axis = 0, the result is: # result = [[0.167, 0.167, 0.167], # [0.167, 0.167, 0.167]] # So only when dim and axis both equal to ndim - 1 (the last dimension), # their semantics are equivalent. # So use softmax when dim and axis both equal to ndim - 1, # otherwise transpose the input to put the vectors to be normalized to the last dimension. # When input rank is not known at export time we compute softmax using a subgraph # with other operators input_dim = symbolic_helper._get_tensor_rank(input) if input_dim is not None: # TODO: remove this as onnx opset 11 spec allows negative axes if dim < 0: dim = input_dim + dim is_transpose_required = input_dim != dim + 1 if is_transpose_required: axes = list(range(input_dim)) axes[dim], axes[-1] = axes[-1], axes[dim] input = g.op("Transpose", input, perm_i=axes) dim = input_dim - 1 softmax = g.op("Softmax", input, axis_i=dim) if dtype and dtype.node().kind() != "prim::Constant": parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") softmax = g.op( "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type(), ) if is_transpose_required: softmax = g.op("Transpose", softmax, perm_i=axes) # type: ignore[possibly-undefined] return softmax # Apply max normalization. input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1)) exp = g.op("Exp", input) sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim]) softmax = g.op("Div", exp, sum) if dtype and dtype.node().kind() != "prim::Constant": parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") softmax = g.op( "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() ) return softmax @_onnx_symbolic("aten::softplus") @_beartype.beartype def softplus(g: jit_utils.GraphContext, self, beta, threshold): beta_const = symbolic_helper._maybe_get_const(beta, "f") if beta_const != 1: return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta) return g.op("Softplus", self) @_onnx_symbolic("aten::get_pool_ceil_padding") @_beartype.beartype def get_pool_ceil_padding(input, kernel_size, stride, padding): # TODO(justinchuby): Looks like this op is deprecated in torch sizes = symbolic_helper._get_tensor_sizes(input) dim = sizes[-len(padding) :] if sizes is not None else None if dim is None or any(i is None for i in dim): return symbolic_helper._unimplemented( "get_pool_ceil_padding", "input size not accessible", input ) ceiled_output_dim = [ int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) + 1 for i in range(0, len(padding)) ] # ensure last pooling starts inside ceiled_output_dim = [ ( ceiled_output_dim[i] - 1 if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) else ceiled_output_dim[i] ) for i in range(0, len(ceiled_output_dim)) ] padding_ceil = [ ( 0 if (stride[i] == 1) else ( kernel_size[i] - ( dim[i] + 2 * padding[i] - ((ceiled_output_dim[i] - 1) * stride[i] + 1) ) ) ) for i in range(0, len(padding)) ] # ensure padding is not > kernel_size padding_ceil = [ ( ( int(padding_ceil[i]) if padding_ceil[i] < kernel_size[i] - 1 else int(kernel_size[i] - 1) ) if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) else int(padding_ceil[i]) ) for i in range(0, len(padding_ceil)) ] return padding_ceil @_onnx_symbolic( "aten::max_pool1d", decorate=[ symbolic_helper._apply_params( "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False ), _export("max_pool1d"), ], ) @_onnx_symbolic( "aten::max_pool2d", decorate=[ symbolic_helper._apply_params( "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False ), _export("max_pool2d"), ], ) @_onnx_symbolic( "aten::max_pool3d", decorate=[ symbolic_helper._apply_params( "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False ), _export("max_pool3d"), ], ) @_beartype.beartype def _max_pool(name, tuple_fn, ndims, return_indices): @symbolic_helper.quantized_args(True, False, False, False, False, False) @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") @_beartype.beartype def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): if set(tuple_fn(dilation)) != {1}: return symbolic_helper._unimplemented(name, "dilation", input) if not stride: stride = kernel_size padding = tuple(tuple_fn(padding)) if ceil_mode: padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding)) else: padding = padding * 2 kwargs = { "kernel_shape_i": tuple_fn(kernel_size), "pads_i": padding, "strides_i": tuple_fn(stride), } # easy but hacky way to get flattened indices values # to be used to convert the indices values to non-flattened. # In ONNX the indices are computed as a flatten 1-D tensor, # so the values in indices are in [0, N x C x D1 x ... x Dn). # To convert the indices to the same format used by Pytorch, # we first execute a maxpool with a kernel and stride of 1 on the same input. # This will result in a tensor of indices in which each index will have it's own value. # Using this tensor as a reference, we extract the first index of each axis and subtract # it from each index of this axis in the indices to convert. # This step will result in a tensor were each dimension has values of indices within # the dimension it is in. # For more information : # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 if return_indices: r, indices = g.op("MaxPool", input, outputs=2, **kwargs) _, flattened_indices = g.op( "MaxPool", input, outputs=2, kernel_shape_i=[1 for _ in range(ndims)], strides_i=[1 for _ in range(ndims)], ) # convert indices to have non-flattened indices values s = symbolic_helper._slice_helper( g, flattened_indices, axes=[2 + i for i in range(ndims)], starts=list(tuple_fn(0)), ends=list(tuple_fn(1)), ) indices = sub(g, indices, s) return r, indices else: r = g.op("MaxPool", input, outputs=1, **kwargs) return r return symbolic_fn max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")( _max_pool( "max_pool1d_with_indices", torch.nn.modules.utils._single, 1, return_indices=True, ) ) max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")( _max_pool( "max_pool2d_with_indices", torch.nn.modules.utils._pair, 2, return_indices=True, ) ) max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")( _max_pool( "max_pool3d_with_indices", torch.nn.modules.utils._triple, 3, return_indices=True, ) ) @_onnx_symbolic( "aten::avg_pool1d", decorate=[ symbolic_helper._apply_params("avg_pool1d", torch.nn.modules.utils._single), _export("avg_pool1d"), ], ) @_onnx_symbolic( "aten::avg_pool2d", decorate=[ symbolic_helper._apply_params("avg_pool2d", torch.nn.modules.utils._pair), _export("avg_pool2d"), ], ) @_onnx_symbolic( "aten::avg_pool3d", decorate=[ symbolic_helper._apply_params("avg_pool3d", torch.nn.modules.utils._triple), _export("avg_pool3d"), ], ) @_beartype.beartype def _avg_pool(name, tuple_fn): @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") @_beartype.beartype def symbolic_fn( g, input: _C.Value, kernel_size: Sequence[int], stride: Sequence[int], padding: Union[int, Sequence[int]], ceil_mode: int, count_include_pad: int, divisor_override=None, ): if not stride: stride = kernel_size padding = symbolic_helper._avgpool_helper( tuple_fn, padding, kernel_size, stride, divisor_override, name ) assert isinstance(padding, tuple) adjusted_padding = padding # Although onnx::AvgPool provides count_include_pad, # The corner case of Average Pooling with ceil_mode on # PyTorch allows sliding window go off bound, which leads to # this accommodation. # More detail on https://github.com/pytorch/pytorch/issues/57178 if count_include_pad: input = symbolic_helper._op_with_optional_float_cast( g, "Pad", input, pads_i=((0,) * 2 + padding) * 2, mode_s="constant", value_f=0.0, opset_before=11, ) adjusted_padding = (0,) * len(padding) if ceil_mode: padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) adjusted_padding = adjusted_padding + tuple( a + b for (a, b) in zip(padding_ceil, adjusted_padding) ) else: adjusted_padding = adjusted_padding * 2 output = g.op( "AveragePool", input, kernel_shape_i=tuple_fn(kernel_size), strides_i=tuple_fn(stride), pads_i=adjusted_padding, ) return output return symbolic_fn @_onnx_symbolic( "aten::adaptive_avg_pool1d", decorate=[ symbolic_helper._apply_params( "adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single ), _export("adaptive_avg_pool1d"), ], ) @_onnx_symbolic( "aten::adaptive_avg_pool2d", decorate=[ symbolic_helper._apply_params( "adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair ), _export("adaptive_avg_pool2d"), ], ) @_onnx_symbolic( "aten::adaptive_avg_pool3d", decorate=[ symbolic_helper._apply_params( "adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple ), _export("adaptive_avg_pool3d"), ], ) @_onnx_symbolic( "aten::adaptive_max_pool1d", decorate=[ symbolic_helper._apply_params( "adaptive_max_pool1d", "MaxPool", torch.nn.modules.utils._single, max_pool1d_with_indices, ), _export("adaptive_max_pool1d"), ], ) @_onnx_symbolic( "aten::adaptive_max_pool2d", decorate=[ symbolic_helper._apply_params( "adaptive_max_pool2d", "MaxPool", torch.nn.modules.utils._pair, max_pool2d_with_indices, ), _export("adaptive_max_pool2d"), ], ) @_onnx_symbolic( "aten::adaptive_max_pool3d", decorate=[ symbolic_helper._apply_params( "adaptive_max_pool3d", "MaxPool", torch.nn.modules.utils._triple, max_pool3d_with_indices, ), _export("adaptive_max_pool3d"), ], ) @_beartype.beartype def _adaptive_pool(name, type, tuple_fn, fn=None): @symbolic_helper.quantized_args(True, False) @_beartype.beartype def symbolic_fn(g, input, output_size): # _adaptive_pool is supported for cases where output_size is 1 for all dimensions, # by executing a GlobalPool. # It is also supported for cases where the output size is a factor of the input size. # For these cases the stride and kernel size are uniform along all the indices of # the same dimension, which makes it possible to export it to ONNX. # for MaxPool, GlobalMaxPool does not return indices, # so we try using max_poolxd_with_indices, and if it is not possible # (input is not a complete tensor or output size not factor of input size) # then we call GlobalAveragePool and return None for the indices output_size_value = output_size try: output_size = symbolic_helper._parse_arg(output_size, "is") except Exception: # FIXME(justinchuby): Avoid catching Exception. # Catch a more specific exception instead. return symbolic_helper._onnx_unsupported( "adaptive pooling, since output_size is not constant.", input ) if output_size == [1] * len(output_size) and type == "AveragePool": return g.op("GlobalAveragePool", input) sizes = symbolic_helper._get_tensor_sizes(input) try: dim = sizes[2:] except Exception: # FIXME(justinchuby): Avoid catching Exception. # Catch a more specific exception instead. dim = None if dim is None or any(i is None for i in dim): if output_size == [1] * len(output_size): return g.op("GlobalMaxPool", input), None return symbolic_helper._unimplemented( name, "input size not accessible", input ) # verify if output size % input size = 0 for all dim mod = [dim[i] % output_size[i] for i in range(0, len(dim))] if mod != [0] * len(mod): if output_size == [1] * len(output_size): return g.op("GlobalMaxPool", input), None return symbolic_helper._unimplemented( name, "output size that are not factor of input size", output_size_value ) k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] # call max_poolxd_with_indices to get indices in the output if type == "MaxPool": return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False) output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k)) return output return symbolic_fn @_beartype.beartype def _prepare_onnx_paddings(dim: int, pad): """Generate paddings in ONNX order based on pad in pytorch. Args: dim: the dimension of the tensor. pad: the paddings in pytorch. The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... """ # The desired order of paddings is # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. # n is the dimension of input. # assume zero-dimensions in the beginning paddings = list(pad[:]) + [0] * (dim * 2 - len(pad)) # reverse order and collate first beginnings and then ends paddings = paddings[-2::-2] + paddings[-1::-2] return paddings @_beartype.beartype def _convert_padding_node(input): padding = symbolic_helper._maybe_get_const(input, "is") if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding): input_list = symbolic_helper._unpack_list(padding) try: padding = [ symbolic_helper._get_const(v, "i", "padding") for v in input_list ] except Exception: # FIXME(justinchuby): Avoid catching Exception. # Catch a more specific exception instead. return symbolic_helper._onnx_opset_unsupported_detailed( "Pad", 9, 11, "The sizes of the padding must be constant", input ) return padding @_onnx_symbolic("aten::constant_pad_nd") @_beartype.beartype def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value): mode = "constant" try: value = symbolic_helper._get_const(value, "f", "value") except Exception: # FIXME(justinchuby): Avoid catching Exception. # Catch a more specific exception instead. return symbolic_helper._onnx_opset_unsupported_detailed( "Pad", 9, 11, "The value for the padding must be constant", value ) padding = _convert_padding_node(padding) paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) return symbolic_helper._op_with_optional_float_cast( g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11 ) @_beartype.beartype def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value): padding = _convert_padding_node(pad) assert len(padding) % 2 == 0 ndim = len(padding) // 2 cur = input for idx in range(ndim): pad_r = padding[-(2 * idx + 1)] pad_l = padding[-(2 * idx + 2)] tensors = [] if pad_l > 0: left = symbolic_helper._slice_helper( g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX] ) tensors.append(left) if pad_l < 0 or pad_r < 0: start = builtins.max(0, -pad_l) end = -(builtins.max(0, -pad_r)) middle = symbolic_helper._slice_helper( g, cur, axes=[2 + idx], starts=[start], ends=[end], ) tensors.append(middle) else: tensors.append(cur) if pad_r > 0: right = symbolic_helper._slice_helper( g, cur, axes=[2 + idx], starts=[0], ends=[pad_r] ) tensors.append(right) cur = g.op("Concat", *tensors, axis_i=(2 + idx)) return cur @_onnx_symbolic("aten::reflection_pad1d") @_onnx_symbolic("aten::reflection_pad2d") @_onnx_symbolic("aten::reflection_pad3d") @_beartype.beartype def reflection_pad(g: jit_utils.GraphContext, input, padding): mode = "reflect" padding = _convert_padding_node(padding) paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) return symbolic_helper._op_with_optional_float_cast( g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 ) @_onnx_symbolic("aten::replication_pad1d") @_onnx_symbolic("aten::replication_pad2d") @_onnx_symbolic("aten::replication_pad3d") @_beartype.beartype def replication_pad(g: jit_utils.GraphContext, input, padding): mode = "edge" padding = _convert_padding_node(padding) paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) return symbolic_helper._op_with_optional_float_cast( g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 ) @_onnx_symbolic("aten::pad") @_beartype.beartype def pad( g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value, mode: _C.Value, value: _C.Value, ): mode = symbolic_helper._parse_arg(mode, "s") if mode == "replicate": return replication_pad(g, input, pad) elif mode == "reflect": return reflection_pad(g, input, pad) elif mode == "constant": return constant_pad_nd(g, input, pad, value) elif mode == "circular": return _pad_circular(g, input, pad) else: raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) @_onnx_symbolic( "aten::upsample_nearest1d", decorate=[ symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest"), _export("upsample_nearest1d"), ], ) @_onnx_symbolic( "aten::upsample_nearest2d", decorate=[ symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest"), _export("upsample_nearest2d"), ], ) @_onnx_symbolic( "aten::upsample_nearest3d", decorate=[ symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest"), _export("upsample_nearest3d"), ], ) @_onnx_symbolic( "aten::upsample_linear1d", decorate=[ symbolic_helper._apply_params("upsample_linear1d", 3, "linear"), _export("upsample_linear1d"), ], ) @_onnx_symbolic( "aten::upsample_bilinear2d", decorate=[ symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear"), _export("upsample_bilinear2d"), ], ) @_onnx_symbolic( "aten::upsample_trilinear3d", decorate=[ symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear"), _export("upsample_trilinear3d"), ], ) @_beartype.beartype def _interpolate(name: str, dim: int, interpolate_mode: str): def symbolic_fn(g, input, output_size, *args): scales, align_corners = symbolic_helper._get_interpolate_attributes( g, interpolate_mode, args ) symbolic_helper._interpolate_warning(interpolate_mode) align_corners = symbolic_helper._maybe_get_scalar(align_corners) if align_corners: return symbolic_helper._unimplemented(name, "align_corners == True", input) if scales is None: scales = symbolic_helper._interpolate_size_to_scales( g, input, output_size, dim ) return g.op("Upsample", input, scales, mode_s=interpolate_mode) return symbolic_fn @_onnx_symbolic("aten::__interpolate") @_beartype.beartype def __interpolate( g: jit_utils.GraphContext, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias, ): scales, mode = symbolic_helper._interpolate_get_scales_and_mode( g, input, size, scale_factor, mode, align_corners ) return g.op("Upsample", input, scales, mode_s=mode) @_onnx_symbolic("aten::bitwise_not") @_beartype.beartype def bitwise_not(g: jit_utils.GraphContext, input): if not symbolic_helper._is_bool(input): raise errors.SymbolicValueError( "ONNX export does NOT support exporting bitwise Not " "for non-boolean input values", input, ) return g.op("Not", input) @_onnx_symbolic("aten::bitwise_or") @_beartype.beartype def bitwise_or(g, self, other): if not symbolic_helper._is_bool(self): raise errors.SymbolicValueError( "ONNX export does NOT support exporting bitwise OR " "for non-boolean input values. self: ", self, ) if not symbolic_helper._is_bool(other): raise errors.SymbolicValueError( "ONNX export does NOT support exporting bitwise OR " "for non-boolean input values. other: ", other, ) return g.op("Or", self, other) @_beartype.beartype def wrap_logical_op_with_cast_to(to_type): def decorator(fn): @functools.wraps(fn) def wrap_with_cast(g, input, other): to_cast_func = globals()[f"_cast_{to_type}"] return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False)) return wrap_with_cast return decorator @_beartype.beartype def wrap_logical_op_with_negation(func: Callable) -> Callable: @functools.wraps(func) def wrap_with_not(g, input, other): return g.op("Not", func(g, input, other)) return wrap_with_not @_onnx_symbolic("aten::__not_") @_beartype.beartype def __not_(g: jit_utils.GraphContext, self): if not symbolic_helper._is_bool(self): raise errors.SymbolicValueError( "ONNX export does NOT support exporting bitwise Not " "for non-boolean input values", self, ) return g.op("Not", self) @_onnx_symbolic("aten::eq") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def eq(g: jit_utils.GraphContext, self, other): if isinstance(self.type(), _C.DeviceObjType) and isinstance( other.type(), _C.DeviceObjType ): # ONNX doesn't have devices, so consider them all to be equal. # The no-op check for equality will get constant-folded. return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool)) self_node = self.node() other_node = other.node() if self_node.kind() == other_node.kind() == "onnx::Constant": if self_node.kindOf("value") == other_node.kindOf("value") == "s": # Exporting strings to ONNX is not supported. # If both strings are constant, we can compare them directly. # The no-op check for equality will get constant-folded. return g.op( "Constant", value_t=torch.tensor( self_node.s("value") == other_node.s("value"), dtype=torch.bool, ), ) return g.op("Equal", self, other) @_onnx_symbolic("aten::ne") @symbolic_helper.quantized_args(True, True) @wrap_logical_op_with_negation @_beartype.beartype def ne(g: jit_utils.GraphContext, self, other): return eq(g, self, other) @_onnx_symbolic("aten::gt") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def gt(g: jit_utils.GraphContext, input, other): return _gt_impl(g, input, other) @_beartype.beartype def _gt_impl(g: jit_utils.GraphContext, input, other): if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) return g.op("Greater", input, other) @_onnx_symbolic("aten::lt") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def lt(g: jit_utils.GraphContext, input, other): return _lt_impl(g, input, other) @_beartype.beartype def _lt_impl(g: jit_utils.GraphContext, input, other): if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) return g.op("Less", input, other) @_onnx_symbolic("aten::ge") @symbolic_helper.quantized_args(True, True) @wrap_logical_op_with_negation @_beartype.beartype def ge(g: jit_utils.GraphContext, input, other): return _lt_impl(g, input, other) @_onnx_symbolic("aten::le") @symbolic_helper.quantized_args(True, True) @wrap_logical_op_with_negation @_beartype.beartype def le(g: jit_utils.GraphContext, input, other): return _gt_impl(g, input, other) @_onnx_symbolic("aten::__and_") @_beartype.beartype def __and_(g: jit_utils.GraphContext, input, other): if not symbolic_helper._is_bool(input): raise errors.SymbolicValueError( "ONNX export does NOT support exporting bitwise AND " "for non-boolean input values", input, ) if not symbolic_helper._is_bool(other): raise errors.SymbolicValueError( "ONNX export does NOT support exporting bitwise AND " "for non-boolean input values", other, ) return g.op("And", input, other) @_onnx_symbolic("aten::__or_") @_beartype.beartype def __or_(g: jit_utils.GraphContext, input, other): if not symbolic_helper._is_bool(input): raise errors.SymbolicValueError( "ONNX export does NOT support exporting bitwise OR " "for non-boolean input values", input, ) if not symbolic_helper._is_bool(other): raise errors.SymbolicValueError( "ONNX export does NOT support exporting bitwise OR " "for non-boolean input values", other, ) return g.op("Or", input, other) @_onnx_symbolic("aten::__xor_") @_beartype.beartype def __xor_(g: jit_utils.GraphContext, input, other): if not symbolic_helper._is_bool(input): raise errors.SymbolicValueError( "ONNX export does NOT support exporting bitwise XOR " "for non-boolean input values", input, ) if not symbolic_helper._is_bool(other): raise errors.SymbolicValueError( "ONNX export does NOT support exporting bitwise XOR " "for non-boolean input values", other, ) return g.op("Xor", input, other) @_onnx_symbolic("aten::logical_and") @wrap_logical_op_with_cast_to("Bool") @_beartype.beartype def logical_and(g: jit_utils.GraphContext, input, other): return g.op("And", input, other) @_onnx_symbolic("aten::logical_or") @wrap_logical_op_with_cast_to("Bool") @_beartype.beartype def logical_or(g: jit_utils.GraphContext, input, other): return g.op("Or", input, other) @_onnx_symbolic("aten::logical_xor") @wrap_logical_op_with_cast_to("Bool") @_beartype.beartype def logical_xor(g: jit_utils.GraphContext, input, other): return g.op("Xor", input, other) @_onnx_symbolic("aten::logical_not") @_beartype.beartype def logical_not(g: jit_utils.GraphContext, input): return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)) @_onnx_symbolic("aten::__rshift_") @_beartype.beartype def __rshift_(g: jit_utils.GraphContext, self, other): # make sure to cast other to self's type # (when self is long, make sure that other is not float) self_scalar_type = _type_utils.JitScalarType.from_value(self) if ( _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) != self_scalar_type ): other = g.op( "Cast", other, to_i=self_scalar_type.onnx_type(), ) two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) # exponent (same type as self) has to be float or double in onnx::Pow if not symbolic_helper._is_fp(self): other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) two_pow = g.op("Pow", two, other) two_pow = g.op( "Cast", two_pow, to_i=self_scalar_type.onnx_type(), ) rshift = g.op("Div", self, two_pow) return rshift @_onnx_symbolic("aten::__lshift_") @_beartype.beartype def __lshift_(g: jit_utils.GraphContext, self, other): # make sure to cast other to self's type # (when self is long, make sure that other is not float) self_scalar_type = _type_utils.JitScalarType.from_value(self) if ( _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) != self_scalar_type ): other = g.op( "Cast", other, to_i=self_scalar_type.onnx_type(), ) two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) # exponent (same type as self) has to be float or double in onnx::Pow if not symbolic_helper._is_fp(self): other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) two_pow = g.op("Pow", two, other) two_pow = g.op( "Cast", two_pow, to_i=self_scalar_type.onnx_type(), ) lshift = g.op("Mul", self, two_pow) return lshift @_onnx_symbolic("aten::where") @symbolic_helper.parse_args("v", "v", "v", "i") @_beartype.beartype def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): # Assumes that torch.where's first argument takes only Bool and Byte tensors. if not symbolic_helper._is_bool(condition): condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) if self is None: condition = nonzero(g, condition) return symbolic_helper._unbind_helper( g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs ) return g.op("Where", condition, self, other) @_onnx_symbolic("aten::log_softmax") @symbolic_helper.parse_args("v", "i", "none") @_beartype.beartype def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): # PyTorch dim and ONNX axis have different meanings. # See Softmax comment for details. # TODO: remove this as onnx opset 11 spec allows negative axes input_dim = symbolic_helper._get_tensor_rank(input) if input_dim is None: return symbolic_helper._unimplemented( "dim", "ONNX and PyTorch use different strategies to split the input. " "Input rank must be known at export time.", ) if dim < 0: dim = input_dim + dim is_transpose_required = input_dim != dim + 1 # ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases. if is_transpose_required: axes = list(range(input_dim)) axes[dim], axes[-1] = axes[-1], axes[dim] input = g.op("Transpose", input, perm_i=axes) dim = input_dim - 1 return_op = g.op("LogSoftmax", input, axis_i=dim) if dtype and dtype.node().kind() != "prim::Constant": parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") return_op = g.op( "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() ) if is_transpose_required: return_op = g.op("Transpose", return_op, perm_i=axes) # type: ignore[possibly-undefined] return return_op @_onnx_symbolic("aten::_log_softmax") @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype def _log_softmax(g: jit_utils.GraphContext, input, dim, half_to_float): if ( half_to_float and _type_utils.JitScalarType.from_value( input, _type_utils.JitScalarType.UNDEFINED ) == _type_utils.JitScalarType.HALF ): input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) return log_softmax(g, input, dim) @_onnx_symbolic("aten::_convolution") @symbolic_helper.parse_args( "v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i" ) @_beartype.beartype def _convolution( g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32=None, ): weight_size = symbolic_helper._get_tensor_sizes(weight) try: kernel_shape = weight_size[2:] except Exception: # FIXME(justinchuby): Avoid catching Exception. # Catch a more specific exception instead. kernel_shape = None if kernel_shape is None or any(i is None for i in kernel_shape): raise errors.SymbolicValueError( "Unsupported: ONNX export of convolution for kernel of unknown shape.", input, ) args = [input, weight] # ONNX only supports 1D bias if ( not symbolic_helper._is_none(bias) and symbolic_helper._get_tensor_rank(bias) == 1 ): args.append(bias) kwargs = { "kernel_shape_i": weight_size[2:], "strides_i": stride, # NB: ONNX supports asymmetric padding, whereas PyTorch supports only # symmetric padding "pads_i": padding + padding, "dilations_i": dilation, "group_i": groups, } if any(o != 0 for o in output_padding): # ONNX supports both output_shape and output_padding. they are equivalent expressive. # output_padding is more straightforward, so we use it here. # output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2 assert transposed assert len(stride) == len(output_padding) kwargs["output_padding_i"] = output_padding n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) if ( not symbolic_helper._is_none(bias) and symbolic_helper._get_tensor_rank(bias) != 1 ): return g.op("Add", n, bias) else: return n @_onnx_symbolic("aten::_convolution_mode") @symbolic_helper.parse_args( "v", "v", "v", "is", "s", "is", "i", ) @_beartype.beartype def _convolution_mode( g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups, ): weight_size = symbolic_helper._get_tensor_sizes(weight) try: kernel_shape = weight_size[2:] except Exception: # FIXME(justinchuby): Avoid catching Exception. # Catch a more specific exception instead. kernel_shape = None if kernel_shape is None or any(i is None for i in kernel_shape): raise errors.SymbolicValueError( "Unsupported: ONNX export of convolution for kernel of unknown shape.", input, ) args = [input, weight] # ONNX only supports 1D bias if ( not symbolic_helper._is_none(bias) and symbolic_helper._get_tensor_rank(bias) == 1 ): args.append(bias) if padding == "valid": padding = "VALID" elif padding == "same": padding = "SAME_UPPER" kwargs = { "kernel_shape_i": weight_size[2:], "strides_i": stride, "auto_pad_s": padding, "dilations_i": dilation, "group_i": groups, } n = g.op("Conv", *args, **kwargs) if ( not symbolic_helper._is_none(bias) and symbolic_helper._get_tensor_rank(bias) != 1 ): return g.op("Add", n, bias) else: return n @_onnx_symbolic("aten::convolution") @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i") @_beartype.beartype def convolution( g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, ): return _convolution( g, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, None, None, None, None, ) @_onnx_symbolic("aten::conv1d") @symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") @_beartype.beartype def conv1d( g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups ): str_padding = symbolic_helper._parse_arg(padding, "s") if str_padding in ["valid", "same"]: return _convolution_mode( g, input, weight, bias, stride, str_padding, dilation, groups, ) else: padding = symbolic_helper._parse_arg(padding, "is") return _convolution( g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None, ) @_onnx_symbolic("aten::conv2d") @symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") @_beartype.beartype def conv2d( g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups ): str_padding = symbolic_helper._parse_arg(padding, "s") if str_padding in ["valid", "same"]: return _convolution_mode( g, input, weight, bias, stride, str_padding, dilation, groups, ) else: padding = symbolic_helper._parse_arg(padding, "is") return _convolution( g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None, ) @_onnx_symbolic("aten::conv3d") @symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") @_beartype.beartype def conv3d( g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups ): str_padding = symbolic_helper._parse_arg(padding, "s") if str_padding in ["valid", "same"]: return _convolution_mode( g, input, weight, bias, stride, str_padding, dilation, groups, ) else: padding = symbolic_helper._parse_arg(padding, "is") return _convolution( g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None, ) @_onnx_symbolic("aten::conv_transpose1d") @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") @_beartype.beartype def conv_transpose1d( g: jit_utils.GraphContext, input, weight, bias, stride, padding, output_padding, groups, dilation, ): return _convolution( g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None, ) @_onnx_symbolic("aten::conv_transpose2d") @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") @_beartype.beartype def conv_transpose2d( g: jit_utils.GraphContext, input, weight, bias, stride, padding, output_padding, groups, dilation, ): return _convolution( g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None, ) @_onnx_symbolic("aten::conv_transpose3d") @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") @_beartype.beartype def conv_transpose3d( g: jit_utils.GraphContext, input, weight, bias, stride, padding, output_padding, groups, dilation, ): return _convolution( g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None, ) @_onnx_symbolic("aten::batch_norm") @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") @_beartype.beartype def batch_norm( g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled, ): symbolic_helper.check_training_mode(training, "batch_norm") if ( torch.is_autocast_enabled() and not symbolic_helper.args_have_same_dtype( [input, weight, bias, running_mean, running_var] ) and GLOBALS.export_onnx_opset_version < 15 ): return symbolic_helper._onnx_opset_unsupported_detailed( "BatchNormalization", 9, 15, "All input tensors must have the same `dtype`." " Turn off Autocast or export using opset version 15.", input, ) weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( g, input, weight, bias, running_mean, running_var ) out = g.op( "BatchNormalization", input, weight, bias, running_mean, running_var, epsilon_f=eps, momentum_f=1 - momentum, outputs=1 if not training else 5, ) if not training: return out else: res, new_running_mean, new_running_var, saved_mean, saved_var = out new_running_mean.setType(running_mean.type()) new_running_var.setType(running_var.type()) saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName()) saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName()) return res @_onnx_symbolic("aten::native_layer_norm") @symbolic_helper.quantized_args(True, False, False, False) @symbolic_helper.parse_args("v", "is", "v", "v", "f") @_beartype.beartype def native_layer_norm( g: jit_utils.GraphContext, input: _C.Value, normalized_shape: Sequence[int], weight: _C.Value, bias: _C.Value, eps: float, ) -> Tuple[_C.Value, _C.Value, _C.Value]: axes = [-i for i in range(len(normalized_shape), 0, -1)] two_cst = symbolic_helper._generate_wrapped_number(g, 2.0) eps_cst = symbolic_helper._generate_wrapped_number(g, eps) if g.opset < 18: mean = g.op("ReduceMean", input, axes_i=axes) else: mean = g.op( "ReduceMean", input, g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), ) numerator = sub(g, input, mean) # Cast it to eps dtype to avoid precision loss is_type_half = ( _type_utils.JitScalarType.from_value(numerator) == _type_utils.JitScalarType.HALF ) if is_type_half: eps_dtype = _type_utils.JitScalarType.from_value(eps_cst) numerator = g.op( "Cast", numerator, to_i=_type_utils.JitScalarType(eps_dtype).onnx_type() ) # variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula if g.opset < 18: variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) else: variance = g.op( "ReduceMean", pow(g, numerator, two_cst), g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), ) denominator = sqrt(g, g.op("Add", variance, eps_cst)) normalized = g.op("Div", numerator, denominator) # Cast back to input type as eps related ops are all done if is_type_half: input_dtype = _type_utils.JitScalarType.from_value(input) normalized = g.op( "Cast", normalized, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() ) if not (weight is None or symbolic_helper._is_none(weight)): normalized = mul(g, normalized, weight) if not (bias is None or symbolic_helper._is_none(bias)): normalized = add(g, normalized, bias) # rdenominator := 1 / sqrt(variance + eps) # According to aten::native_layer_norm, rdenominator should have the same dtype as input, # mean and normalized, so we need to Cast it back if is_type_half: denominator = g.op( "Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() # type: ignore[possibly-undefined] ) rdenominator = g.op("Reciprocal", denominator) else: rdenominator = reciprocal(g, denominator) return normalized, mean, rdenominator @_onnx_symbolic("aten::layer_norm") @symbolic_helper.quantized_args(True, False, False, False) @symbolic_helper.parse_args("v", "is", "v", "v", "f", "b") @_beartype.beartype def layer_norm( g: jit_utils.GraphContext, input: _C.Value, normalized_shape: Sequence[int], weight: _C.Value, bias: _C.Value, eps: float, cudnn_enable: bool, ) -> _C.Value: if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "layer_norm", input, weight, bias, normalized_shape_i=normalized_shape, eps_f=eps, cudnn_enable_i=cudnn_enable, ) normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps) return normalized @_onnx_symbolic("aten::instance_norm") @symbolic_helper.parse_args("v", "v", "v", "v", "v", "b", "f", "f", "b") @_beartype.beartype def instance_norm( g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var, use_input_stats: bool, momentum: Number, eps: Number, cudnn_enabled: bool, ): symbolic_helper.check_training_mode(use_input_stats, "instance_norm") channel_size = symbolic_helper._get_tensor_dim_size(input, 1) if weight is None or symbolic_helper._is_none(weight): if channel_size is None: raise errors.SymbolicValueError( "Unsupported: ONNX export of instance_norm for unknown channel size.", input, ) weight_value = torch.tensor( [1.0] * channel_size, dtype=_type_utils.JitScalarType.from_value(input).dtype(), ) weight = g.op("Constant", value_t=weight_value) if bias is None or symbolic_helper._is_none(bias): if channel_size is None: raise errors.SymbolicValueError( "Unsupported: ONNX export of instance_norm for unknown channel size.", input, ) bias_value = torch.tensor( [0.0] * channel_size, dtype=_type_utils.JitScalarType.from_value(input).dtype(), ) bias = g.op("Constant", value_t=bias_value) if ( running_mean is None or symbolic_helper._is_none(running_mean) or running_var is None or symbolic_helper._is_none(running_var) ): return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) else: input_size = symbolic_helper._get_tensor_sizes(input) # If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm. # For more information instance_norm(): # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542 input_size_reshape = input_size.copy() n = input_size[0] if n is None: raise errors.SymbolicValueError( "Unsupported: ONNX export of instance_norm training for unknown " "batch size.", input, ) c = input_size[1] input_size_reshape[0] = 1 input_size_reshape[1] = n * c weight_ = repeat( g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) ) bias_ = repeat( g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) ) running_mean_ = repeat( g, running_mean, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), ) running_var_ = repeat( g, running_var, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), ) input_reshaped = g.op( "Reshape", input, g.op("Constant", value_t=torch.LongTensor(input_size_reshape)), ) out = batch_norm( g, input_reshaped, weight_, bias_, running_mean_, running_var_, use_input_stats, momentum, eps, cudnn_enabled, ) return view(g, out, g.op("Constant", value_t=torch.tensor(input_size))) @_onnx_symbolic("aten::unfold") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def unfold(g: jit_utils.GraphContext, input, dimension, size, step): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step) sizes = symbolic_helper._get_tensor_sizes(input) # FIXME(justinchuby): Get rid of the try catch here to improve readability try: sizedim = sizes[dimension] except Exception: # FIXME(justinchuby): Avoid catching Exception. # Catch a more specific exception instead. sizedim = None if sizedim is not None: low_indices = range(0, sizedim, step) hi_indices = range(size, sizedim + 1, step) stack = [ symbolic_helper._slice_helper( g, input, axes=[dimension], starts=[low], ends=[hi] ) for low, hi in zip(low_indices, hi_indices) ] ndim = len(sizes) perm = list(range(0, ndim)) perm.append(perm.pop(dimension)) unsqueeze = [ symbolic_helper._unsqueeze_helper( g, g.op("Transpose", t, perm_i=perm), [dimension] ) for t in stack ] return g.op("Concat", *unsqueeze, axis_i=dimension) else: return symbolic_helper._unimplemented( "Unfold", "input size not accessible", input ) @_onnx_symbolic("aten::elu") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "t", "t", "t") @_beartype.beartype def elu(g: jit_utils.GraphContext, input, alpha, scale, input_scale): if scale and scale != 1.0: return symbolic_helper._unimplemented( "scale", "does not support scale in Elu", scale ) if input_scale and input_scale != 1.0: return symbolic_helper._unimplemented( "input_scale", "does not support input_scale in Elu", input_scale ) # See Note [Export inplace] return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha)) @_onnx_symbolic("aten::selu") @symbolic_helper.quantized_args(True) @_beartype.beartype def selu(g: jit_utils.GraphContext, input): return g.op("Selu", input) @_onnx_symbolic("aten::index_select") @symbolic_helper.parse_args("v", "i", "v") @_beartype.beartype def index_select(g: jit_utils.GraphContext, self, dim, index): # In case of a scalar index, index_select returns a tensor with the same rank as the input. # To match this behavior in ONNX, we make index a 1D tensor so that the following gather # also produces a tensor with the same rank as the input. return symbolic_helper._select_helper(g, self, dim, index) @_onnx_symbolic("aten::index_put") @_beartype.beartype def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accumulate): if symbolic_helper._is_packed_list(indices_list_value): indices_list = symbolic_helper._unpack_list(indices_list_value) else: indices_list = [indices_list_value] if symbolic_helper.is_caffe2_aten_fallback(): args = [self] + indices_list + [values, accumulate] return g.at("index_put", *args) accumulate = symbolic_helper._parse_arg(accumulate, "b") if len(indices_list) == 0: if accumulate: return add(g, self, values) return values symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self) @_onnx_symbolic("aten::index_fill") @_beartype.beartype def index_fill(g: jit_utils.GraphContext, self, dim, index, value): dim_value = symbolic_helper._parse_arg(dim, "i") if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value, ) expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) value = symbolic_helper._maybe_get_scalar(value) value = symbolic_helper._if_scalar_type_as(value, self) expanded_value = expand(g, value, expanded_index_shape, None) return scatter(g, self, dim, expanded_index, expanded_value) @_onnx_symbolic("aten::index_copy") @_beartype.beartype def index_copy(g: jit_utils.GraphContext, self, dim, index, source): dim_value = symbolic_helper._parse_arg(dim, "i") if symbolic_helper.is_caffe2_aten_fallback(): return g.at("index_copy", self, index, source, dim_i=dim_value) expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) return scatter(g, self, dim, expanded_index, source) @_onnx_symbolic("aten::bucketize") @symbolic_helper.parse_args("v", "v", "b", "b") @_beartype.beartype def bucketize( g: jit_utils.GraphContext, self, boundaries, out_int32=False, right=False ): out_type = _C_onnx.TensorProtoDataType.INT64 if out_int32: out_type = _C_onnx.TensorProtoDataType.INT32 # A tensor expanded_boundaries is created such that it # contains a copy of boundaries for each element of self. new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0) # Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md tensor_rank = symbolic_helper._get_tensor_rank(self) assert tensor_rank is not None unsqueeze_axes = list(range(1, tensor_rank + 1)) expanded_boundaries = expand( g, symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes), new_shape, None, ) # Compare each element of self to boundaries to get a tensor # with leading 1s and trailing 0s. # e.g., 4 > [1, 3, 4] = [1, 1, 0] # The index of the last 1 is the bucket where the element should go. if right: cond = ge(g, self, expanded_boundaries) else: cond = gt(g, self, expanded_boundaries) cond_out = g.op("Cast", cond, to_i=out_type) # Sum to get the number of 1s corresponding to each element, # which is the same as the bucket index. # e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2 return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0) @_onnx_symbolic("aten::type_as") @_beartype.beartype def type_as(g: jit_utils.GraphContext, self, other): self_dtype = symbolic_helper._try_get_scalar_type(self) other_dtype = symbolic_helper._try_get_scalar_type(other) if self_dtype == other_dtype and self_dtype is not None: return self if other_dtype is not None: return g.op( "Cast", self, to_i=other_dtype.onnx_type(), ) if symbolic_helper.is_caffe2_aten_fallback(): # We don't know the type of other, bail by emitting ATen return g.at("type_as", self, other) raise errors.SymbolicValueError( "Unsupported: ONNX export of type_as for tensor " "of unknown dtype. Please check if the dtype of the " "parameter passed to the type_as function is correct.", other, ) @_onnx_symbolic("aten::cosine_similarity") @symbolic_helper.parse_args("v", "v", "i", "f") @_beartype.beartype def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps) cross = symbolic_helper._reducesum_helper( g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0 ) x1_l2 = symbolic_helper._reducesum_helper( g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0 ) x2_l2 = symbolic_helper._reducesum_helper( g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0 ) div_tens = max( g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps])) ) return div(g, cross, div_tens) @_onnx_symbolic("aten::pairwise_distance") @_beartype.beartype def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim): if not symbolic_helper._is_value(eps): eps = g.op("Constant", value_t=torch.tensor([eps])) inv_p = div( g, g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)), add(g, p, eps), ) summation = symbolic_helper._reducesum_helper( g, pow(g, sub(g, input1, input2), p), axes_i=[-1], keepdims_i=symbolic_helper._parse_arg(keepdim, "i"), ) return pow(g, summation, inv_p) @_onnx_symbolic("aten::clone") # ignore clone operators that are inserted by PyTorch autograd @_beartype.beartype def clone(g: jit_utils.GraphContext, input, unused_memory_format): return input @_onnx_symbolic("aten::abs") @_beartype.beartype def abs(g: jit_utils.GraphContext, self): return g.op("Abs", self) @_onnx_symbolic("aten::log") @_beartype.beartype def log(g: jit_utils.GraphContext, self): return g.op("Log", self) @_onnx_symbolic("aten::log1p") @_beartype.beartype def log1p(g: jit_utils.GraphContext, self): return log(g, add(g, symbolic_helper._if_scalar_type_as(torch.ones(1), self), self)) @_onnx_symbolic("aten::log10") @_beartype.beartype def log10(g: jit_utils.GraphContext, self): _ln10 = 2.30258509299404568401 return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10]))) @_onnx_symbolic("aten::pow") @_beartype.beartype def pow(g: jit_utils.GraphContext, self, exponent): f_dtype = _type_utils.JitScalarType.from_value(self) if not symbolic_helper._is_fp(self): f_dtype = _type_utils.JitScalarType.FLOAT self = g.op("Cast", self, to_i=f_dtype.onnx_type()) if not symbolic_helper._is_fp(exponent): exponent = g.op( "Cast", exponent, to_i=f_dtype.onnx_type(), ) pow = g.op("Pow", self, exponent) return pow @_onnx_symbolic("aten::clamp") @_beartype.beartype def clamp(g: jit_utils.GraphContext, self, min, max): # min or max may be None that we need to dispatch to # Clip separately, as ONNX does not have None syntax if symbolic_helper._is_none(min): return clamp_max(g, self, max) elif symbolic_helper._is_none(max): return clamp_min(g, self, min) else: if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max): return symbolic_helper._op_with_optional_float_cast( g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12, ) else: return clamp_max(g, clamp_min(g, self, min), max) @_onnx_symbolic("aten::clamp_min") @symbolic_helper.parse_args("v", "v") @_beartype.beartype def clamp_min(g: jit_utils.GraphContext, self, min): if symbolic_helper._is_constant(min): return symbolic_helper._op_with_optional_float_cast( g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12 ) else: dtype = _type_utils.JitScalarType.from_value(self) min = g.op("Cast", min, to_i=dtype.onnx_type()) return symbolic_helper._op_with_optional_float_cast( g, "Max", self, min, opset_before=12 ) @_onnx_symbolic("aten::clamp_max") @symbolic_helper.parse_args("v", "v") @_beartype.beartype def clamp_max(g: jit_utils.GraphContext, self, max): if symbolic_helper._is_constant(max): return symbolic_helper._op_with_optional_float_cast( g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12 ) else: dtype = _type_utils.JitScalarType.from_value(self) max = g.op("Cast", max, to_i=dtype.onnx_type()) return symbolic_helper._op_with_optional_float_cast( g, "Min", self, max, opset_before=12 ) @_onnx_symbolic("aten::max") # torch.max (same for torch.min) actually has two interfaces smashed together: # torch.max(x, dim, keepdim) and torch.max(x, y) # TODO(justinchuby): Support multiple quantized args in output @_beartype.beartype def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) @_onnx_symbolic("aten::maximum") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def maximum(g: jit_utils.GraphContext, input, other): return max(g, input, dim_or_y=other) @_onnx_symbolic("aten::min") # TODO(justinchuby): Support multiple quantized args in output @_beartype.beartype def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) @_onnx_symbolic("aten::minimum") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def minimum(g: jit_utils.GraphContext, input, other): return min(g, input, dim_or_y=other) @_onnx_symbolic("aten::amax") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "is", "i") @_beartype.beartype def amax(g: jit_utils.GraphContext, self, dim, keepdim): return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim) @_onnx_symbolic("aten::amin") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "is", "i") @_beartype.beartype def amin(g: jit_utils.GraphContext, self, dim, keepdim): return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim) @_onnx_symbolic("aten::aminmax") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): reduce_kwargs = {"keepdims_i": keepdim} if not symbolic_helper._is_none(dim): dim = symbolic_helper._get_const(dim, "i", "dim") reduce_kwargs["axes_i"] = [dim] return g.op("ReduceMin", self, **reduce_kwargs), g.op( "ReduceMax", self, **reduce_kwargs ) @_onnx_symbolic("aten::exp") @_beartype.beartype def exp(g: jit_utils.GraphContext, self): return g.op("Exp", self) @_onnx_symbolic("aten::dropout_") @_onnx_symbolic("aten::dropout") @symbolic_helper.parse_args("v", "f", "i") @_beartype.beartype def dropout(g: jit_utils.GraphContext, input, p, train): symbolic_helper.check_training_mode(train, "dropout") # if train is False, dropout is no-op if not train: return input r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) return r @_onnx_symbolic( "aten::alpha_dropout_", decorate=[symbolic_helper._apply_params("aten::alpha_dropout_")], ) # See Note [Export inplace] @_onnx_symbolic( "aten::feature_alpha_dropout_", decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout_")], ) @_onnx_symbolic( "aten::feature_dropout_", decorate=[symbolic_helper._apply_params("aten::feature_dropout_")], ) @_onnx_symbolic( "aten::feature_alpha_dropout", decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout")], ) @_onnx_symbolic( "aten::alpha_dropout", decorate=[symbolic_helper._apply_params("aten::alpha_dropout")], ) @_onnx_symbolic( "aten::feature_dropout", decorate=[symbolic_helper._apply_params("aten::feature_dropout")], ) @_beartype.beartype def _unsupported_dropout(name: str): @symbolic_helper.parse_args("v", "none", "b") @_beartype.beartype def feature_dropout(g, input, p, train): # NB: In inference mode, FeatureDropout is exported as an identity op. if train: return symbolic_helper._unimplemented(name, "training mode", input) return input return feature_dropout @_onnx_symbolic("aten::norm") @symbolic_helper.parse_args("v", "t", "is", "i", "v") @_beartype.beartype def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None): if p == 1: f = symbolic_helper._reduce_op_symbolic_helper("ReduceL1") elif p == 2: f = symbolic_helper._reduce_op_symbolic_helper("ReduceL2") else: raise errors.SymbolicValueError( "ONNX export only p-norms with p of 1 or 2", self ) result = f(g, self, dim=dim, keepdim=keepdim) if dtype is not None: dtype = symbolic_helper._get_const(dtype, "i", "dtype") result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) return result @_onnx_symbolic("aten::conv_tbc") @symbolic_helper.parse_args("v", "v", "v", "i") @_beartype.beartype def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("conv_tbc", input, weight, bias, pad_i=pad) else: # input must have 3 dimensions, see: # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 # input = (time, batch, in_channels) # weight = (kernel_width, in_channels, out_channels) # bias = (out_channels,) input = g.op("Transpose", input, perm_i=[1, 2, 0]) weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) return g.op("Transpose", conv, perm_i=[2, 0, 1]) @_onnx_symbolic("aten::_unique") @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse): if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "_unique", input, sorted_i=sorted, return_inverse_i=return_inverse, outputs=2, ) else: return symbolic_helper._onnx_unsupported("_unique", input) @_onnx_symbolic("aten::_unique2") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts): if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "_unique2", input, sorted_i=sorted, return_inverse_i=return_inverse, return_counts_i=return_counts, outputs=3, ) symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) @_onnx_symbolic("aten::_cast_Byte") @_deprecation.deprecated( "2.0", "the future", "Avoid using this function and create a Cast node instead", ) @_beartype.beartype def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8) @_onnx_symbolic("aten::_cast_Char") @_deprecation.deprecated( "2.0", "the future", "Avoid using this function and create a Cast node instead", ) @_beartype.beartype def _cast_Char(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8) @_onnx_symbolic("aten::_cast_Short") @_deprecation.deprecated( "2.0", "the future", "Avoid using this function and create a Cast node instead", ) @_beartype.beartype def _cast_Short(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16) @_onnx_symbolic("aten::_cast_Int") @_deprecation.deprecated( "2.0", "the future", "Avoid using this function and create a Cast node instead", ) @_beartype.beartype def _cast_Int(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) @_onnx_symbolic("aten::_cast_Long") @_deprecation.deprecated( "2.0", "the future", "Avoid using this function and create a Cast node instead", ) @_beartype.beartype def _cast_Long(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) @_onnx_symbolic("aten::_cast_Half") @_deprecation.deprecated( "2.0", "the future", "Avoid using this function and create a Cast node instead", ) @_beartype.beartype def _cast_Half(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16) @_onnx_symbolic("aten::_cast_Float") @_deprecation.deprecated( "2.0", "the future", "Avoid using this function and create a Cast node instead", ) @_beartype.beartype def _cast_Float(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) @_onnx_symbolic("aten::_cast_Double") @_deprecation.deprecated( "2.0", "the future", "Avoid using this function and create a Cast node instead", ) @_beartype.beartype def _cast_Double(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE) @_onnx_symbolic("aten::_cast_Bool") @_deprecation.deprecated( "2.0", "the future", "Avoid using this function and create a Cast node instead", ) @_beartype.beartype def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking): return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL) @_onnx_symbolic("aten::empty") @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") @_beartype.beartype def empty( g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False, memory_format=None, ): return zeros(g, sizes, dtype, layout, device, pin_memory) @_onnx_symbolic("aten::empty_like") @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") @_beartype.beartype def empty_like( g: jit_utils.GraphContext, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None, ): return zeros_like(g, input, dtype, layout, device, pin_memory) @_onnx_symbolic("aten::new_empty") @_beartype.beartype def new_empty( g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False ): self_dtype = symbolic_helper._try_get_scalar_type(self) if symbolic_helper._is_none(dtype) and self_dtype is not None: dtype = self_dtype return empty(g, sizes, dtype, layout, device, pin_memory) @_onnx_symbolic("aten::scalar_tensor") @_beartype.beartype def scalar_tensor(g: jit_utils.GraphContext, scalar, dtype, *options): dtype = symbolic_helper._get_const(dtype, "i", "dtype") if dtype is None: dtype = _type_utils.JitScalarType.FLOAT scalar = g.op("Cast", scalar, to_i=_type_utils.JitScalarType(dtype).onnx_type()) return scalar @_onnx_symbolic("aten::tensor") @_beartype.beartype def tensor( g: jit_utils.GraphContext, data, dtype=None, device=None, requires_grad=False ): dtype = symbolic_helper._get_const(dtype, "i", "dtype") if symbolic_helper._is_packed_list(data): if dtype is None: dtype = _type_utils.JitScalarType.from_value( symbolic_helper._unpack_list(data)[0] ) input_list = list() for t in symbolic_helper._unpack_list(data): shape_reference = g.op("Constant", value_t=torch.LongTensor([1])) t = symbolic_helper._reshape_helper(g, t, shape_reference) t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type()) input_list.append(t) return g.op("Concat", *input_list, axis_i=0) else: if dtype is None: dtype = _type_utils.JitScalarType.from_value(data) if symbolic_helper._is_list(data) and ( symbolic_helper._is_tensor_list(data) or symbolic_helper._is_scalar_list(data) ): data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1) return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type()) @_onnx_symbolic("aten::as_tensor") @_beartype.beartype def as_tensor(g: jit_utils.GraphContext, data, dtype=None, device=None): return tensor(g, data, dtype, device) @_onnx_symbolic("aten::zeros") @symbolic_helper.parse_args("v", "i", "v", "v", "v") @_beartype.beartype def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): # NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it if dtype is None: scalar_type = _type_utils.JitScalarType.FLOAT else: scalar_type = _type_utils.JitScalarType(dtype) sizes_ = symbolic_helper._maybe_get_const(sizes, "is") if isinstance(sizes_, list) and len(sizes_) == 0: sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) return g.op( "ConstantOfShape", sizes, value_t=torch.tensor([0], dtype=scalar_type.dtype()), ) @_onnx_symbolic("aten::zeros_like") @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") @_beartype.beartype def zeros_like( g: jit_utils.GraphContext, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None, ): shape = g.op("Shape", input) if symbolic_helper._is_none(dtype): scalar_type = _type_utils.JitScalarType.from_value( input, _type_utils.JitScalarType.FLOAT ) else: scalar_type = _type_utils.JitScalarType(dtype) return g.op( "ConstantOfShape", shape, value_t=torch.tensor([0], dtype=scalar_type.dtype()), ) @_onnx_symbolic("aten::new_zeros") @_beartype.beartype def new_zeros( g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False ): self_dtype = symbolic_helper._try_get_scalar_type(self) if symbolic_helper._is_none(dtype) and self_dtype is not None: dtype = self_dtype return zeros(g, sizes, dtype, layout, device, pin_memory) @_onnx_symbolic("aten::zero") @_beartype.beartype def zero(g: jit_utils.GraphContext, self): self_dtype = symbolic_helper._try_get_scalar_type(self) return zeros_like(g, self, self_dtype) @_onnx_symbolic("aten::ones") @symbolic_helper.parse_args("v", "i", "v", "v", "v") @_beartype.beartype def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): if dtype is None: scalar_type = _type_utils.JitScalarType.FLOAT else: scalar_type = _type_utils.JitScalarType(dtype) sizes_ = symbolic_helper._maybe_get_const(sizes, "is") if isinstance(sizes_, list) and len(sizes_) == 0: sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) return g.op( "ConstantOfShape", sizes, value_t=torch.tensor([1], dtype=scalar_type.dtype()), ) @_onnx_symbolic("aten::ones_like") @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") @_beartype.beartype def ones_like( g: jit_utils.GraphContext, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None, ): shape = g.op("Shape", input) if symbolic_helper._is_none(dtype): scalar_type = _type_utils.JitScalarType.from_value( input, _type_utils.JitScalarType.FLOAT ) else: scalar_type = _type_utils.JitScalarType(dtype) return g.op( "ConstantOfShape", shape, value_t=torch.tensor([1], dtype=scalar_type.dtype()), ) @_onnx_symbolic("aten::new_ones") @_beartype.beartype def new_ones( g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False ): self_dtype = symbolic_helper._try_get_scalar_type(self) if symbolic_helper._is_none(dtype) and self_dtype is not None: dtype = self_dtype return ones(g, sizes, dtype, layout, device, pin_memory) @_onnx_symbolic("aten::full") @_beartype.beartype def full( g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False ): const_value = symbolic_helper._maybe_get_const(value, "t") if symbolic_helper._is_value(const_value): dtype = _type_utils.JitScalarType.FLOAT if dtype is None else dtype tmp = zeros(g, sizes, dtype, layout, device) return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) else: dtype = symbolic_helper._get_const(dtype, "i", "dtype") if dtype is None: scalar_type = _type_utils.JitScalarType.FLOAT else: scalar_type = _type_utils.JitScalarType(dtype) sizes_ = symbolic_helper._maybe_get_const(sizes, "is") if isinstance(sizes_, list) and len(sizes_) == 0: sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) return g.op( "ConstantOfShape", sizes, value_t=const_value.view(1).to(scalar_type.dtype()), ) @_onnx_symbolic("aten::full_like") @_beartype.beartype def full_like( g: jit_utils.GraphContext, input, fill_value, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None, ): fill_value = symbolic_helper._maybe_get_const(fill_value, "f") dtype = symbolic_helper._get_const(dtype, "i", "dtype") if dtype is None: scalar_type = _type_utils.JitScalarType.from_value( input, _type_utils.JitScalarType.FLOAT ) else: scalar_type = _type_utils.JitScalarType(dtype) if symbolic_helper._is_value(fill_value): tmp = zeros_like(g, input, dtype, layout, device) fill_value = g.op("Cast", fill_value, to_i=scalar_type.onnx_type()) return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1))) else: shape = g.op("Shape", input) return g.op( "ConstantOfShape", shape, value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()), ) @_onnx_symbolic("aten::new_full") @_beartype.beartype def new_full( g: jit_utils.GraphContext, self, size, fill_value, dtype, layout, device, pin_memory=False, ): self_dtype = symbolic_helper._try_get_scalar_type(self) if symbolic_helper._is_none(dtype) and self_dtype is not None: dtype = self_dtype return full(g, size, fill_value, dtype, layout, device, pin_memory) @_onnx_symbolic("aten::eye") @_beartype.beartype def eye(g: jit_utils.GraphContext, *args): if len(args) == 5: # aten::eye(n, dtype, layout, device, pin_memory) n, dtype, layout, device, pin_memory = args dim_size = symbolic_helper._unsqueeze_helper(g, n, [0]) shape = g.op("Concat", dim_size, dim_size, axis_i=0) tensor = zeros(g, shape, dtype, layout, device) return g.op("EyeLike", tensor) if len(args) == 6: # aten::eye(n, m, dtype, layout, device, pin_memory) n, m, dtype, layout, device, pin_memory = args shape = g.op( "Concat", symbolic_helper._unsqueeze_helper(g, n, [0]), symbolic_helper._unsqueeze_helper(g, m, [0]), axis_i=0, ) tensor = zeros(g, shape, dtype, layout, device) return g.op("EyeLike", tensor) return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments") @_onnx_symbolic("aten::slice") @_beartype.beartype def slice(g: jit_utils.GraphContext, self, *args): if len(args) == 4: # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor dim, start, end, step = args step = symbolic_helper._parse_arg(step, "i") if step != 1: raise errors.SymbolicValueError("step!=1 is currently not supported", self) is_start_none = start.node().kind() == "prim::Constant" and isinstance( start.type(), _C.NoneType ) is_end_none = end.node().kind() == "prim::Constant" and isinstance( end.type(), _C.NoneType ) is_start_onnx_const = start.node().kind() == "onnx::Constant" is_end_onnx_const = end.node().kind() == "onnx::Constant" if ( ((not is_start_none) and (not is_start_onnx_const)) or ((not is_end_none) and (not is_end_onnx_const)) or dim.node().kind() != "onnx::Constant" ): if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: raise errors.SymbolicValueError( "Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice " "is a deprecated experimental op. Please use statically allocated " "variables or export to a higher opset version.", self, ) else: start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0]) end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0]) dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0]) return g.op( "DynamicSlice", self, start_unsqueezed, end_unsqueezed, dim_unsqueezed, ) else: start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") end = ( _constants.INT64_MAX if is_end_none else symbolic_helper._parse_arg(end, "i") ) dim = symbolic_helper._parse_arg(dim, "i") return symbolic_helper._slice_helper( g, self, axes=[dim], starts=[start], ends=[end] ) elif len(args) == 3: # aten::slice(t[] l, int start, int end, int step) -> t[] start, end, step = args dim = 0 is_start_none = start.node().kind() == "prim::Constant" and isinstance( start.type(), _C.NoneType ) is_end_none = end.node().kind() == "prim::Constant" and isinstance( end.type(), _C.NoneType ) start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") end = ( _constants.INT64_MAX if is_end_none else symbolic_helper._parse_arg(end, "i") ) return symbolic_helper._slice_helper( g, self, axes=[dim], starts=[start], ends=[end] ) return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments") @_onnx_symbolic("aten::hardtanh") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "f", "f") @_beartype.beartype def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): return symbolic_helper._op_with_optional_float_cast( g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12 ) @_onnx_symbolic("aten::hardswish") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v") @_beartype.beartype def hardswish(g: jit_utils.GraphContext, self): hs = hardsigmoid(g, self) return g.op("Mul", self, hs) @_onnx_symbolic("aten::hardsigmoid") # Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp @symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) @symbolic_helper.parse_args("v") @_beartype.beartype def hardsigmoid(g: jit_utils.GraphContext, self): # Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid. # See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html return g.op("HardSigmoid", self, alpha_f=1 / 6) @_onnx_symbolic("aten::tanhshrink") @symbolic_helper.parse_args("v") @_beartype.beartype def tanhshrink(g: jit_utils.GraphContext, self): return g.op("Sub", self, tanh(g, self)) @_onnx_symbolic("aten::hardshrink") @symbolic_helper.parse_args("v", "f") @_beartype.beartype def hardshrink(g: jit_utils.GraphContext, self, lambd): scalar_type = _type_utils.JitScalarType.from_value( self, _type_utils.JitScalarType.FLOAT ) lambd_op = g.op( "Constant", value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), ) cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op))) return g.op( "Where", cond, self, g.op( "Constant", value_t=torch.tensor(0, dtype=scalar_type.dtype()), ), ) @_onnx_symbolic("aten::softshrink") @symbolic_helper.parse_args("v", "f") @_beartype.beartype def softshrink(g: jit_utils.GraphContext, self, lambd): scalar_type = _type_utils.JitScalarType.from_value( self, _type_utils.JitScalarType.FLOAT ) lambd_op = g.op( "Constant", value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), ) gt_cond = gt(g, self, lambd_op) gt_out = g.op( "Where", gt_cond, sub(g, self, lambd_op), g.op( "Constant", value_t=torch.tensor(0, dtype=scalar_type.dtype()), ), ) lt_cond = lt(g, self, neg(g, lambd_op)) lt_out = g.op( "Where", lt_cond, add(g, self, lambd_op), g.op( "Constant", value_t=torch.tensor(0, dtype=scalar_type.dtype()), ), ) return add(g, gt_out, lt_out) @_onnx_symbolic("aten::alias") @_beartype.beartype def alias(g: jit_utils.GraphContext, self): return self @_onnx_symbolic("aten::unsqueeze") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def unsqueeze(g: jit_utils.GraphContext, self, dim): """Implement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`""" # Handle negative dim if dim < 0: rank = symbolic_helper._get_tensor_rank(self) if rank is not None: warnings.warn( "ONNX export unsqueeze with negative axis " + str(dim) + " might cause the onnx model to be incorrect. " + "Negative axis is not supported in ONNX. " + "Axis is converted to " + str(dim + rank + 1) + " based on input shape at export time. " + "Passing an tensor of different rank in execution will be incorrect." ) dim = dim + rank + 1 else: return symbolic_helper._unimplemented( "unsqueeze", "negative axis with unknown input rank", self ) return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim]) @_onnx_symbolic("aten::sort") # TODO(justinchuby): Support multiple quantized args in output @symbolic_helper.parse_args("v", "i", "i", "none") @_beartype.beartype def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): if out is not None: symbolic_helper._unimplemented( "Sort", "Out parameter is not supported for sort", self ) self_sizes = symbolic_helper._get_tensor_sizes(self) try: dim_size = self_sizes[dim] except Exception: # FIXME(justinchuby): Avoid catching Exception. # Catch a more specific exception instead. dim_size = None if dim_size is None: return symbolic_helper._unimplemented("Sort", "input size not accessible", self) return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) @_onnx_symbolic("aten::numel") @_beartype.beartype def numel(g: jit_utils.GraphContext, self): return symbolic_helper._numel_helper(g, self) @_onnx_symbolic("aten::topk") # TODO(justinchuby): Support multiple quantized args in output @symbolic_helper.parse_args("v", "i", "i", "i", "i", "none") @_beartype.beartype def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): if out is not None: symbolic_helper._unimplemented( "TopK", "Out parameter is not supported for topk", self ) if not largest: symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported", self) return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) @_onnx_symbolic("prim::convert_element_type") @_beartype.beartype def convert_element_type(g: jit_utils.GraphContext, self, *args): dtype = symbolic_helper._get_const(args[0], "i", "dtype") return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) @_onnx_symbolic("aten::to") @_beartype.beartype def to(g: jit_utils.GraphContext, self, *args): @_beartype.beartype def is_aten_to_device_only(args): if len(args) == 4: # aten::to(Tensor, Device, bool, bool, memory_format) return ( args[0].node().kind() == "prim::device" or args[0].type().isSubtypeOf(_C.ListType.ofInts()) or isinstance(args[0].type(), _C.DeviceObjType) ) elif len(args) == 5: # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) # When dtype is None, this is a aten::to(device) call dtype = symbolic_helper._get_const(args[1], "i", "dtype") return dtype is None elif len(args) in (6, 7): # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor # When dtype is None, this is a aten::to(device) call dtype = symbolic_helper._get_const(args[0], "i", "dtype") return dtype is None return False # ONNX doesn't have a concept of a device, so we ignore device-only casts if is_aten_to_device_only(args): return self if len(args) == 4: # TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=]() # In this case, the constant value is a tensor not int, # so symbolic_helper._maybe_get_const(args[0], 'i') would not work. dtype = args[0] if ( symbolic_helper._is_value(args[0]) and args[0].node().kind() == "onnx::Constant" ): tval = symbolic_helper._node_get(args[0].node(), "value") if isinstance(tval, torch.Tensor): if len(tval.shape) == 0: tval = tval.item() dtype = int(tval) else: dtype = tval if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor): # aten::to(Tensor, Tensor, bool, bool, memory_format) dtype = _type_utils.JitScalarType.from_value(args[0]) return g.op( "Cast", self, to_i=dtype.onnx_type(), ) else: # aten::to(Tensor, ScalarType, bool, bool, memory_format) # memory_format is ignored return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) elif len(args) == 5: # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) dtype = symbolic_helper._get_const(args[1], "i", "dtype") # memory_format is ignored return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) elif len(args) == 6: # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor dtype = symbolic_helper._get_const(args[0], "i", "dtype") # Layout, device and memory_format are ignored return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) elif len(args) == 7: # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor dtype = symbolic_helper._get_const(args[0], "i", "dtype") # Layout, device and memory_format are ignored return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self) @_onnx_symbolic("aten::repeat") @_beartype.beartype def repeat(g: jit_utils.GraphContext, self, repeats): dtype = _type_utils.JitScalarType.INT64 shape_ = ones_like(g, repeats, dtype) self = g.op("Expand", self, shape_) return g.op("Tile", self, repeats) @_onnx_symbolic("aten::repeat_interleave") @_beartype.beartype def repeat_interleave( g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None ): repeats_dim = symbolic_helper._get_tensor_rank(repeats) repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) input_sizes = symbolic_helper._get_tensor_sizes(self) if repeats_dim is None: raise errors.SymbolicValueError( "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", self, ) if repeats_sizes is None: raise errors.SymbolicValueError( "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", self, ) if input_sizes is None: raise errors.SymbolicValueError( "Unsupported: ONNX export of repeat_interleave for unknown input size.", self, ) # if dim is None flatten # By default, use the flattened input array, and return a flat output array if symbolic_helper._is_none(dim): self = symbolic_helper._reshape_helper( g, self, g.op("Constant", value_t=torch.tensor([-1])) ) dim = torch.tensor(0, dtype=torch.int64) else: dim = symbolic_helper._maybe_get_scalar(dim) # Handle cases where dim is negative if dim < 0: dim += len(input_sizes) input_sizes_temp = input_sizes.copy() for idx, input_size in enumerate(input_sizes): if input_size is None: input_sizes[idx], input_sizes_temp[idx] = 0, -1 # Cases where repeats is an int or single value tensor if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): if input_sizes[dim] == 0: return symbolic_helper._onnx_opset_unsupported_detailed( "repeat_interleave", 9, 13, "Unsupported along dimension with unknown input size", self, ) return symbolic_helper._repeat_interleave_single_value_repeat_helper( g, self, repeats, dim ) # Cases where repeats is a 1 dim Tensor elif repeats_dim == 1: if input_sizes[dim] == 0: return symbolic_helper._onnx_opset_unsupported_detailed( "repeat_interleave", 9, 13, "Unsupported along dimension with unknown input size", self, ) if repeats_sizes[0] is None: return symbolic_helper._onnx_opset_unsupported_detailed( "repeat_interleave", 9, 13, "Unsupported for cases with dynamic repeats", self, ) assert ( repeats_sizes[0] == input_sizes[dim] ), "repeats must have the same size as input along dim" reps = repeats_sizes[0] else: raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self) final_splits = list() r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0) i_splits = symbolic_helper._repeat_interleave_split_helper(g, self, reps, dim) input_sizes[dim], input_sizes_temp[dim] = -1, 1 for idx, r_split in enumerate(r_splits): i_split = unsqueeze(g, i_splits[idx], dim + 1) r_concat = [ g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])), r_split, g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])), ] r_concat = g.op("Concat", *r_concat, axis_i=0) i_split = expand(g, i_split, r_concat, None) i_split = symbolic_helper._reshape_helper( g, i_split, g.op("Constant", value_t=torch.LongTensor(input_sizes)), allowzero=0, ) final_splits.append(i_split) return g.op("Concat", *final_splits, axis_i=dim) @_onnx_symbolic("aten::pixel_shuffle") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): dims = symbolic_helper._get_tensor_sizes(self) if len(dims) != 4: return symbolic_helper._unimplemented( "pixel_shuffle", "only support 4d input", self ) if any(i is None for i in dims[1:]): after_view = symbolic_helper._reshape_helper( g, symbolic_helper._unsqueeze_helper(g, self, [2, 3]), g.op( "Constant", value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]), ), allowzero=0, ) after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) # For dynamic input shapes, two reshapes are performed reshape_h = symbolic_helper._reshape_helper( g, after_transpose, g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])), allowzero=0, ) reshape_w = symbolic_helper._reshape_helper( g, reshape_h, g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])), allowzero=0, ) return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5]) else: output_channel = dims[1] // upscale_factor // upscale_factor after_view = symbolic_helper._reshape_helper( g, self, g.op( "Constant", value_t=torch.tensor( [ -1, output_channel, upscale_factor, upscale_factor, dims[2], dims[3], ] ), ), allowzero=0, ) after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) return symbolic_helper._reshape_helper( g, after_transpose, g.op( "Constant", value_t=torch.tensor( [ -1, output_channel, dims[2] * upscale_factor, dims[3] * upscale_factor, ] ), ), allowzero=0, ) @_onnx_symbolic("aten::pixel_unshuffle") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def pixel_unshuffle(g: jit_utils.GraphContext, self, downscale_factor): dims = symbolic_helper._get_tensor_sizes(self) if len(dims) != 4: return symbolic_helper._unimplemented( "pixel_shuffle", "only support 4d input", self ) if any(i is None for i in dims[1:]): # For dynamic input shapes, two reshapes are performed reshape_h = symbolic_helper._reshape_helper( g, symbolic_helper._unsqueeze_helper(g, self, [3]), g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])), allowzero=0, ) reshape_w = symbolic_helper._reshape_helper( g, reshape_h, g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])), allowzero=0, ) after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4]) final_reshape = symbolic_helper._reshape_helper( g, after_transpose, g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])), allowzero=0, ) return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3]) else: output_channel = dims[1] * downscale_factor * downscale_factor after_view = symbolic_helper._reshape_helper( g, self, g.op( "Constant", value_t=torch.tensor( [ -1, dims[1], dims[2] // downscale_factor, downscale_factor, dims[3] // downscale_factor, downscale_factor, ] ), ), allowzero=0, ) after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4]) return symbolic_helper._reshape_helper( g, after_transpose, g.op( "Constant", value_t=torch.tensor( [ -1, output_channel, dims[2] // downscale_factor, dims[3] // downscale_factor, ] ), ), allowzero=0, ) @_beartype.beartype def _generic_rnn( g: jit_utils.GraphContext, variant, input, initial_states, all_weights, has_biases, num_layers, dropout, train, bidirectional, batch_first=None, batch_sizes=None, ): warnings.warn( "Exporting a model to ONNX with a batch_size other than 1, " + "with a variable length with " + variant + " can cause an error " + "when running the ONNX model with a different batch size. " + "Make sure to save the model with a batch size of 1, " + "or define the initial states (h0/c0) as inputs of the model. " ) onnxActivations = [ "Relu", "Tanh", "Sigmoid", "Affine", "LeakyRelu", "ThresholdedRelu", "ScaledTanh", "HardSigmoid", "Elu", "Softsign", "Softplus", ] variantToOnnxActivationMap = dict( zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations) ) weights_per_layer = 4 if has_biases else 2 # this means that projections are used inside LSTM, so need to tell user that it's not supported if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * ( 1 + bidirectional ): return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input) assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional) layer_weights = [ all_weights[i : i + weights_per_layer] for i in range(0, len(all_weights), weights_per_layer) ] if batch_first: # batch, seq, feat -> seq, batch, feat input = g.op("Transpose", input, perm_i=[1, 0, 2]) if dropout and train: return symbolic_helper._unimplemented( "RNN/GRU/LSTM", "dropout in training mode", input ) if variant.startswith("RNN"): nonlinearity = variantToOnnxActivationMap[variant[4:].lower()] variant = "RNN" w_hh = all_weights[1] hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1) if hidden_size is None: return symbolic_helper._unimplemented( "RNN/GRU/LSTM", "unknown hidden size", input ) unidirectional = not bidirectional prev_output = input h_outs = [] if variant == "RNN" or variant == "GRU": h0 = initial_states elif variant == "LSTM": h0, c0 = initial_states c_outs = [] sequence_lens = unused(g) if batch_sizes is None else batch_sizes if variant == "GRU": # pytorch is reset, input, hidden # onnx is input, reset, hidden reform_permutation = [(1, 2), (0, 1), (2, 3)] elif variant == "LSTM": # pytorch is input, forget, cell, output. # onnx is input, output, forget, cell. reform_permutation = [(0, 1), (3, 4), (1, 3)] @_beartype.beartype def reform_weights(g, w, n, intervals): slices = [ symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n]) for x, y in intervals ] return g.op("Concat", *slices, axis_i=0) @_beartype.beartype def transform_weights_no_bias(layer_index): weights = layer_weights[layer_index] if variant == "RNN": weight_ih, weight_hh = weights elif variant == "GRU" or variant == "LSTM": weight_ih, weight_hh = ( reform_weights(g, w, hidden_size, reform_permutation) for w in weights ) return tuple( symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined] ) @_beartype.beartype def transform_weights(layer_index): weights = layer_weights[layer_index] if variant == "RNN": weight_ih, weight_hh, bias_ih, bias_hh = weights elif variant == "GRU" or variant == "LSTM": weight_ih, weight_hh, bias_ih, bias_hh = ( reform_weights(g, w, hidden_size, reform_permutation) for w in weights ) bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) # type: ignore[possibly-undefined] return tuple( symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh, bias_concat) # type: ignore[possibly-undefined] ) @_beartype.beartype def retrieve_state(x, start, end): return ( x if num_layers == 1 else symbolic_helper._slice_helper( g, x, axes=[0], starts=[start], ends=[end] ) ) for i in range(num_layers): if unidirectional: if weights_per_layer == 4: weight_ih, weight_hh, bias_concat = transform_weights(i) else: weight_ih, weight_hh = transform_weights_no_bias(i) bias_concat = unused(g) state_indices = i, i + 1 else: if weights_per_layer == 4: weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i) weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1) bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0) else: weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i) weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1) bias_concat = unused(g) weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0) weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0) state_indices = 2 * i, 2 * i + 2 inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens] inputs.append(retrieve_state(h0, *state_indices)) # type: ignore[possibly-undefined] if variant == "LSTM": inputs.append(retrieve_state(c0, *state_indices)) # type: ignore[possibly-undefined] extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"} if variant == "RNN": if bidirectional: activation = [nonlinearity, nonlinearity] # type: ignore[possibly-undefined] else: activation = [nonlinearity] # type: ignore[possibly-undefined] prev_output, h_out = g.op( "RNN", *inputs, outputs=2, hidden_size_i=hidden_size, activations_s=activation, **extra_kwargs, ) elif variant == "GRU": prev_output, h_out = g.op( "GRU", *inputs, outputs=2, hidden_size_i=hidden_size, linear_before_reset_i=1, **extra_kwargs, ) elif variant == "LSTM": prev_output, h_out, c_out = g.op( "LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs ) if bidirectional: # The ONNX RNN/GRU/LSTM produce an output of dimensions # seq_len, num_directions, batch, hidden_size # We have to convert to match pytorch's expected # seq_len, batch, num_directions * hidden_size # by first moving num_directions before hidden_size with # Transpose, and then combining it with hidden_size # with Reshape. prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3]) prev_output = symbolic_helper._reshape_helper( g, prev_output, g.op("Constant", value_t=torch.LongTensor([0, 0, -1])), allowzero=0, ) else: prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1]) h_outs.append(h_out) # type: ignore[possibly-undefined] if variant == "LSTM": c_outs.append(c_out) # type: ignore[possibly-undefined] if batch_first: # seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2]) h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) # type: ignore[possibly-undefined] if variant == "RNN" or variant == "GRU": return prev_output, h_outs elif variant == "LSTM": c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) # type: ignore[possibly-undefined] return prev_output, h_outs, c_outs @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") @_beartype.beartype def _lstm_full( g: jit_utils.GraphContext, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first, ): hidden, weight = symbolic_helper._unpack_list( hidden_v ), symbolic_helper._unpack_list(weight_v) return _generic_rnn( g, "LSTM", input, hidden, weight, has_biases, num_layers, dropout, train, bidirectional, batch_first, ) @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") @_beartype.beartype def _lstm_packed( g: jit_utils.GraphContext, input, batch_sizes, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, ): hidden, weight = symbolic_helper._unpack_list( hidden_v ), symbolic_helper._unpack_list(weight_v) return _generic_rnn( g, "LSTM", input, hidden, weight, has_biases, num_layers, dropout, train, bidirectional, batch_sizes=batch_sizes, ) @_onnx_symbolic("aten::lstm") @_beartype.beartype def lstm(g: jit_utils.GraphContext, *args): if symbolic_helper._is_tensor_list(args[3]): return _lstm_packed(g, *args) else: return _lstm_full(g, *args) @_onnx_symbolic("aten::lstm_cell") @_beartype.beartype def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh): input = symbolic_helper._unsqueeze_helper(g, self, [0]) hidden = symbolic_helper._unpack_list(hidden) hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden] weight = ( (w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh) ) has_biases = True if symbolic_helper._is_tensor(b_ih) else False _, h_outs, c_outs = _generic_rnn( g, "LSTM", input, hidden, weight, has_biases, num_layers=1, dropout=0, train=0, bidirectional=False, batch_first=False, ) return symbolic_helper._squeeze_helper( g, h_outs, [0] ), symbolic_helper._squeeze_helper(g, c_outs, [0]) @_onnx_symbolic( "aten::gru", decorate=[symbolic_helper._apply_params("GRU"), _export("gru")] ) @_onnx_symbolic( "aten::rnn_tanh", decorate=[symbolic_helper._apply_params("RNN_TANH"), _export("rnn_tanh")], ) @_onnx_symbolic( "aten::rnn_relu", decorate=[symbolic_helper._apply_params("RNN_RELU"), _export("rnn_relu")], ) def _one_hidden_rnn(kind: str): @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") @_beartype.beartype def _rnn_full( g, input, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first, ): weight = symbolic_helper._unpack_list(weight_v) return _generic_rnn( g, kind, input, hidden, weight, has_biases, num_layers, dropout, train, bidirectional, batch_first, ) @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") def _rnn_packed( g, input, batch_sizes, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional, ): weight = symbolic_helper._unpack_list(weight_v) return _generic_rnn( g, kind, input, hidden, weight, has_biases, num_layers, dropout, train, bidirectional, batch_sizes=batch_sizes, ) def symbolic(g, *args): if symbolic_helper._is_tensor_list(args[3]): return _rnn_packed(g, *args) else: return _rnn_full(g, *args) return symbolic @_onnx_symbolic("aten::_dim_arange") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def _dim_arange(g: jit_utils.GraphContext, like, dim): like_shape = g.op("Shape", like) stop = g.op( "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 ) if symbolic_helper.is_caffe2_aten_fallback(): return g.op("_caffe2::Range", stop) else: # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) return arange(g, stop, 4, None, None, None) @_onnx_symbolic("aten::detach") @_beartype.beartype def detach(g: jit_utils.GraphContext, input): # Erase aten::detach nodes because ONNX is inference only return input @_onnx_symbolic("aten::contiguous") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def contiguous(g: jit_utils.GraphContext, input, memory_format): if memory_format > 2: # allower values are any, preserve and contiguous_format raise errors.SymbolicValueError( "onnx memory_format support is not implemented", input ) return input @_onnx_symbolic("aten::_pack_padded_sequence") @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype def _pack_padded_sequence(g: jit_utils.GraphContext, input, lengths, batch_first): # Currently there is no PackPadded operator in ONNX. We rely on an # optimization pass to remove this later. It is an error if all # PackPadded operators cannot be optimized out. if batch_first: input = g.op("Transpose", input, perm_i=[1, 0, 2]) if not lengths.type().isSubtypeOf(torch._C.TensorType.get()): raise errors.SymbolicValueError( "'lengths' must be a Tensor for ONNX export", input ) # We know it's a TensorType so this check is now safe. # It's really only necessary because those operators expand to something that # only works with int32 types in Caffe2... if ( _type_utils.JitScalarType.from_value( lengths, _type_utils.JitScalarType.UNDEFINED ) != _type_utils.JitScalarType.INT ): lengths = g.op("Cast", lengths, to_i=_C_onnx.TensorProtoDataType.INT32) return g.op("prim::PackPadded", input, lengths, outputs=2) @_onnx_symbolic("aten::_pad_packed_sequence") @symbolic_helper.parse_args("v", "v", "i", "t", "v") @_beartype.beartype def _pad_packed_sequence( g: jit_utils.GraphContext, data, batch_sizes, batch_first, padding_value, total_length, ): # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence # It is only useful/used when training using data_parallel model, so # It shouldn't be relevant for ONNX anyway data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2) if batch_first: data = g.op("Transpose", data, perm_i=[1, 0, 2]) return data, lengths @_onnx_symbolic("aten::randint") @_beartype.beartype def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options): dtype = symbolic_helper._get_const(dtype, "i", "dtype") low_i = symbolic_helper._get_const(low, "i", "low") high_i = symbolic_helper._get_const(high, "i", "high") if dtype is None: scalar_type = _type_utils.JitScalarType.INT64 else: scalar_type = _type_utils.JitScalarType(dtype) if low_i is None: raise symbolic_helper._onnx_unsupported("randint", low) if high_i is None: raise symbolic_helper._onnx_unsupported("randint", high) shape = symbolic_helper._maybe_get_const(shapes, "is") if symbolic_helper._is_value(shape): shape_const = g.op( "ConstantOfShape", shapes, value_t=torch.tensor([0], dtype=torch.float), ) randn = g.op( "RandomUniformLike", shape_const, low_f=low_i, high_f=high_i, ) else: randn = g.op( "RandomUniform", shape_i=shape, low_f=low_i, high_f=high_i, ) # cast to integer type int_dtype = _type_utils.JitScalarType.INT64 randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) if int_dtype != scalar_type: randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) return randint @_onnx_symbolic("aten::randint_like") @_beartype.beartype def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options): dtype = symbolic_helper._get_const(dtype, "i", "dtype") low_i = symbolic_helper._get_const(low, "i", "low") high_i = symbolic_helper._get_const(high, "i", "high") if dtype is None: scalar_type = _type_utils.JitScalarType.INT64 else: scalar_type = _type_utils.JitScalarType(dtype) if low_i is None: raise symbolic_helper._onnx_unsupported("randint", low) if high_i is None: raise symbolic_helper._onnx_unsupported("randint", high) randn = g.op( "RandomUniformLike", self, low_f=low_i, high_f=high_i, ) # cast to integer type int_dtype = _type_utils.JitScalarType.INT64 randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) if int_dtype != scalar_type: randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) return randint @_onnx_symbolic("aten::randn") @_beartype.beartype def randn(g: jit_utils.GraphContext, shapes, dtype, *options): dtype = symbolic_helper._get_const(dtype, "i", "dtype") if dtype is None: scalar_type = _type_utils.JitScalarType.FLOAT else: scalar_type = _type_utils.JitScalarType(dtype) shape = symbolic_helper._maybe_get_const(shapes, "is") if symbolic_helper._is_value(shape): shape_const = g.op( "ConstantOfShape", shapes, value_t=torch.tensor([0], dtype=torch.float), ) return g.op( "RandomNormalLike", shape_const, dtype_i=scalar_type.onnx_type(), ) return g.op( "RandomNormal", shape_i=shape, dtype_i=scalar_type.onnx_type(), ) @_onnx_symbolic("aten::rand") @_beartype.beartype def rand(g: jit_utils.GraphContext, shapes, dtype, *options): dtype = symbolic_helper._get_const(dtype, "i", "dtype") if dtype is None: scalar_type = _type_utils.JitScalarType.FLOAT else: scalar_type = _type_utils.JitScalarType(dtype) shape = symbolic_helper._maybe_get_const(shapes, "is") if symbolic_helper._is_value(shape): shape_const = g.op( "ConstantOfShape", shapes, value_t=torch.tensor([0], dtype=torch.float), ) return g.op( "RandomUniformLike", shape_const, dtype_i=scalar_type.onnx_type(), ) return g.op( "RandomUniform", shape_i=shape, dtype_i=scalar_type.onnx_type(), ) @_onnx_symbolic("aten::randn_like") @_beartype.beartype def randn_like( g: jit_utils.GraphContext, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None, ): dtype = symbolic_helper._get_const(dtype, "i", "dtype") if dtype is None: scalar_type = _type_utils.JitScalarType.from_value( self, _type_utils.JitScalarType.FLOAT ) else: scalar_type = _type_utils.JitScalarType(dtype) return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type()) @_onnx_symbolic("aten::rand_like") @_beartype.beartype def rand_like( g: jit_utils.GraphContext, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None, ): dtype = symbolic_helper._get_const(dtype, "i", "dtype") if dtype is None: dtype = _type_utils.JitScalarType.from_value( self, _type_utils.JitScalarType.FLOAT ) return g.op( "RandomUniformLike", self, dtype_i=_type_utils.JitScalarType(dtype).onnx_type() ) @_onnx_symbolic("aten::rrelu") @symbolic_helper.parse_args("v", "f", "f", "i", "none") @_beartype.beartype def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator): if not training: slope = (upper + lower) / 2.0 return g.op("LeakyRelu", input, alpha_f=slope) p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower) return g.op("PRelu", input, p) @_onnx_symbolic("aten::bernoulli") @_beartype.beartype def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): if out is not None and not symbolic_helper._is_none(out): symbolic_helper._unimplemented( "Bernoulli", "out parameter is not supported for bernoulli", input ) if generator is not None and not symbolic_helper._is_none(generator): symbolic_helper._unimplemented( "Bernoulli", "generator is not supported for bernoulli", input ) dtype = _type_utils.JitScalarType.from_value( input, _type_utils.JitScalarType.UNDEFINED ) if dtype == _type_utils.JitScalarType.UNDEFINED: return symbolic_helper._unimplemented( "Bernoulli", "input dtype not accessible", input ) rands = g.op( "RandomUniformLike", input, high_f=1.0, low_f=0.0, dtype_i=dtype.onnx_type(), ) prob = p if p is not None and not symbolic_helper._is_none(p) else input output = g.op("Less", rands, prob) return g.op("Cast", output, to_i=dtype.onnx_type()) @_onnx_symbolic("aten::log_sigmoid") @symbolic_helper.parse_args("v") @_beartype.beartype def log_sigmoid(g: jit_utils.GraphContext, input): p = g.op("Sigmoid", input) return g.op("Log", p) @_onnx_symbolic("aten::erf") @symbolic_helper.parse_args("v") @_beartype.beartype def erf(g: jit_utils.GraphContext, input): return g.op("Erf", input) @_onnx_symbolic("aten::flatten") @symbolic_helper.quantized_args(True, False, False) @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): dim = symbolic_helper._get_tensor_rank(input) if dim is None: return symbolic_helper._unimplemented( "dim", "ONNX and PyTorch use different strategies to split the input. " "Input rank must be known at export time.", input, ) if dim == 0: return symbolic_helper._reshape_helper(g, input, [1]) if dim == 1: return g.op("Identity", input) # TODO: remove this as onnx opset 11 spec allows negative axes if end_dim < 0: end_dim = dim + end_dim # use ONNX's Flatten operator for cases where the output shape is 2D if start_dim == 1 and end_dim == dim - 1: return g.op("Flatten", input, axis_i=start_dim) if start_dim == 0 and end_dim == dim - 2: return g.op("Flatten", input, axis_i=end_dim + 1) return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) @_onnx_symbolic("aten::nonzero") @symbolic_helper.parse_args("v") @_beartype.beartype def nonzero(g: jit_utils.GraphContext, input): """Emitted from `torch.nonzero(x, as_tuple=False)`""" return t(g, g.op("NonZero", input)) @_onnx_symbolic("aten::nonzero_numpy") # Emitted from `torch.nonzero(x, as_tuple=True)` @_beartype.beartype def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): return unbind(g, nonzero(g, input), 1, _outputs=_outputs) @_onnx_symbolic("aten::isnan") @symbolic_helper.parse_args("v") @_beartype.beartype def isnan(g: jit_utils.GraphContext, input): output = g.op("IsNaN", input) return output @_onnx_symbolic("aten::any") @_beartype.beartype def _any(g: jit_utils.GraphContext, *args): # aten::any(Tensor self) if len(args) == 1: input = args[0] dim, keepdim = None, 0 # aten::any(Tensor self, int[]? dim, bool keepdim) else: input, dim, keepdim = args # Can be int list or single int dim = symbolic_helper._parse_arg(dim, "t") dim = [int(d) for d in dim.view(-1)] keepdim = symbolic_helper._parse_arg(keepdim, "i") input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) input_sum = symbolic_helper._reducesum_helper( g, input, axes_i=dim, keepdims_i=keepdim ) return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long))) @_onnx_symbolic("aten::all") @_beartype.beartype def _all(g: jit_utils.GraphContext, *args): input = g.op("Not", args[0]) # aten::all(Tensor self) if len(args) == 1: return g.op("Not", _any(g, input)) # aten::all(Tensor self, int[]? dim, bool keepdim) else: return g.op("Not", _any(g, input, args[1], args[2])) @_onnx_symbolic("aten::narrow") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def narrow(g: jit_utils.GraphContext, input, dim, start, length): return symbolic_helper._slice_helper( g, input, axes=[dim], starts=[start], ends=[start + length] ) @_onnx_symbolic("aten::argmax") @symbolic_helper.parse_args("v", "v", "b") @_beartype.beartype def argmax( g: jit_utils.GraphContext, input: torch._C.Value, dim: torch._C.Value, keepdim: bool, ): return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") @_onnx_symbolic("aten::argmin") @symbolic_helper.parse_args("v", "v", "b") @_beartype.beartype def argmin( g: jit_utils.GraphContext, input: torch._C.Value, dim: torch._C.Value, keepdim: bool, ): return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") @_onnx_symbolic("aten::scatter") @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def scatter(g: jit_utils.GraphContext, self, dim, index, src): src_type = _type_utils.JitScalarType.from_value( src, _type_utils.JitScalarType.UNDEFINED ) src = symbolic_helper._maybe_get_scalar(src) if symbolic_helper._is_value(src): return g.op("Scatter", self, index, src, axis_i=dim) else: # Check if scalar "src" has same type as self (PyTorch allows different # type for scalar src (but not when src is tensor)). If not, insert Cast node. self_scalar_type = _type_utils.JitScalarType.from_value(self) if self_scalar_type != src_type: src = g.op("Cast", src, to_i=self_scalar_type.onnx_type()) return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim) @_onnx_symbolic("aten::scatter_add") @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): scalar_type = symbolic_helper._try_get_scalar_type(self) if scalar_type is None: return symbolic_helper._unimplemented( "scatter_add", "input dtype not accessible", self ) sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False) if sizes: to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=scalar_type.dtype())) else: to_add = zeros_like(g, self, scalar_type) to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src) return add(g, self, to_add) @_onnx_symbolic("aten::log2") @_beartype.beartype def log2(g: jit_utils.GraphContext, self): _ln2 = 0.693147180559945309 return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2))) @_onnx_symbolic("aten::is_floating_point") @_beartype.beartype def is_floating_point(g: jit_utils.GraphContext, self): if symbolic_helper._is_fp(self): return g.op("Constant", value_t=torch.BoolTensor([1])) return g.op("Constant", value_t=torch.BoolTensor([0])) @_onnx_symbolic("aten::__is_") @_beartype.beartype def __is_(g: jit_utils.GraphContext, self, other): if symbolic_helper._is_none(other): if symbolic_helper._is_none(self): return g.op("Constant", value_t=torch.BoolTensor([1])) return g.op("Constant", value_t=torch.BoolTensor([0])) return eq(g, self, other) @_onnx_symbolic("aten::__isnot_") @wrap_logical_op_with_negation @_beartype.beartype def __isnot_(g: jit_utils.GraphContext, self, other): return __is_(g, self, other) @_onnx_symbolic("aten::one_hot") @_beartype.beartype def one_hot(g: jit_utils.GraphContext, self, num_classes): values = g.op("Constant", value_t=torch.LongTensor([0, 1])) # onnxruntime supports limited type combinations for OneHot. if _type_utils.JitScalarType.from_value( num_classes, _type_utils.JitScalarType.UNDEFINED ) in { _type_utils.JitScalarType.UINT8, _type_utils.JitScalarType.INT8, _type_utils.JitScalarType.INT, _type_utils.JitScalarType.INT16, }: num_classes = g.op("Cast", num_classes, to_i=_C_onnx.TensorProtoDataType.INT64) return g.op("OneHot", self, num_classes, values, axis_i=-1) @_onnx_symbolic("aten::gather") @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): if symbolic_helper._maybe_get_const(sparse_grad, "i"): return symbolic_helper._unimplemented("gather", "sparse_grad == True", self) # NOTE: This workaround is needed since GatherElement is only supported # since opset 11, and Gather in ONNX is not the same as torch.gather. scalar_type = _type_utils.JitScalarType.from_value(self) values = g.op("Constant", value_t=torch.LongTensor([0, 1])) depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim]))) index = g.op( "Cast", g.op("OneHot", index, depth, values, axis_i=dim), to_i=scalar_type.onnx_type(), ) mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index) return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0) @symbolic_helper.parse_args("v", "is", "i", "i") @_beartype.beartype def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim): return symbolic_helper._var_mean_helper(g, input, dim, correction, keepdim) @_onnx_symbolic("aten::std") @_beartype.beartype def std(g: jit_utils.GraphContext, input, *args): var, _ = var_mean(g, input, *args) return g.op("Sqrt", var) @_onnx_symbolic("aten::var") @_beartype.beartype def var(g: jit_utils.GraphContext, input, *args): var, _ = var_mean(g, input, *args) return var @_onnx_symbolic("aten::var_mean") @_beartype.beartype def var_mean(g: jit_utils.GraphContext, input, *args): if len(args) == 1: return _var_mean(g, input, None, args[0], None) else: return _var_mean(g, input, *args) @_onnx_symbolic("aten::std_mean") @_beartype.beartype def std_mean(g: jit_utils.GraphContext, input, *args): var, mean = var_mean(g, input, *args) return g.op("Sqrt", var), mean @_onnx_symbolic("aten::logsumexp") @symbolic_helper.parse_args("v", "is", "i") @_beartype.beartype def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim) @_onnx_symbolic("aten::arange") @_beartype.beartype def arange(g: jit_utils.GraphContext, *args): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("arange", *args) @_beartype.beartype def _get_arange_dtype(dtype): dtype = symbolic_helper._maybe_get_const(dtype, "i") return dtype @_beartype.beartype def _float_step_convert(range_tensor): if symbolic_helper._is_fp(range_tensor): range_tensor = g.op( "Cast", g.op("Ceil", range_tensor), to_i=_type_utils.JitScalarType.INT64.onnx_type(), ) return range_tensor if len(args) == 2 or len(args) == 5: if len(args) == 2: # aten::arange(Scalar end, Tensor out) dtype = None else: # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) dtype = _get_arange_dtype(args[1]) dtype, end, start, step = symbolic_helper._arange_cast_helper( g, end=args[0], dtype=dtype ) end = symbolic_helper._unsqueeze_helper(g, end, [0]) range_tensor = _float_step_convert(end) arange_tensor = symbolic_helper._squeeze_helper( g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1] ) return g.op( "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() ) elif len(args) == 4 or len(args) == 7: if len(args) == 4: # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) dtype = None else: # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) dtype = _get_arange_dtype(args[3]) dtype, end, start, step = symbolic_helper._arange_cast_helper( g, start=args[0], end=args[1], step=args[2], dtype=dtype ) step = symbolic_helper._unsqueeze_helper(g, step, [0]) end = symbolic_helper._unsqueeze_helper(g, end, [0]) start = symbolic_helper._unsqueeze_helper(g, start, [0]) range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step)) arange_tensor = symbolic_helper._squeeze_helper( g, nonzero(g, ones(g, range_tensor, None, None, None)), [1] ) arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) return g.op( "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() ) elif len(args) == 6: # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) dtype = _get_arange_dtype(args[2]) dtype, end, start, step = symbolic_helper._arange_cast_helper( g, start=args[0], end=args[1], dtype=dtype ) end = symbolic_helper._unsqueeze_helper(g, end, [0]) start = symbolic_helper._unsqueeze_helper(g, start, [0]) range_tensor = _float_step_convert(g.op("Sub", end, start)) arange_tensor = g.op( "Add", symbolic_helper._squeeze_helper( g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1] ), start, ) return g.op( "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() ) return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments") @_onnx_symbolic("aten::linspace") @_beartype.beartype def linspace( g: jit_utils.GraphContext, start, end, steps, dtype, layout, device, pin_memory ): range_tensor = symbolic_helper._arange_helper(g, steps, None) step = div( g, sub(g, end, start), sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))), ) return add(g, mul(g, range_tensor, step), start) @_onnx_symbolic("aten::lift") @_beartype.beartype def lift(g: jit_utils.GraphContext, self): # at::lift() is a no-op from the perspective of tracing for onnx return self @_onnx_symbolic("aten::masked_fill") @_beartype.beartype def masked_fill(g: jit_utils.GraphContext, self, mask, value): """Implement the masked_fill functionality available for a pytorch tensor in ONNX. Fills elements of the input tensor with `value` where `mask` is True. """ mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) value = symbolic_helper._maybe_get_scalar(value) return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self) @_onnx_symbolic("aten::masked_fill_") @_beartype.beartype def masked_fill_(g: jit_utils.GraphContext, self, mask, value): return masked_fill(g, self, mask, value) @_onnx_symbolic("aten::index") @_beartype.beartype def index(g: jit_utils.GraphContext, self, index): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("index", self, index, overload_name="Tensor") if symbolic_helper._is_packed_list(index): indices = symbolic_helper._unpack_list(index) else: indices = [index] @_beartype.beartype def try_mask_to_index(index): if not symbolic_helper._is_none(index) and ( _type_utils.JitScalarType.from_value( index, _type_utils.JitScalarType.UNDEFINED ) == _type_utils.JitScalarType.UINT8 or symbolic_helper._is_bool(index) ): if g.opset < 9: raise errors.SymbolicValueError( "Exporting masked indices are only supported after ONNX opset 9.", self, ) warnings.warn( "Exporting aten::index operator with indices of type Byte. " "Only 1-D indices are supported. In any other case, " "this will produce an incorrect ONNX graph." ) index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1]) return index indices = [try_mask_to_index(idx) for idx in indices] if len(indices) == 1: return symbolic_helper._select_helper( g, self, 0, indices[0], apply_reshape=False ) else: # Multiple tensors as indices. Each tensor could either be # 1. prim::Constant() # representing ":" in python indexing. E.g. tensor[:, :] # 2. prim::Constant[value=...] or tensor output # representing advanced indexing. E.g. tensor[[0, 1], [2, 0]]. # For more info on advanced indexing, # check https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing # Consider a general case of # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":". # Same results can be achieved through transposing t into # t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] # and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t # and process the tensor indices. # t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n] # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)) # After gather, reshape and transpose back. adv_idx_indices = [ i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx) ] if len(adv_idx_indices) == 0: return self elif len(adv_idx_indices) == 1: return index_select( g, self, adv_idx_indices[0], indices[adv_idx_indices[0]] ) else: rank = symbolic_helper._get_tensor_rank(self) if rank is None: return symbolic_helper._unimplemented( "aten::index", "operator of advanced indexing on tensor of unknown rank. " "Try turning on shape inference during export: " "torch.onnx._export(..., onnx_shape_inference=True).", self, ) # TODO: If indexing is supported natively in ONNX in future opsets, # update the warning to recommend exporting with higher opset version. warnings.warn( "Exporting aten::index operator of advanced indexing in opset " f"{GLOBALS.export_onnx_opset_version}" " is achieved by combination of multiple ONNX operators, " "including Reshape, Transpose, Concat, and Gather. " "If indices include negative values, the exported graph will produce incorrect results." ) adv_idx_count = len(adv_idx_indices) shape_tensor = _shape_as_tensor(g, self) dim_tensor_list = [ g.op( "Gather", shape_tensor, g.op("Constant", value_t=torch.LongTensor([dim])), axis_i=0, ) for dim in range(rank) ] self = g.op( "Transpose", self, perm_i=adv_idx_indices + [i for i in range(rank) if i not in adv_idx_indices], ) self = g.op("Flatten", self, axis_i=adv_idx_count) # Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well. cum_adv_index = indices[adv_idx_indices[-1]] multiplier = dim_tensor_list[adv_idx_indices[-1]] for i in range(adv_idx_count - 2, -1, -1): adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier) cum_adv_index = g.op("Add", cum_adv_index, adv_index) multiplier = g.op( "Mul", multiplier, dim_tensor_list[adv_idx_indices[i]] ) # perform gather self = index_select(g, self, 0, cum_adv_index) cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index) # check if all advanced indices are consecutive. # Refer to https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing # to understand how the subarray position is decided. if adv_idx_indices == list( range(adv_idx_indices[0], adv_idx_indices[-1] + 1) ): # unfold regular index axes folded_adv_idx_shape_list = [ g.op("Constant", value_t=torch.LongTensor([-1])) ] + [ dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices ] folded_adv_idx_shape = g.op( "Concat", *folded_adv_idx_shape_list, axis_i=0 ) self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape) # Transpose folded advanced indexed axis to its original location. adv_idx_permute = ( list(range(1, adv_idx_indices[0] + 1)) + [0] + list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1)) ) self = g.op("Transpose", self, perm_i=adv_idx_permute) # unfold advanced index axes final_shape_list = ( [dim_tensor_list[i] for i in range(adv_idx_indices[0])] + [cum_adv_index_shape_tensor] + [ dim_tensor_list[i] for i in range(adv_idx_indices[0], rank) if i not in adv_idx_indices ] ) final_shape = g.op("Concat", *final_shape_list, axis_i=0) else: final_shape = g.op( "Concat", cum_adv_index_shape_tensor, *[ dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices ], axis_i=0, ) return symbolic_helper._reshape_helper(g, self, final_shape) @_onnx_symbolic("aten::linalg_norm") @symbolic_helper.parse_args("v", "v", "is", "b", "v") @_beartype.beartype def linalg_norm( g: jit_utils.GraphContext, self: torch._C.Value, ord: torch._C.Value, dim: Optional[Sequence[int]], keepdim: bool, dtype: torch._C.Value, ): # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html ord_value = None if dim is None: if symbolic_helper._is_none(ord): self = symbolic_helper._reshape_helper(g, self, [-1]) ord = g.op("Constant", value_t=torch.LongTensor([2])) self_dim = symbolic_helper._get_tensor_rank(self) if self_dim is None: return symbolic_helper._unimplemented( "dim", "Input rank must be known at export time.", self ) if self_dim == 1: ord_value = symbolic_helper._parse_arg(ord, "f") else: dim = [0, 1] else: if len(dim) == 1: if symbolic_helper._is_none(ord): ord = g.op("Constant", value_t=torch.LongTensor([2])) ord_value = symbolic_helper._parse_arg(ord, "f") if ord_value: return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype) return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) @_onnx_symbolic("aten::linalg_vector_norm") @symbolic_helper.parse_args("v", "f", "is", "b", "v") @_beartype.beartype def linalg_vector_norm( g: jit_utils.GraphContext, self: torch._C.Value, ord: float, dim: Optional[Sequence[int]], keepdim: bool, dtype: torch._C.Value, ): return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) @_onnx_symbolic("aten::linalg_matrix_norm") @symbolic_helper.parse_args("v", "v", "is", "b", "v") @_beartype.beartype def linalg_matrix_norm( g: jit_utils.GraphContext, self: torch._C.Value, ord: torch._C.Value, dim: List[int], keepdim: bool, dtype: torch._C.Value, ): # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html ord_value = symbolic_helper._parse_arg(ord, "s") if ord_value == "fro": return frobenius_norm(g, self, dim, keepdim) elif ord_value == "nuc": return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc", self) else: ord_value = symbolic_helper._parse_arg(ord, "f") if ord_value is None: return frobenius_norm(g, self, dim, keepdim) if ord_value == 2 or ord_value == -2: # ord = 2/-2 unimplemented due to lack of operators # used to calculate singular values return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2", self) # Wrap the dim vector to handle negative dim values self_dim = symbolic_helper._get_tensor_rank(self) if self_dim is None: return symbolic_helper._unimplemented( "linalg.matrix_norm", "Input rank must be known at export time.", self ) # Common implementation for cases with # ord = 1/-1 and ord = inf/-inf if dim[0] < 0: dim[0] += self_dim if dim[1] < 0: dim[1] += self_dim if ord_value == math.inf or ord_value == -math.inf: dim[0], dim[1] = dim[1], dim[0] if dim[1] > dim[0] and not keepdim: dim[1] -= 1 sum = symbolic_helper._reducesum_helper( g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim ) if ord_value > 0: result, indices = max( g, sum, dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), keepdim=keepdim, ) else: result, indices = min( g, sum, dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), keepdim=keepdim, ) return result @_onnx_symbolic("aten::linalg_cross") @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype def linalg_cross(g: jit_utils.GraphContext, input, other, dim=-1): return cross(g, input, other, dim) @_onnx_symbolic("aten::frobenius_norm") @symbolic_helper.parse_args("v", "is", "b") @_beartype.beartype def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): sqr = g.op("Mul", self, self) sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim) return g.op("Sqrt", sumsqr) @_onnx_symbolic("aten::multinomial") @symbolic_helper.parse_args("v", "i", "b", "v") @_beartype.beartype def multinomial( g: jit_utils.GraphContext, input, num_samples, replacement=False, generator=None ): if generator is not None and not symbolic_helper._is_none(generator): symbolic_helper._unimplemented( "Multinomial", "generator is not supported for multinomial", input ) if not replacement and num_samples > 1: symbolic_helper._unimplemented( "Multinomial", "replacement=False when num_samples > 1 is not supported for multinomial", input, ) log_input = log(g, input) return g.op( "Multinomial", log_input, dtype_i=_C_onnx.TensorProtoDataType.INT64, sample_size_i=num_samples, ) @_onnx_symbolic("aten::baddbmm") @_beartype.beartype def baddbmm(g: jit_utils.GraphContext, self, batch1, batch2, beta, alpha): scalar_type = _type_utils.JitScalarType.from_value(self) batch_mul = matmul(g, batch1, batch2) mul_a = mul( g, batch_mul, g.op("Cast", alpha, to_i=scalar_type.onnx_type()), ) mul_b = mul( g, self, g.op("Cast", beta, to_i=scalar_type.onnx_type()), ) return add(g, mul_a, mul_b) @_onnx_symbolic("aten::meshgrid") @symbolic_helper.parse_args("v", "s") @_beartype.beartype def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: Optional[str] = None): if indexing is None: indexing = "ij" elif indexing not in {"ij", "xy"}: raise errors.SymbolicValueError( f"Unsupported indexing: {indexing}", tensor_list ) unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list) if indexing == "xy": unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1] tensors = [ symbolic_helper._reshape_helper( g, t, g.op("Constant", value_t=torch.LongTensor([-1])) ) for t in unpacked_tensor_list ] tensors_shape = [g.op("Shape", t) for t in tensors] out_shape = g.op("Concat", *tensors_shape, axis_i=0) out = [] for i, t in enumerate(tensors): shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len( tensors ) shape_i[i] = tensors_shape[i] t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0)) out.append(g.op("Expand", t_reshaped, out_shape)) if indexing == "xy": out[0], out[1] = out[1], out[0] return g.op("prim::ListConstruct", *out) @_onnx_symbolic("aten::remainder") @_beartype.beartype def remainder(g: jit_utils.GraphContext, input, other): div = _floor_divide(g, input, other) quo = g.op("Mul", div, other) return g.op("Sub", input, quo) @_onnx_symbolic("aten::gelu") @symbolic_helper.parse_args("v", "s") @_beartype.beartype def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "none"): if approximate == "tanh": kBeta = math.sqrt(2 / math.pi) kKappa = 0.044715 beta = torch.tensor(kBeta, dtype=torch.double) kappa = torch.tensor(kKappa, dtype=torch.double) one = torch.tensor(1.0, dtype=torch.double) half = torch.tensor(0.5, dtype=torch.double) self_cube = mul(g, self, mul(g, self, self)) inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube))) return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner)))) else: _sqrt2 = 1.4142135623730951 erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double))) erf_plusone = add( g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)) ) return mul( g, mul(g, self, erf_plusone), g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)), ) @_onnx_symbolic("aten::group_norm") @symbolic_helper.quantized_args(True, False, False, False) @symbolic_helper.parse_args("v", "i", "v", "v", "f", "i") @_beartype.beartype def group_norm( g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled ): if symbolic_helper.is_caffe2_aten_fallback(): return g.at( "group_norm", input, weight, bias, num_groups_i=num_groups, eps_f=eps, cudnn_enabled_i=cudnn_enabled, ) channel_size = symbolic_helper._get_tensor_dim_size(input, 1) if channel_size is not None: assert channel_size % num_groups == 0 input_rank = symbolic_helper._get_tensor_rank(input) if input_rank is None: return symbolic_helper._unimplemented("group_norm", "unknown input rank", input) # 0 in the shape list keeps dimension value unchanged. shape = [0, num_groups, -1] input_reshaped = symbolic_helper._reshape_helper( g, input, g.op("Constant", value_t=torch.LongTensor(shape)) ) # C is always divisible by num_groups # Due to shape difference. we need to apply weight and bias after # instance norm computation and reshape weight_ = g.op( "Constant", value_t=torch.tensor( [1.0] * num_groups, dtype=_type_utils.JitScalarType.from_value(input).dtype(), ), ) bias_ = g.op( "Constant", value_t=torch.tensor( [0.0] * num_groups, dtype=_type_utils.JitScalarType.from_value(input).dtype(), ), ) norm_reshaped = g.op( "InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps ) norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input)) if weight is None or weight.node().mustBeNone(): weight_value = torch.tensor( [1.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() ) weight = g.op("Constant", value_t=weight_value) if bias is None or bias.node().mustBeNone(): bias_value = torch.tensor( [0.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() ) bias = g.op("Constant", value_t=bias_value) # Norm has shape [N, C, *] so we reshape weight and bias to [C, *] axes = list(range(1, input_rank - 1)) return add( g, mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)), symbolic_helper._unsqueeze_helper(g, bias, axes), ) @_onnx_symbolic("aten::_weight_norm") @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim): rank = symbolic_helper._get_tensor_rank(weight_v) if rank is not None: # W = g * ((v) / ||v||) # Compute norm_except_dim for l2 norm. dim = None means over all dims # torch's weight_norm module sets dim = -1 if it's None. # This conflicts the logic for negative axes to access dims backwards # TODO: Might need a fix in torch group_norm module axes = list(range(rank)) if dim is not None: if dim < -1: dim += rank if dim != -1: axes.remove(dim) norm_v = norm(g, weight_v, 2, axes, 1) div = g.op("Div", weight_v, norm_v) return g.op("Mul", div, weight_g) if symbolic_helper.is_caffe2_aten_fallback(): return g.at("_weight_norm", weight_v, weight_g, dim_i=dim) raise errors.SymbolicValueError( "Unsupported: ONNX export of _weight_norm for tensor of unknown rank.", weight_v, ) @_onnx_symbolic("aten::dim") @_beartype.beartype def dim(g: jit_utils.GraphContext, self): """Implement the dim functionality available for a pytorch tensor in ONNX""" # ONNX does not support dim directly in this opset so we can use 2 ops to get the info shape = g.op("Shape", self) return g.op("Size", shape) @_onnx_symbolic("aten::__contains_") @_beartype.beartype def __contains_(g: jit_utils.GraphContext, self, element): unpacked_list = symbolic_helper._unpack_list(self) if all( symbolic_helper._is_constant(x) for x in unpacked_list ) and symbolic_helper._is_constant(element): return g.op( "Constant", value_t=torch.tensor( symbolic_helper._node_get(element.node(), "value") in (symbolic_helper._node_get(x.node(), "value") for x in unpacked_list) ), ) raise errors.SymbolicValueError( "Unsupported: ONNX export of __contains__ for non-constant list or element.", self, ) @_onnx_symbolic("aten::__getitem_") @_beartype.beartype def __getitem_(g: jit_utils.GraphContext, self, i): return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i) @_onnx_symbolic("aten::item") @_beartype.beartype def item(g: jit_utils.GraphContext, self): return self @_onnx_symbolic("aten::take") @_beartype.beartype def take(g: jit_utils.GraphContext, self, index): self_flattened = symbolic_helper._reshape_helper( g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) ) out = index_select(g, self_flattened, 0, index) out = reshape_as(g, out, index) return out @_beartype.beartype def _kl_div_log_target_impl(g: jit_utils.GraphContext, input, target): diff_ = sub(g, target, input) exp_ = exp(g, target) output = mul(g, exp_, diff_) return output @_beartype.beartype def _kl_div_non_log_target_impl(g: jit_utils.GraphContext, input, target): log_ = log(g, target) diff_ = sub(g, log_, input) output_pos = mul(g, target, diff_) zeros_ = zeros_like(g, output_pos) mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0))) output = where(g, mask_, output_pos, zeros_) return output @_onnx_symbolic("aten::kl_div") @symbolic_helper.parse_args("v", "v", "i", "b") @_beartype.beartype def kl_div(g: jit_utils.GraphContext, input, target, reduction, log_target): if log_target: output = _kl_div_log_target_impl(g, input, target) else: output = _kl_div_non_log_target_impl(g, input, target) if reduction == 0: return output elif reduction == 1: return g.op("ReduceMean", output, keepdims_i=0) elif reduction == 2: return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) else: return symbolic_helper._onnx_unsupported( "kl_div with reduction other than none, mean, or sum.", input ) @_onnx_symbolic("aten::mse_loss") @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype def mse_loss(g: jit_utils.GraphContext, input, target, reduction): output = mul(g, sub(g, input, target), sub(g, input, target)) if reduction == 0: return output elif reduction == 1: return g.op("ReduceMean", output, keepdims_i=0) elif reduction == 2: return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) else: return symbolic_helper._onnx_unsupported( "mse_loss with reduction other than none, mean, or sum.", input ) @_onnx_symbolic("aten::as_strided") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "v", "is", "i") @_beartype.beartype def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None): sizes = symbolic_helper._maybe_get_const(sizes, "is") rank = len(strides) self_1d = symbolic_helper._reshape_helper( g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) ) ind: Optional[torch.Tensor] if not symbolic_helper._is_value(sizes): ind = torch.tensor([0], dtype=torch.long) for i, (size, stride) in enumerate(zip(sizes, strides)): r_size = [1] * rank r_size[i] = -1 ind = ind + torch.arange(size).view(r_size) * stride if offset: ind = ind + offset return g.op("Gather", self_1d, g.op("Constant", value_t=ind)) else: ind = None for i, stride in enumerate(strides): r_size = [1] * rank r_size[i] = -1 size = select( g, sizes, g.op("Constant", value_t=torch.tensor([0])), g.op("Constant", value_t=torch.tensor(i)), ) tmp_ind = symbolic_helper._reshape_helper( g, arange(g, size, 4, None, None, None), g.op("Constant", value_t=torch.tensor(r_size)), ) tmp_ind = g.op( "Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride])) ) if ind is None: ind = tmp_ind else: ind = g.op("Add", ind, tmp_ind) if offset: ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset]))) return g.op("Gather", self_1d, ind) @_onnx_symbolic("aten::__derive_index") @_beartype.beartype def __derive_index(g: jit_utils.GraphContext, index, start, step): return g.op("Add", start, g.op("Mul", index, step)) @_onnx_symbolic("aten::__range_length") # Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp # if (step > 0 && lo < hi) { # push(stack, 1 + (hi - 1 - lo) / step); # } else if (step < 0 && lo > hi) { # push(stack, 1 + (lo - 1 - hi) / (0 - step)); # } else { # push(stack, 0); # } @_beartype.beartype def __range_length(g: jit_utils.GraphContext, lo, hi, step): sub = g.op("Sub", hi, lo) div = g.op("Ceil", true_divide(g, sub, step)) return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64) @_onnx_symbolic("aten::linear") @_beartype.beartype def linear(g: jit_utils.GraphContext, input, weight, bias): rank = symbolic_helper._get_tensor_rank(input) weight = t(g, weight) if rank == 2 and not bias.node().mustBeNone(): alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) output = addmm(g, bias, input, weight, alpha, beta) else: output = matmul(g, input, weight) if not bias.node().mustBeNone(): output = add(g, bias, output) return output @_onnx_symbolic("aten::hann_window") @symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v") @_beartype.beartype def hann_window( g: jit_utils.GraphContext, window_length, periodic=True, dtype: Optional[int] = None, layout=None, device=None, pin_memory=None, requires_grad=False, ): if dtype is None: dtype_ = torch.get_default_dtype() if not dtype_ or not dtype_.is_floating_point: dtype_ = torch.float scalar_type = _type_utils.JitScalarType.from_dtype(dtype_) else: scalar_type = _type_utils.JitScalarType(dtype) n_array = arange(g, window_length, 4, None, None, None) output = g.op("Cast", n_array, to_i=_C_onnx.TensorProtoDataType.FLOAT) output = mul( g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output ) if periodic is False: window_length = sub( g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int)) ) output = div(g, output, window_length) output = g.op( "Cast", square(g, sin(g, output)), to_i=scalar_type.onnx_type(), ) return output @_onnx_symbolic("aten::mv") @_beartype.beartype def mv(g: jit_utils.GraphContext, self, vec): return matmul(g, self, vec) @_onnx_symbolic("aten::dot") @_beartype.beartype def dot(g: jit_utils.GraphContext, self, other): return matmul(g, self, other) @_onnx_symbolic("aten::movedim") @symbolic_helper.parse_args("v", "t", "t") @_beartype.beartype def movedim(g: jit_utils.GraphContext, self, source, destination): # This is a pythonic implementation mostly taken from aten/src/ATen/native/TensorShape.cpp::movedim source = source.view(-1) destination = destination.view(-1) assert source.size() == destination.size() if (source == destination).all(): return self self_rank = symbolic_helper._get_tensor_rank(self) assert self_rank is not None perm = list(range(self_rank)) src_dims = perm.copy() dst_dims = perm.copy() for src, dst in zip(source.tolist(), destination.tolist()): perm[dst] = src src_dims[src] = -1 dst_dims[dst] = -1 src_dims = [dim for dim in src_dims if dim != -1] dst_dims = [dim for dim in dst_dims if dim != -1] for src, dst in zip(src_dims, dst_dims): perm[dst] = src return g.op("Transpose", self, perm_i=perm) @_onnx_symbolic("aten::fill") @symbolic_helper.parse_args("v", "v") @_beartype.beartype def fill(g: jit_utils.GraphContext, self, value): scalar_type = _type_utils.JitScalarType.from_value( self, _type_utils.JitScalarType.FLOAT ) return full_like(g, self, value, scalar_type) @_onnx_symbolic("aten::index_add") @_beartype.beartype def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None): warnings.warn( "Warning: ONNX export does not support duplicated values in 'index' field, " + "this will cause the ONNX model to be incorrect." ) # ONNX does not support "alpha" argument, unlike aten index_add # See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: return symbolic_helper._unimplemented("index_add", "alpha != 1", self) dim = symbolic_helper._maybe_get_const(dim, "i") if dim is None: raise errors.SymbolicValueError( "ONNX export does NOT support exporting 'index_add_()' function with " "unknown 'dim' value.", self, ) self_dim_rank = symbolic_helper._get_tensor_rank(self) other_dim_rank = symbolic_helper._get_tensor_rank(other) if self_dim_rank is None or other_dim_rank is None: raise errors.SymbolicValueError( "ONNX export does NOT support exporting 'index_add_()' function while " "the rank of self tensor or tensor to be added is unknown.", self, ) if other_dim_rank != self_dim_rank: delta = self_dim_rank - other_dim_rank for i in range(delta): other = symbolic_helper._unsqueeze_helper( g, other, [symbolic_helper._get_tensor_rank(other)] ) other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim) self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim) if (other_dim_size is not None) and (self_dim_size is not None): if other_dim_size > self_dim_size: raise errors.SymbolicValueError( "ONNX export does not support exporting 'index_add_()' function with " "duplicated values in 'index' parameter yet.", self, ) # Construct a new shape. It's almost as same as self except the size of the 'dim' # dimension is 1, so that we can expand other dimensions as expected. new_shape_axes = list(range(self_dim_rank)) new_shape_starts = [0 for i in range(self_dim_rank)] new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)] new_shape = symbolic_helper._slice_helper( g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends ) other = expand_as(g, other, new_shape) for i in range(dim): index = symbolic_helper._unsqueeze_helper(g, index, [0]) for i in range(self_dim_rank - dim - 1): index = symbolic_helper._unsqueeze_helper( g, index, [symbolic_helper._get_tensor_rank(index)] ) return scatter_add(g, self, dim, expand_as(g, index, other), other) @_onnx_symbolic("aten::roll") @symbolic_helper.parse_args("v", "is", "is") @_beartype.beartype def roll(g: jit_utils.GraphContext, self, shifts, dims): assert len(shifts) == len(dims) result = self for i in range(len(shifts)): shapes = [] shape = symbolic_helper._slice_helper( g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize] ) shapes.append(shape) shape = symbolic_helper._slice_helper( g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]] ) shapes.append(shape) result = g.op("Concat", *shapes, axis_i=dims[i]) return result @_onnx_symbolic("aten::cross") @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype def cross(g: jit_utils.GraphContext, input, other, dim=None): dim = symbolic_helper._get_dim_for_cross(input, dim) # If we have two tensors such that # A = [a, b, c], B = [d, e, f], we permute the tensor such that we have # After first roll, # A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e) roll_x_1 = roll(g, input, [2], [dim]) roll_y_1 = roll(g, other, [1], [dim]) # After second roll, # A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d) roll_x_2 = roll(g, input, [1], [dim]) roll_y_2 = roll(g, other, [2], [dim]) # cross product is calculated as # result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)] return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2)) @_onnx_symbolic("aten::cdist") @_beartype.beartype def cdist( g: jit_utils.GraphContext, x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary", ): # X1.shape = (B * P * D), X2.shape = (B * R * D) # In order to respect numpy style broadcasting as demonstrated in # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md # we unsqueeze both input tensors # Currently we ignore the 'compute_mode' variable as we use default to # using matrix multiplication to calculate the euclidean distance rank = symbolic_helper._get_tensor_rank(x1) assert rank is not None broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1]) broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2]) return pairwise_distance( g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False ) @_onnx_symbolic("aten::lerp") @_beartype.beartype def lerp(g: jit_utils.GraphContext, self, end, weight): # Conditional for better numeric. This has been discussed in # https://github.com/pytorch/pytorch/pull/18871 diff = g.op("Sub", end, self) return where( g, g.op("Less", weight, g.op("Constant", value_t=torch.tensor(0.5))), g.op("Add", self, g.op("Mul", weight, diff)), g.op( "Sub", end, g.op( "Mul", diff, g.op("Sub", g.op("Constant", value_t=torch.tensor(1.0)), weight), ), ), ) @_onnx_symbolic("aten::broadcast_tensors") @_beartype.beartype def broadcast_tensors(g: jit_utils.GraphContext, self): all_tensors = symbolic_helper._unpack_list(self) t_with_final_shape = zeros_like(g, all_tensors[0]) # Add operator supports multidirectional broadcasting. So we leverage this function # to infer the final shape generated by the broadcast. for t in all_tensors: t_with_final_shape = add(g, t_with_final_shape, t) t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors] return g.op("prim::ListConstruct", *t_list) @_onnx_symbolic("aten::is_pinned") def is_pinned(g: jit_utils.GraphContext, self, device=None): # Unused by ONNX. return None @_onnx_symbolic("prim::ConstantSplit") @_beartype.beartype def prim_constant_split(g: jit_utils.GraphContext, self, split_size, dim): size = symbolic_helper._get_tensor_dim_size(self, dim) if size is None: return symbolic_helper._unimplemented( "prim::ConstantSplit", "unknown dimension size", self ) splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: splits.append(leftover) return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) # TODO: It would be better to export this as a chunk directly, as this is # less sensitive to changes in input size. # TODO: Once we have proper scoping, stop reimplementing chunk, delete this # method, and use the desugared version @_onnx_symbolic("prim::ConstantChunk") @_beartype.beartype def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): dim_size = symbolic_helper._get_tensor_dim_size(self, dim) if dim_size is None: return symbolic_helper._unimplemented( "prim::ConstantChunk", "unknown dimension size", self ) split_size = (dim_size + chunks - 1) // chunks return prim_constant_split(g, self, split_size, dim) @_onnx_symbolic("prim::shape") @_beartype.beartype def prim_shape(g: jit_utils.GraphContext, self): return g.op("Shape", self) @_onnx_symbolic("prim::max") @_beartype.beartype def prim_max(g: jit_utils.GraphContext, self, other): return symbolic_helper._op_with_optional_float_cast( g, "Max", self, other, opset_before=12 ) @_onnx_symbolic("prim::min") @_beartype.beartype def prim_min(g: jit_utils.GraphContext, self, other=None): if not other: if symbolic_helper._is_packed_list(self): self = stack(g, self, g.op("Constant", value_t=torch.tensor([0]))) return min(g, self) return min(g, self, other) @_onnx_symbolic("prim::data") @_beartype.beartype def prim_data(g: jit_utils.GraphContext, self): return self @_onnx_symbolic("prim::layout") def prim_layout(g: jit_utils.GraphContext, self): # Always return 'torch.strided'. Other layout types are not supported by JIT 'TensorType'. # Layout class defined in 'c10/core/Layout.h'. return g.op("Constant", value_t=torch.tensor(0)) @_onnx_symbolic("prim::ListConstruct") @_beartype.beartype def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs): return None @_onnx_symbolic("prim::ListUnpack") @_beartype.beartype def prim_list_unpack( g: jit_utils.GraphContext, *inputs, **kwargs ) -> Optional[List[_C.Value]]: if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct": # Cancel the previous node if it is ListConstruct by returning its inputs # TODO(justinchuby): Use a public method in the helper module return symbolic_helper._unpack_list(inputs[0]) return None @_onnx_symbolic("prim::TupleConstruct") @_beartype.beartype def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs): return None @_onnx_symbolic("prim::Uninitialized") @_beartype.beartype def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs): return None # exists to refine the type of the Value # if x is an optional Tensor, unchecked_cast will cast # x to Tensor, so the rest of the graph knows that x is a Tensor # this doesn't do anything in runtime and is a noop in ONNX @_onnx_symbolic("prim::unchecked_cast") @_beartype.beartype def prim_unchecked_cast(g: jit_utils.GraphContext, self): return self @_onnx_symbolic("prim::dtype") @_beartype.beartype def prim_dtype(g: jit_utils.GraphContext, self): scalar_type = symbolic_helper._try_get_scalar_type(self) if scalar_type is None: scalar_type = _type_utils.JitScalarType.FLOAT # This node records a torch dtype as int return g.op("Constant", value_t=torch.tensor(scalar_type)) @_onnx_symbolic("prim::tolist") @_beartype.beartype def prim_tolist(g: jit_utils.GraphContext, input, dim_val, elem_ty_val): """tolist is currently supported only for 1D input tensors. dim_val and elem_ty_val represent dimension and type annotations that need to match dimension and type of the input tensor. """ dim = symbolic_helper._maybe_get_const(dim_val, "i") if dim > 1: return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input) return input # ----------------------------------------------------------------------------- # Symbolic functions that need extra context # ----------------------------------------------------------------------------- @_onnx_symbolic("prim::device") @_beartype.beartype def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None: output_type = g.original_node.output().type() if isinstance(output_type, _C.DeviceObjType): return None return symbolic_helper._unimplemented( "prim::device", f"output type should be 'DeviceObjType', not '{output_type.kind()}'", g.original_node.output(), ) @_onnx_symbolic("prim::Loop") @_beartype.beartype def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]: node = g.original_node env = g.env values_in_env = g.values_in_env params_dict = g.params_dict operator_export_type = GLOBALS.operator_export_type opset_version = GLOBALS.export_onnx_opset_version old_blocks = tuple(node.blocks()) new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks) ) for old_block, new_block_context in zip(old_blocks, new_block_contexts): # Copy input metadata to subblock # # prim::Loop(iter, cond, input_1, ..., input_n) # block0(iter, input_1, ..., input_n) # # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. for i, b_in in enumerate(old_block.inputs()): if i == 0 and i < len(inputs): b_in.setType(inputs[i].type()) # For optional block inputs, they may switch between None not-None inside # the loop body, so if the loop input is not optional, the block input may # still need to be optional. if ( i > 0 and (i + 1) < len(inputs) and not isinstance(b_in.type(), _C.OptionalType) ): b_in.setType(inputs[i + 1].type()) torch._C._jit_pass_onnx_block( old_block, new_block_context.block, operator_export_type, env, values_in_env, False, ) fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( new_node, opset_version ) # Run shape type inference for Loop after subblock is converted. if GLOBALS.onnx_shape_inference: torch._C._jit_pass_onnx_node_shape_type_inference( new_node, params_dict, opset_version ) return fixed_outputs @_onnx_symbolic("prim::If") @_beartype.beartype def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]: n = g.original_node block = g.block env = g.env values_in_env = g.values_in_env params_dict = g.params_dict operator_export_type = GLOBALS.operator_export_type opset_version = GLOBALS.export_onnx_opset_version static_if = inputs[0].node().kind() == "onnx::Constant" if static_if: # Fold static if # # The torch IR # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu), # %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ... # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %21 : Long(device=cpu) = aten::eq(%20, %64) # %22 : Long(device=cpu) = prim::If(%21) # block0(): # %23 : Long(device=cpu) = aten::is_floating_point(%input.1) # -> (%23) # block1(): # -> (%65) # %input.53 : Tensor, %weight : Tensor = prim::If(%22) # block0(): # -> (%embedding_matrix.1, %input.1) # block1(): # -> (%input.1, %embedding_matrix.1) # %26 : int[] = aten::size(%input.53) # # The converted ONNX graph # %10 : Bool(device=cpu) = onnx::Constant[value={0}]() # %14 : Bool(device=cpu) = onnx::Equal(%13, %8) # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]() # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1) input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist() const_value = ( all(input_flag) if isinstance(input_flag, list) else bool(input_flag) ) block_idx = 0 if const_value else 1 current_b = list(n.blocks())[block_idx] env = torch._C._jit_pass_onnx_block( current_b, block, operator_export_type, env, values_in_env, True, ) if_output_list = list(n.outputs()) current_b_list = list(current_b.outputs()) final_b_list = [] for idx in range(len(if_output_list)): if current_b_list[idx] not in env: raise errors.SymbolicValueError( f"The sub block ATen output {current_b_list[idx]} is not in env.", current_b_list[idx], ) # type:ignore[operator] onnx_b = env[current_b_list[idx]] final_b_list.append(onnx_b) return final_b_list else: old_blocks = tuple(n.blocks()) new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks) ) for old_block, new_block_context in zip(old_blocks, new_block_contexts): torch._C._jit_pass_onnx_block( old_block, new_block_context.block, operator_export_type, env, values_in_env, False, ) fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( new_node, opset_version ) # Run shape type inference for If after subblock is converted. if GLOBALS.onnx_shape_inference: torch._C._jit_pass_onnx_node_shape_type_inference( new_node, params_dict, opset_version ) return fixed_outputs @_onnx_symbolic("prim::Constant") @_beartype.beartype def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs): node = g.original_node if node.mustBeNone(): return None # This must go before checking for string values, because some device constants # have string values, but we want to keep them as unconverted Device types so # that eq() can work on them. if isinstance(node.output().type(), _C.DeviceObjType): return None if node.kindOf("value") == "t": return g.op("Constant", value_t=symbolic_helper._node_get(node, "value")) if node.kindOf("value") == "s": return g.op("Constant", value_s=symbolic_helper._node_get(node, "value")) if node.output().type().isSubtypeOf( _C.ListType.ofInts() ) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()): return g.op( "Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value")) ) if node.output().type().isSubtypeOf(_C.ListType.ofStrings()): str_constants = [ g.op("Constant", value_s=s) for s in symbolic_helper._node_get(node, "value") ] return g.op("prim::ListConstruct", *str_constants) raise errors.SymbolicValueError( f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. " f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.", node.output(), ) @_onnx_symbolic("prim::type") @_beartype.beartype def prim_type(g: jit_utils.GraphContext, device_value: _C.Value, *args, **kwargs): if device_value.node().kind() == "prim::device": device = jit_utils.get_device_from_value(device_value.node().input()) if device is not None: return g.op("Constant", value_s=str(device)) return symbolic_helper._unimplemented( "prim::type", "Device type cannot be statically determined.", device_value, ) @_onnx_symbolic("onnx::Placeholder") @_beartype.beartype def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs): node = g.original_node block = g.block env = g.env values_in_env = g.values_in_env return torch._C._jit_onnx_convert_pattern_from_subblock( block, node, env, values_in_env ) @_onnx_symbolic("aten::resolve_conj") @_onnx_symbolic("aten::resolve_neg") @_beartype.beartype def noop_complex_operators(g: jit_utils.GraphContext, input: _C.Value): # ONNX does not have operators to *directly* manipulate real/imaginary components # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, # which results in failures due to missing operators for complex numbers # `aten::resolve_conj` and `aten::resolve_neg` can safely be implemented as no-op return input @_onnx_symbolic("aten::_conj") @_onnx_symbolic("aten::conj_physical") @_beartype.beartype def unsupported_complex_operators(g: jit_utils.GraphContext, input: _C.Value): # ONNX does not have operators to *directly* manipulate real/imaginary components # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, # which results in failures due to missing operators for complex numbers # While `aten::_conj` and `aten::conj_physical` raise exception when input is complex if symbolic_helper.is_complex_value(input): # FIXME(justinchuby): report correct name for symbolic being executed return symbolic_helper._onnx_unsupported( "aten::_conj, aten::conj_physical", input, ) # they can safely be implemented as no-op for real numbers only return noop_complex_operators(g, input) @_onnx_symbolic("aten::logit") @_beartype.beartype def logit(g: jit_utils.GraphContext, self: torch._C.Value, eps: torch._C.Value): one = g.op("Constant", value_t=torch.tensor(1.0)) if not symbolic_helper._is_none(eps): eps = g.op( "Cast", eps, to_i=_type_utils.JitScalarType.from_value(self).onnx_type() ) one_sub_eps = g.op("Sub", one, eps) self_less_equal_one_sub_eps = g.op("Greater", one_sub_eps, self) temporary_self = g.op("Where", self_less_equal_one_sub_eps, self, one_sub_eps) temporary_self_less_eps = g.op("Less", temporary_self, eps) z = g.op("Where", temporary_self_less_eps, eps, temporary_self) else: z = self sub = g.op("Sub", one, z) div = g.op("Div", z, sub) return g.op("Log", div)