# mypy: allow-untyped-defs """This file exports ONNX ops for opset 20. Note [ONNX Operators that are added/updated in opset 20] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set New operators: AffineGrid ConstantOfShape DFT Gelu GridSample ImageDecoder IsInf IsNaN ReduceMax ReduceMin RegexFullMatch StringConcat StringSplit """ import functools import torch.nn.functional as F from torch import _C from torch.onnx import symbolic_helper from torch.onnx._internal import _beartype, jit_utils, registration # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py __all__ = ["_grid_sampler", "_affine_grid_generator"] def convert_grid_sample_mode(mode_s): return ( "linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s ) _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20) @_onnx_symbolic("aten::grid_sampler") @symbolic_helper.parse_args("v", "v", "i", "i", "b") @_beartype.beartype def _grid_sampler( g: jit_utils.GraphContext, input: _C.Value, grid: _C.Value, mode_enum: int, padding_mode_enum: int, align_corners: bool, ): mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index] # mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html mode_s = convert_grid_sample_mode(mode_s) padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg, index] return g.op( "GridSample", input, grid, align_corners_i=int(align_corners), mode_s=mode_s, padding_mode_s=padding_mode_s, ) @_onnx_symbolic("aten::affine_grid_generator") @symbolic_helper.parse_args("v", "v", "b") @_beartype.beartype def _affine_grid_generator( g: jit_utils.GraphContext, theta: _C.Value, size: _C.Value, align_corners: bool, ): return g.op( "AffineGrid", theta, size, align_corners_i=int(align_corners), )