#include "caffe2/opt/mobile.h" #include "caffe2/core/logging.h" #include "caffe2/opt/converter.h" #include "caffe2/opt/fusion.h" #include "caffe2/opt/passes.h" namespace caffe2 { namespace opt { using namespace nom; void addNNPACK(repr::NNModule* nn, bool low_memory) { for (auto node : nn->dataFlow.getMutableNodes()) { // Skip blobs. NOM_REQUIRE_OR_CONT(repr::nn::is(node)); // Check if it is a convolution. auto nnOp = repr::nn::get(node); NOM_REQUIRE_OR_CONT(isa(nnOp)); // Requires X, W, b for NNPACK NOM_REQUIRE_OR_CONT(node->getInEdges().size() >= 3); std::string engine = "NNPACK"; // Now do some specific checks to see if an NNPACK engine is correct. bool validTransformCandidate = true; auto conv = dyn_cast(nnOp); NOM_REQUIRE_OR_CONT(conv->getLayout() == nom::repr::Conv::NNLayout::NCHW); // NNPACK only supports stride == 1 for (auto stride : conv->getStrides()) { if (stride != 1) { validTransformCandidate = false; break; } } NOM_REQUIRE_OR_CONT(validTransformCandidate); // NNPACK only supports 2DConv. const auto& kernelShape = conv->getKernelShape(); NOM_REQUIRE_OR_CONT(kernelShape.size() == 2); // Kx1 and 1xK convs are inefficient in NNPACK. if (kernelShape[0] != kernelShape[1]) { NOM_REQUIRE_OR_CONT(kernelShape[0] != 1 && kernelShape[1] != 1); } // We're good to use our engine. auto annotation = conv->getMutableAnnotation(); NOM_REQUIRE_OR_CONT(annotation && isa(annotation)); auto* op = dyn_cast(annotation)->getMutableOperatorDef(); op->set_engine(engine); if (!low_memory) { auto* precompute_argument = op->add_arg(); precompute_argument->set_name("convolution_transform_strategy"); precompute_argument->set_s("PRECOMPUTE"); } } } namespace { inline bool isNNPACKConvReluEfficient( const std::string& algo, const repr::Conv& conv) { if (algo == "AUTO" || algo == "") { for (auto stride : conv.getStrides()) { if (stride > 1) { return false; } } for (auto kernel : conv.getKernelShape()) { if (kernel < 2) { return false; } } } else if (!(algo == "WINOGRAD" || algo == "WINOGRAD_FP16" || algo == "FT8x8" || algo == "FT16x16")) { return false; } return true; } } // namespace void fuseNNPACKConvRelu(repr::NNModule* nn) { auto should_fuse = [](const repr::Conv& conv) { const auto annotation = conv.getAnnotation(); if (!annotation || !isa(annotation)) { return false; } const auto& op = dyn_cast(annotation)->getOperatorDef(); // We only want to fuse for fast NNPACK convs if (op.engine() != "NNPACK") { return false; } caffe2::string algo = "AUTO"; for (const auto &arg : op.arg()) { if (arg.name() == "algo") { algo = arg.s(); } } if (!isNNPACKConvReluEfficient(algo, conv)) { return false; } return true; }; auto postprocess = [](repr::NNGraph::NodeRef conv_node) { auto conv = repr::nn::get(conv_node); auto annotation = conv->getMutableAnnotation(); if (!annotation || !isa(annotation)) { return; } auto* op = dyn_cast(annotation)->getMutableOperatorDef(); auto* arg = op->add_arg(); arg->set_name("activation"); arg->set_s("Relu"); }; fuseActivation(nn, should_fuse, postprocess); } REGISTER_OPT_PASS_FROM_FUNC(FuseNNPACKConvRelu, fuseNNPACKConvRelu); REGISTER_OPT_PASS_FROM_FUNC(AddNNPACK, addNNPACK); } // namespace opt } // namespace caffe2