#pragma once #include "caffe2/onnx/backend_rep.h" #include "caffe2/onnx/device.h" #include "caffe2/onnx/helper.h" #include "caffe2/proto/caffe2_pb.h" #include "onnx/onnx_pb.h" #include #include #include #include constexpr int kKnownOpsetVersion = 9; namespace caffe2 { namespace onnx { using ::ONNX_NAMESPACE::AttributeProto; using ::ONNX_NAMESPACE::GraphProto; using ::ONNX_NAMESPACE::ModelProto; using ::ONNX_NAMESPACE::NodeProto; using ::ONNX_NAMESPACE::TensorProto; using ::ONNX_NAMESPACE::ValueInfoProto; using ValueInfoMap = std::unordered_map; class TORCH_API ConversionContext { public: ConversionContext(const ValueInfoMap& value_infos, int opset_version) : value_infos_(value_infos), opset_version_(opset_version) {} const ValueInfoMap& value_infos() const { return value_infos_; } int opset_version() const { return opset_version_; } private: const ValueInfoMap& value_infos_; const int opset_version_; }; // \brief This struct holds the converted ops after the onnx->c2 conversion. // Notice that for RNN ops, it may create ops in init_net. Hence we have the // `init_ops` field. struct TORCH_API Caffe2Ops { ::google::protobuf::RepeatedPtrField init_ops; ::google::protobuf::RepeatedPtrField ops; ::google::protobuf::RepeatedPtrField interface_blobs; }; // A convenient class to query attributes of a NodeProto. Note that the // NodeProto can not be modified during the query of OnnxAttributes object class TORCH_API OnnxAttributes { public: OnnxAttributes(const NodeProto& node); bool HasAttribute(const std::string& key) const { return onnx_attrs_.count(key); } AttributeProto* AddRewrittenAttribute(const std::string& key) { auto tmp = rewritten_onnx_attrs_.emplace(key, AttributeProto()); auto& attr = tmp.first->second; attr.set_name(key); return &attr; } ::google::protobuf::RepeatedPtrField OnnxAttrToCaffe2Arg( std::function mapper) const; // Get attribute given attribute name, specialied on data type T. Note that // the return value is copied template T get(const std::string& key) const; template T get(const std::string& key, const T& default_value) const { if (onnx_attrs_.count(key)) { return get(key); } else { return default_value; } } const AttributeProto* remove(const std::string& key) { const AttributeProto* result = nullptr; auto iter = onnx_attrs_.find(key); if (iter != onnx_attrs_.end()) { result = iter->second; onnx_attrs_.erase(iter); } return result; } private: std::unordered_map onnx_attrs_; std::unordered_map rewritten_onnx_attrs_; }; template <> int64_t OnnxAttributes::get(const std::string& key) const; template <> float OnnxAttributes::get(const std::string& key) const; template <> ::google::protobuf::RepeatedPtrField OnnxAttributes::get( const std::string& key) const; template <> ::google::protobuf::RepeatedField<::google::protobuf::int64> OnnxAttributes::get(const std::string& key) const; template <> ::google::protobuf::RepeatedField OnnxAttributes::get( const std::string& key) const; template <> const TensorProto* OnnxAttributes::get(const std::string& key) const; // convenient class for onnx node struct TORCH_API OnnxNode { OnnxNode(const NodeProto& node_in) : node(node_in), attributes(node_in) {} const NodeProto& node; OnnxAttributes attributes; }; class TORCH_API Caffe2Backend { public: // Since we still have this Python-C++ hybrid flow, we will need to take the // DummyName generator from Python as a pointer. In this case, Python env owns // the DummyName object and we don't need to keep track of its life time in // C++. Therefore in this case, we use a null dtor to prevent C++ shared_ptr // from releasing the object Caffe2Backend(DummyName* dummy = nullptr) { if (dummy) { dummy_ = std::shared_ptr(dummy, [](DummyName*) {}); } else { dummy_ = std::make_shared(); } } Caffe2BackendRep* Prepare( const std::string& onnx_model_str, const std::string& device, const std::vector& extras); bool SupportOp(const std::string tyep) const; Caffe2Ops ConvertNode( const std::string& node_str, const ConversionContext& ctx); void BuildTensorFillingOp( caffe2::OperatorDef* c2_op, const TensorProto& onnx_tensor, const std::string& output_name = "", const std::string& shape_name = ""); private: using SpecialOpConverter = Caffe2Ops (Caffe2Backend::*)(OnnxNode*, const ConversionContext&); void OnnxToCaffe2( caffe2::NetDef* init_net, caffe2::NetDef* pred_net, const ModelProto& onnx_model, const std::string& device, int opset_version, bool include_initializers, const std::vector& extras); void CheckOpSchemaArguments( const caffe2::OpSchema& schema, const caffe2::OperatorDef& op); Caffe2Ops OnnxNodeToCaffe2Ops( const ModelProto& init_model, const ModelProto& pred_model, const ConversionContext& ctx, OnnxNode* onnx_node); std::unordered_set AllNamesInGraph(const GraphProto& graph); Caffe2Ops CommonOnnxNodeToCaffe2Ops( OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateArgMaxMin(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateCast(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateConstant(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateConstantOfShape( OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateConvPoolOpBase( OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreatePadPool(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateReshape(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateGather(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateGemm(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreatePad(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateConcat(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateLogSoftmax(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateSlice(OnnxNode* onnx_node, const ConversionContext& ctx); std::string PreprocessSliceIndexTensor( OnnxNode* onnx_node, Caffe2Ops& ret, std::string indices_tensor, std::string axes_tensor, std::string rank_tensor, std::string zero_tensor, std::string one_tensor, int default_value); Caffe2Ops CreateDynamicSlice( OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateSplit(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateReciprocal(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateRandomNormal( OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateNonZeroOp(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateMultinomialOp( OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateBatchNormalization( OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateMatMul(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateUpsample(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateDropout(OnnxNode* onnx_node, const ConversionContext& ctx); Caffe2Ops CreateLRN(OnnxNode* onnx_node, const ConversionContext& ctx); // LUT related getters const std::unordered_map& get_renamed_operators() const; const std::unordered_set& get_rnn_operators() const; const std::unordered_map& get_broken_operators() const; const std::unordered_map& get_renamed_attrs() const; const std:: unordered_map>& get_per_op_renamed_attrs() const; const std::unordered_map& get_special_operators() const; // Dummy name generator std::shared_ptr dummy_; }; } // namespace onnx } // namespace caffe2