# mypy: allow-untyped-defs """This file exports ONNX ops for opset 18. Note [ONNX Operators that are added/updated in opset 18] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set New operators: BitwiseAnd CenterCropPad Col2Im Mish OptionalGetElement OptionalHasElement Pad Resize ScatterElements ScatterND Split """ import functools from typing import List, Optional, Sequence, Tuple import torch from torch import _C from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9 from torch.onnx._internal import _beartype, jit_utils, registration # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py __all__ = [ "col2im", ] _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18) @_onnx_symbolic("aten::__and_") @_onnx_symbolic("aten::bitwise_and") @_beartype.beartype def __and_(g: jit_utils.GraphContext, self, other): # do type promotion (scalars don't seem to apply) args = [self, other] # type promotion doesn't happen with torch.bitwise_and(tensor, scalar) prom_args = [arg for arg in args if symbolic_helper._get_tensor_rank(arg)] if len(prom_args) == 0: prom_args = args promotion_jit_type = symbolic_helper._type_promote_from_values(*prom_args) self = symbolic_helper._maybe_cast_to_type(g, self, promotion_jit_type) other = symbolic_helper._maybe_cast_to_type(g, other, promotion_jit_type) if promotion_jit_type == _type_utils.JitScalarType.BOOL: return g.op("And", self, other) return g.op("BitwiseAnd", self, other) @_onnx_symbolic("aten::col2im") @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is") @_beartype.beartype def col2im( g, input: _C.Value, output_size: _C.Value, kernel_size: _C.Value, dilation: Sequence[int], padding: Sequence[int], stride: Sequence[int], ): # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in] adjusted_padding = [] for pad in padding: for _ in range(2): adjusted_padding.append(pad) num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0] if not adjusted_padding: adjusted_padding = [0, 0] * num_dimensional_axis if not dilation: dilation = [1] * num_dimensional_axis if not stride: stride = [1] * num_dimensional_axis return g.op( "Col2Im", input, output_size, kernel_size, dilations_i=dilation, pads_i=adjusted_padding, strides_i=stride, ) @_onnx_symbolic( "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] ) @_onnx_symbolic( "aten::prod", decorate=[ symbolic_helper._apply_params( "ReduceProd", "prod", allow_multi_dim_support=False ) ], ) @_beartype.beartype def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): return symbolic_helper._reduce_with_dtype_helper( onnx_op, name, allow_multi_dim_support ) @_onnx_symbolic("aten::native_layer_norm") @symbolic_helper.quantized_args(True, False, False, False) @symbolic_helper.parse_args("v", "is", "v", "v", "f") @_beartype.beartype def _native_layer_norm( g: jit_utils.GraphContext, input: _C.Value, normalized_shape: Sequence[int], weight: _C.Value, bias: _C.Value, eps: float, ) -> Tuple[_C.Value, _C.Value, _C.Value]: return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps) @_onnx_symbolic("aten::glu") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def _glu(g: jit_utils.GraphContext, input, dim): dim_size = symbolic_helper._get_tensor_dim_size(input, dim) if dim_size is not None: assert dim_size % 2 == 0 first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2) return g.op("Mul", first, g.op("Sigmoid", second)) @_onnx_symbolic("aten::max") # torch.max (same for torch.min) actually has two interfaces smashed together: # torch.max(x, dim, keepdim) and torch.max(x, y) # TODO(justinchuby): Support multiple quantized args in output @_beartype.beartype def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) @_onnx_symbolic("aten::maximum") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def maximum(g: jit_utils.GraphContext, input, other): return max(g, input, dim_or_y=other) @_onnx_symbolic("aten::min") # TODO(justinchuby): Support multiple quantized args in output @_beartype.beartype def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) @_onnx_symbolic("aten::minimum") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def minimum(g: jit_utils.GraphContext, input, other): return min(g, input, dim_or_y=other) @_onnx_symbolic("aten::amax") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "is", "i") @_beartype.beartype def amax(g: jit_utils.GraphContext, self, dim, keepdim): axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) return g.op("ReduceMax", self, axes, keepdims_i=keepdim) @_onnx_symbolic("aten::amin") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "is", "i") @_beartype.beartype def amin(g: jit_utils.GraphContext, self, dim, keepdim): axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) return g.op("ReduceMin", self, axes, keepdims_i=keepdim) @_onnx_symbolic("aten::aminmax") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): if not symbolic_helper._is_none(dim): dim = symbolic_helper._get_const(dim, "i", "dim") axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op( "ReduceMax", self, axes, keepdims_i=keepdim ) else: return g.op("ReduceMin", self, keepdims_i=keepdim), g.op( "ReduceMax", self, keepdims_i=keepdim ) @_onnx_symbolic("aten::var_mean") @_beartype.beartype def _var_mean(g: jit_utils.GraphContext, input, *args): if len(args) == 1: return symbolic_helper._var_mean_helper(g, input, None, args[0], None) else: return symbolic_helper._var_mean_helper(g, input, *args) @_onnx_symbolic("aten::logsumexp") @symbolic_helper.parse_args("v", "is", "i") @_beartype.beartype def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): if dim is None: return g.op("ReduceLogSumExp", input, keepdims_i=0) else: axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim) @_onnx_symbolic("aten::linalg_matrix_norm") @symbolic_helper.parse_args("v", "v", "is", "b", "v") @_beartype.beartype def _linalg_matrix_norm( g: jit_utils.GraphContext, self: torch._C.Value, ord: torch._C.Value, dim: List[int], keepdim: bool, dtype: torch._C.Value, ): return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) @_onnx_symbolic("aten::embedding_bag") @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") @_beartype.beartype def embedding_bag( g: jit_utils.GraphContext, embedding_matrix, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, ): return symbolic_helper._embedding_bag_helper( g, embedding_matrix, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, ) @_onnx_symbolic("aten::linalg_vector_norm") @symbolic_helper.parse_args("v", "f", "is", "b", "v") @_beartype.beartype def linalg_vector_norm( g: jit_utils.GraphContext, self: torch._C.Value, ord: float, dim: Optional[Sequence[int]], keepdim: bool, dtype: torch._C.Value, ): return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype)