/** * Copyright (c) 2016-present, Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef CAFFE2_OPT_FUSION_H_ #define CAFFE2_OPT_FUSION_H_ #include "caffe2/core/workspace.h" #include "nomnigraph/Representations/NeuralNet.h" namespace caffe2 { namespace opt { using namespace nom; TORCH_API void fuseConvBN(repr::NNModule* nn, caffe2::Workspace* ws); // Generic activation fusion helper. // // \tparam OperationT The operator to be fused. // \tparam ActivationT The activation to be fused. // \param nn Neural network module to be modified in place // \param should_fuse Given a conv op, check whether we want to fuse it with // subsequent relu or not // \param postprocess Functor to postprocess the conv node, // attaching additional attributes if necessary template C10_EXPORT void fuseActivation( repr::NNModule* nn, std::function should_fuse, std::function postprocess) { for (auto node_pair : repr::nn::dataIterator(nn->dataFlow)) { repr::NNGraph::NodeRef conv_node; OperationT* conv; std::tie(conv, conv_node) = node_pair; // Check topological feasibility auto conv_outputs = repr::nn::getOutputs(conv_node); if (conv_outputs.size() != 1) { continue; } auto conv_output = conv_outputs.front(); auto consumers = repr::nn::getConsumers(conv_output); if (consumers.size() != 1) { continue; } if (!repr::nn::is(consumers.front())) { continue; } auto relu_node = consumers.front(); auto relu_outputs = repr::nn::getOutputs(relu_node); if (relu_outputs.size() != 1) { continue; } // Check feasibility with application specific logic if (!should_fuse(*conv)) { continue; } // Ready to fuse auto relu_output = relu_outputs.front(); auto output_tensor = repr::nn::get(relu_output); auto output_node = relu_output; auto input_tensor = repr::nn::get(repr::nn::getInputs(conv_node).front()); // Conv cannot be in-place if (output_tensor->getName() != input_tensor->getName()) { nn->dataFlow.replaceNode(conv_output, relu_output); nn->dataFlow.deleteNode(relu_node); nn->dataFlow.deleteNode(conv_output); } else { nn->dataFlow.replaceNode(relu_output, conv_output); output_tensor = repr::nn::get(conv_output); output_node = conv_output; nn->dataFlow.deleteNode(relu_node); nn->dataFlow.deleteNode(relu_output); } // We may have accidentally made the next op in-place // In future iterations of transformations this won't be an issue, // but current caffe2 predictor usage requires things like // external_input and output to be unchanged. bool rectify_inplace = false; for (auto& consumer : repr::nn::getConsumers(output_node)) { for (auto& consumer_output : repr::nn::getOutputs(consumer)) { auto co_name = repr::nn::get(consumer_output)->getName(); if (co_name == output_tensor->getName()) { rectify_inplace = true; } } } if (rectify_inplace) { auto new_output = nn->dataFlow.createNode( make_unique(output_tensor->getName() + "_fusion_fix")); nn->dataFlow.replaceNode(output_node, new_output); } // Application specific logic for postprocessing the conv node postprocess(conv_node); } } } // namespace opt } // namespace caffe2 #endif // CAFFE2_OPT_FUSION_H_