# mypy: allow-untyped-defs from __future__ import annotations import functools import sys import warnings from typing import List, Optional, Sequence, Tuple, Union import torch import torch._C._onnx as _C_onnx import torch.onnx from torch import _C # Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics from torch.onnx import ( _constants, _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9, ) from torch.onnx._globals import GLOBALS from torch.onnx._internal import _beartype, jit_utils, registration # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md # This file exports ONNX ops for opset 10 # Opset 10 is supported by ONNX release 1.5.0 # release on 04/24/19 __all__ = [ "dequantize", "div", "embedding_bag", "fake_quantize_per_tensor_affine", "flip", "fmod", "isfinite", "isinf", "nan_to_num", "quantize_per_tensor", "quantized_add_relu", "quantized_add", "quantized_cat", "quantized_conv1d_relu", "quantized_conv2d_relu", "quantized_conv3d_relu", "quantized_conv1d", "quantized_conv2d", "quantized_conv3d", "quantized_conv_transpose1d", "quantized_conv_transpose2d", "quantized_conv_transpose3d", "quantized_group_norm", "quantized_hardswish", "quantized_instance_norm", "quantized_layer_norm", "quantized_leaky_relu", "quantized_linear", "quantized_linear_relu", "quantized_mul", "quantized_sigmoid", "slice", "sort", "topk", ] _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10) @_onnx_symbolic("aten::div") @_beartype.beartype def div(g: jit_utils.GraphContext, self, other, *args): if len(args) == 0: return opset9.true_divide(g, self, other) else: return _div_rounding_mode(g, self, other, *args) @symbolic_helper.parse_args("v", "v", "s") @_beartype.beartype def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): if rounding_mode == "floor": return _floor_divide(g, self, other) else: return opset9._div_rounding_mode(g, self, other, rounding_mode) @_onnx_symbolic("aten::_floor_divide") @_beartype.beartype def _floor_divide(g: jit_utils.GraphContext, self, other): if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): out = opset9.true_divide(g, self, other) return g.op("Floor", out) else: # Integer division does trunction rounding div = g.op("Div", self, other) # Division is negative if: self < 0 != other < 0 zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero)) # For negative numbers with self % other != 0, subtract 1 to round down instead of up mod = g.op("Mod", self, other, fmod_i=0) fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) fixup = g.op("Sub", div, one) return g.op("Where", fixup_mask, fixup, div) @_onnx_symbolic("aten::sort") @symbolic_helper.parse_args("v", "i", "i", "none") @_beartype.beartype def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) @_onnx_symbolic("aten::topk") @symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") @_beartype.beartype def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): return symbolic_helper._topk_helper( g, self, k, dim, largest=largest, sorted=sorted, out=out ) def _aten_max_pool_onnx( g: jit_utils.GraphContext, self: _C.Value, kernel_shape: Sequence[int], strides: Sequence[int], pads: Sequence[int], dilations: Sequence[int], ceil_mode: bool, unbatched_rank: int, ) -> _C.Value: self_rank = g.op("Size", g.op("Shape", self)) if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 self = g.op( "Unsqueeze", self, g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), ) pool_result, _ = g.op( "MaxPool", self, outputs=2, ceil_mode_i=ceil_mode, dilations_i=dilations, kernel_shape_i=kernel_shape, pads_i=pads, strides_i=strides, ) if self_rank == unbatched_rank: pool_result = g.op( "Squeeze", pool_result, g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), ) return pool_result # For MaxPool def _adjust_attributes_of_max_pool( expand_size: int, kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int], dilation: Union[Sequence[int], int], ) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: """Adjust attributes of avg_pool to match ONNX specification.""" if isinstance(dilation, int): dilation = [dilation] * expand_size if isinstance(kernel_size, int): kernel_shape = [kernel_size] * expand_size else: kernel_shape = kernel_size # type: ignore[assignment] if isinstance(padding, int): pads = [padding] * expand_size * 2 # type: ignore[operator, assignment] elif len(padding) == 1: pads = padding * expand_size * 2 # type: ignore[operator, assignment] elif len(padding) == 2: # 2D padding pads = padding * 2 # type: ignore[operator, assignment] elif len(padding) == 3: # 3D padding pads = padding * 2 # type: ignore[operator, assignment] else: # When padding is already done for all dimensions, # we don't need to double it # eg: (1, 1, 1, 1, 1, 1) pads = padding # type: ignore[assignment] if isinstance(stride, int): strides = [stride] * expand_size elif not stride: strides = kernel_shape else: strides = stride # type: ignore[assignment] return (kernel_shape, strides, pads, dilation) def _aten_max_pool_with_indices_onnx( g: jit_utils.GraphContext, self: _C.Value, kernel_shape: Sequence[int], strides: Sequence[int], pads: Sequence[int], dilations: Sequence[int], ceil_mode: bool, unbatched_rank: int, n_dims_one: Sequence[int], n_dims_zero: Sequence[int], n_dims_axes: Sequence[int], ) -> Tuple[_C.Value, Sequence[int]]: self_rank = g.op("Size", g.op("Shape", self)) if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 self = g.op( "Unsqueeze", self, g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), ) pool_result, indices = g.op( "MaxPool", self, outputs=2, ceil_mode_i=ceil_mode, dilations_i=dilations, kernel_shape_i=kernel_shape, pads_i=pads, strides_i=strides, ) _, flatten_indices = g.op( "MaxPool", self, outputs=2, dilations_i=dilations, kernel_shape_i=n_dims_one, strides_i=n_dims_one, ) ends = g.op("Constant", value_t=torch.tensor(n_dims_one)) starts = g.op("Constant", value_t=torch.tensor(n_dims_zero)) axes = g.op("Constant", value_t=torch.tensor(n_dims_axes)) delta = g.op("Slice", flatten_indices, starts, ends, axes) indices = g.op("Sub", indices, delta) if self_rank == unbatched_rank: pool_result = g.op( "Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64) ) indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64)) return (pool_result, indices) @_onnx_symbolic( "aten::max_pool1d", decorate=[symbolic_helper._apply_params("max_pool1d", 1, return_indices=False)], ) @_onnx_symbolic( "aten::max_pool2d", decorate=[symbolic_helper._apply_params("max_pool2d", 2, return_indices=False)], ) @_onnx_symbolic( "aten::max_pool3d", decorate=[symbolic_helper._apply_params("max_pool3d", 3, return_indices=False)], ) @_onnx_symbolic( "aten::max_pool1d_with_indices", decorate=[ symbolic_helper._apply_params( "max_pool1d_with_indices", 1, return_indices=True, ) ], ) @_onnx_symbolic( "aten::max_pool2d_with_indices", decorate=[ symbolic_helper._apply_params( "max_pool2d_with_indices", 2, return_indices=True, ) ], ) @_onnx_symbolic( "aten::max_pool3d_with_indices", decorate=[ symbolic_helper._apply_params( "max_pool3d_with_indices", 3, return_indices=True, ) ], ) @_beartype.beartype def _max_pool(name: str, expand_size: int, return_indices: bool): @symbolic_helper.quantized_args(True, False, False, False, False, False) @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") def symbolic_fn( g: jit_utils.GraphContext, input: _C.Value, kernel_size: Sequence[int], stride: Sequence[int], padding: Union[int, Sequence[int]], dilation: Sequence[int], ceil_mode: bool, ): kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool( expand_size, kernel_size, stride, padding, dilation ) if return_indices: return _aten_max_pool_with_indices_onnx( g, input, kernel_shape, strides, pads, dilations, ceil_mode, expand_size + 1, ([1] * expand_size), ([0] * expand_size), ([2 + i for i in range(expand_size)]), ) else: return _aten_max_pool_onnx( g, input, kernel_shape, strides, pads, dilations, ceil_mode, expand_size + 1, ) return symbolic_fn # For AvgPool def _adjust_attributes_of_avg_pool( expand_size: int, kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int], ) -> Tuple[Sequence[int], Sequence[int], Sequence[int]]: """Adjust attributes of avg_pool to match ONNX specification.""" if isinstance(kernel_size, int): kernel_shape = [kernel_size] * expand_size else: kernel_shape = kernel_size # type: ignore[assignment] if isinstance(padding, int): pads = [padding] * expand_size * 2 elif len(padding) == 1: pads = padding * expand_size * 2 # type: ignore[operator, assignment] elif len(padding) == 2: pads = padding * expand_size # type: ignore[operator, assignment] else: pads = padding * 2 # type: ignore[operator, assignment] if isinstance(stride, int): strides = [stride] * expand_size elif not stride: strides = kernel_shape else: strides = stride # type: ignore[assignment] return (kernel_shape, strides, pads) @_onnx_symbolic( "aten::avg_pool1d", decorate=[symbolic_helper._apply_params("avg_pool1d", 1)], ) @_onnx_symbolic( "aten::avg_pool2d", decorate=[symbolic_helper._apply_params("avg_pool2d", 2)], ) @_onnx_symbolic( "aten::avg_pool3d", decorate=[symbolic_helper._apply_params("avg_pool3d", 3)], ) @_beartype.beartype def _avg_pool(name, expand_size): @symbolic_helper.quantized_args(True, False, False, False, False, False, False) @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") @_beartype.beartype def symbolic_fn( g, input: _C.Value, kernel_size: Sequence[int], stride: Sequence[int], padding: Union[int, Sequence[int]], ceil_mode: int, count_include_pad: int, divisor_override=None, ): kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( expand_size, kernel_size, stride, padding ) result = g.op( "AveragePool", input, ceil_mode_i=ceil_mode, count_include_pad_i=count_include_pad, kernel_shape_i=kernel_shape, pads_i=pads, strides_i=strides, ) return result return symbolic_fn @_onnx_symbolic( "aten::upsample_nearest1d", decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], ) @_onnx_symbolic( "aten::upsample_nearest2d", decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], ) @_onnx_symbolic( "aten::upsample_nearest3d", decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], ) @_onnx_symbolic( "aten::upsample_linear1d", decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], ) @_onnx_symbolic( "aten::upsample_bilinear2d", decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], ) @_onnx_symbolic( "aten::upsample_trilinear3d", decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], ) @_beartype.beartype def _interpolate(name, dim, interpolate_mode): @symbolic_helper.quantized_args(True, False, False) @_beartype.beartype def symbolic_fn(g, input, output_size, *args): scales, align_corners = symbolic_helper._get_interpolate_attributes( g, interpolate_mode, args ) symbolic_helper._interpolate_warning(interpolate_mode) align_corners = symbolic_helper._maybe_get_scalar(align_corners) if align_corners: return symbolic_helper._unimplemented(name, "align_corners == True", input) if scales is None: scales = symbolic_helper._interpolate_size_to_scales( g, input, output_size, dim ) return g.op("Resize", input, scales, mode_s=interpolate_mode) return symbolic_fn @_onnx_symbolic("aten::__interpolate") @_beartype.beartype def __interpolate( g: jit_utils.GraphContext, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias, ): scales, mode = symbolic_helper._interpolate_get_scales_and_mode( g, input, size, scale_factor, mode, align_corners ) return g.op("Resize", input, scales, mode_s=mode) @_beartype.beartype def _slice( g: jit_utils.GraphContext, input: torch._C.Value, axes: Union[List, torch.Tensor, torch._C.Value], starts: Union[List, torch.Tensor, torch._C.Value], ends: Union[List, torch.Tensor, torch._C.Value], steps: Optional[Union[List, torch.Tensor, torch._C.Value]] = None, ): def is_none_value(value): if value is None: return True return ( isinstance(value, torch._C.Value) and value.node().kind() == "prim::Constant" and isinstance(value.type(), _C.NoneType) ) def to_slice_input(list_or_value, default_value=None): # Convert input param into a 1D torch.Value. if is_none_value(list_or_value) and default_value is not None: list_or_value = [default_value] if isinstance(list_or_value, (list, torch.Tensor)): return g.op("Constant", value_t=torch.tensor(list_or_value)) rank = symbolic_helper._get_tensor_rank(list_or_value) if rank == 0: return symbolic_helper._unsqueeze_helper(g, list_or_value, [0]) if rank == 1: return list_or_value raise errors.SymbolicValueError( f"Rank must be 0 or 1, not {rank}", list_or_value ) def get_const_value(list_or_value): if isinstance(list_or_value, (list, torch.Tensor)): if len(list_or_value) == 1: return list_or_value[0] return None return symbolic_helper._maybe_get_const(list_or_value, "i") # Check if slice is a no-op if ( get_const_value(starts) == 0 and get_const_value(ends) == _constants.INT64_MAX and (steps is None or get_const_value(steps) == 1) ): return input axes = to_slice_input(axes) starts = to_slice_input(starts, default_value=0) ends = to_slice_input(ends, default_value=_constants.INT64_MAX) if steps is None: return g.op("Slice", input, starts, ends, axes) steps = to_slice_input(steps, default_value=1) return g.op("Slice", input, starts, ends, axes, steps) @_onnx_symbolic("aten::slice") @_beartype.beartype def slice(g: jit_utils.GraphContext, self, *args): if len(args) == 4: # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor dims, start, end, step = args elif len(args) == 3: # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[] start, end, step = args dims = [0] else: raise errors.SymbolicValueError("Unknown aten::slice signature", self) return symbolic_helper._slice_helper( g, self, axes=dims, starts=start, ends=end, steps=step, ) @_onnx_symbolic("aten::flip") @symbolic_helper.parse_args("v", "is") @_beartype.beartype def flip(g: jit_utils.GraphContext, input, dims): return symbolic_helper._slice_helper( g, input, axes=dims, starts=[-1] * len(dims), ends=[-_constants.INT64_MAX] * len(dims), steps=[-1] * len(dims), ) @_onnx_symbolic("aten::fmod") @_beartype.beartype def fmod(g: jit_utils.GraphContext, input, other): return g.op("Mod", input, other, fmod_i=1) @_onnx_symbolic("aten::embedding_bag") @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") @_beartype.beartype def embedding_bag( g: jit_utils.GraphContext, embedding_matrix, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, ): if scale_grad_by_freq and GLOBALS.export_training: return symbolic_helper._onnx_unsupported( "embedding_bag with scale_grad_by_freq for training mode" ) if padding_idx is not None and padding_idx >= 0: raise RuntimeError("embedding_bag with padding_idx") warnings.warn( "Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " "Please use opset 11 or higher to export model for dynamic input shape.'" ) offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0) if offsets_dim_0 is not None: if include_last_offset: offset_len = offsets_dim_0 - 1 offsets_extended = offsets else: offset_len = offsets_dim_0 offsets_extended = [ offsets, g.op("Constant", value_t=torch.tensor([sys.maxsize])), ] offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) list_ = [] for i in range(offset_len): start_ = symbolic_helper._unsqueeze_helper( g, opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), [0], ) end_ = symbolic_helper._unsqueeze_helper( g, opset9.select( g, offsets_extended, torch.tensor(0), torch.tensor(i + 1) ), [0], ) axes_ = g.op("Constant", value_t=torch.tensor([0])) indices_row = g.op("Slice", indices, start_, end_, axes_) embeddings = g.op("Gather", embedding_matrix, indices_row) if not symbolic_helper._is_none(per_sample_weights): per_sample_weights_row = g.op( "Slice", per_sample_weights, start_, end_, axes_ ) per_sample_weights_row = symbolic_helper._unsqueeze_helper( g, per_sample_weights_row, [1] ) embeddings = g.op("Mul", embeddings, per_sample_weights_row) if mode == 0: embeddings = symbolic_helper._reducesum_helper( g, embeddings, axes_i=[0], keepdims_i=0 ) elif mode == 1: embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) else: embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0]) list_.append(embeddings) output = g.op("Concat", *list_, axis_i=0) # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. return output, None, None, None else: return symbolic_helper._onnx_unsupported( "embedding_bag with unknown shape of offsets for opset 10 is not supported. " "please use opset 11 or higher." ) @_onnx_symbolic("aten::fake_quantize_per_tensor_affine") @symbolic_helper.parse_args("v", "v", "v", "i", "i") @_beartype.beartype def fake_quantize_per_tensor_affine( g: jit_utils.GraphContext, inputs, scale, zero_point, quant_min=-128, quant_max=127, ): # NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127). # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 if (quant_min, quant_max) == (0, 127): symbolic_helper._onnx_opset_unsupported_detailed( "fake_quantize_per_tensor_affine", 10, 13, "Quantize range (0, 127) not supported, requires opset 13 Clip", inputs, ) if (quant_min, quant_max) not in [(0, 255), (-128, 127)]: raise errors.SymbolicValueError( f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). " f"Got ({quant_min}, {quant_max})", inputs, ) scale = symbolic_helper._maybe_get_scalar(scale) if scale is None: symbolic_helper._onnx_opset_unsupported_detailed( "fake_quantize_per_tensor_affine", 10, 13, "Non-constant scale not supported", inputs, ) scale = scale.float().data # Avoid exporter generating double type if quant_min == 0: zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) else: zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) return g.op( "DequantizeLinear", g.op("QuantizeLinear", inputs, scale, zero_point), scale, zero_point, ) @_onnx_symbolic("aten::isinf") @_beartype.beartype def isinf(g: jit_utils.GraphContext, input): return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)) @_onnx_symbolic("aten::isfinite") @_beartype.beartype def isfinite(g: jit_utils.GraphContext, input): inf_node = isinf(g, input) nan_node = opset9.isnan(g, input) return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node)) @_onnx_symbolic("aten::quantize_per_tensor") @_beartype.beartype def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): dtype = symbolic_helper._get_const(dtype, "i", "dtype") # TODO(justinchuby): Extract all the cast ops into a helper function. zero_point = g.op( "Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type() ) scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) return symbolic_helper.quantize_helper(g, input, scale, zero_point) @_onnx_symbolic("aten::dequantize") @_beartype.beartype def dequantize(g: jit_utils.GraphContext, input): return symbolic_helper.dequantize_helper(g, input)[0] @_onnx_symbolic("aten::nan_to_num") @symbolic_helper.parse_args("v", "f", "f", "f") @_beartype.beartype def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf): # Cannot create a int type tensor with inf/nan values, so we simply # return the original tensor if not symbolic_helper._is_fp(input): return input input_dtype = _type_utils.JitScalarType.from_value(input).dtype() if nan is None: nan = 0.0 nan_cond = opset9.isnan(g, input) nan_result = g.op( "Where", nan_cond, g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)), input, ) # For None values of posinf, neginf we use the greatest/lowest finite # value representable by input's dtype. finfo = torch.finfo(input_dtype) if posinf is None: posinf = finfo.max posinf_cond = opset9.logical_and( g, isinf(g, nan_result), opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))), ) nan_posinf_result = g.op( "Where", posinf_cond, g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)), nan_result, ) if neginf is None: neginf = finfo.min neginf_cond = opset9.logical_and( g, isinf(g, nan_posinf_result), opset9.lt( g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0])) ), ) return g.op( "Where", neginf_cond, g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)), nan_posinf_result, ) # Quantized symbolics --------------------------------------------------------- # https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export # Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were # introduced in opset version 10. @_onnx_symbolic("quantized::linear") @_beartype.beartype def quantized_linear( g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.linear(g, input, weight, bias) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::linear_relu") @_beartype.beartype def quantized_linear_relu( g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.linear(g, input, weight, bias) output = opset9.relu(g, output) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::add") @_beartype.beartype def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): x, _, _, _ = symbolic_helper.dequantize_helper(g, x) y, _, _, _ = symbolic_helper.dequantize_helper(g, y) output = opset9.add(g, x, y) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::add_relu") @_beartype.beartype def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): x, _, _, _ = symbolic_helper.dequantize_helper(g, x) y, _, _, _ = symbolic_helper.dequantize_helper(g, y) output = opset9.add(g, x, y) output = opset9.relu(g, output) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::mul") @_beartype.beartype def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): x, _, _, _ = symbolic_helper.dequantize_helper(g, x) y, _, _, _ = symbolic_helper.dequantize_helper(g, y) output = opset9.mul(g, x, y) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::hardswish") @_beartype.beartype def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): x, _, _, _ = symbolic_helper.dequantize_helper(g, x) output = opset9.hardswish(g, x) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::sigmoid") @_beartype.beartype def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point): x, _, _, _ = symbolic_helper.dequantize_helper(g, x) output = opset9.sigmoid(g, x) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::leaky_relu") @_beartype.beartype def quantized_leaky_relu( g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point ): x, _, _, _ = symbolic_helper.dequantize_helper(g, x) output = opset9.leaky_relu(g, x, negative_slope, inplace) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::layer_norm") @_beartype.beartype def quantized_layer_norm( g: jit_utils.GraphContext, x, normalized_shape, weight, bias, eps, op_scale, op_zero_point, ): x, _, _, _ = symbolic_helper.dequantize_helper(g, x) output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::group_norm") @_beartype.beartype def quantized_group_norm( g: jit_utils.GraphContext, x, num_groups, weight, bias, eps, op_scale, op_zero_point, ): x, _, _, _ = symbolic_helper.dequantize_helper(g, x) output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::instance_norm") @symbolic_helper.parse_args("v", "v", "v", "f", "v", "v") @_beartype.beartype def quantized_instance_norm( g: jit_utils.GraphContext, q_input, weight, bias, eps, op_scale, op_zero_point, ): input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input) output = opset9.instance_norm( g, input, weight, bias, None, None, False, 0.0, eps, False ) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv1d_relu") @_beartype.beartype def quantized_conv1d_relu( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) output = opset9.relu(g, output) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv2d_relu") @_beartype.beartype def quantized_conv2d_relu( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) output = opset9.relu(g, output) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv3d_relu") @_beartype.beartype def quantized_conv3d_relu( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) output = opset9.relu(g, output) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv1d") @_beartype.beartype def quantized_conv1d( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv2d") @_beartype.beartype def quantized_conv2d( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv3d") @_beartype.beartype def quantized_conv3d( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv_transpose1d") @_beartype.beartype def quantized_conv_transpose1d( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, output_padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv_transpose2d( g, input, weight, bias, stride, padding, output_padding, groups, dilation ) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv_transpose2d") @_beartype.beartype def quantized_conv_transpose2d( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, output_padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv_transpose2d( g, input, weight, bias, stride, padding, output_padding, groups, dilation ) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::conv_transpose3d") @_beartype.beartype def quantized_conv_transpose3d( g: jit_utils.GraphContext, q_input, q_weight, bias, stride, padding, output_padding, dilation, groups, op_scale, op_zero_point, ): input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) output = opset9.conv_transpose3d( g, input, weight, bias, stride, padding, output_padding, groups, dilation ) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) @_onnx_symbolic("quantized::cat") @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def quantized_cat( g: jit_utils.GraphContext, q_inputs: _C.Value, dim: int, op_scale: _C.Value, op_zero_point: _C.Value, ) -> _C.Value: unpacked_inputs = symbolic_helper._unpack_list(q_inputs) dequantized = [ symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs ] concatenated = g.op("Concat", *dequantized, axis_i=dim) return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point)