#ifndef CAFFE2_OPT_OPT_PASSS_H #define CAFFE2_OPT_OPT_PASSS_H #include "caffe2/core/common.h" #include "caffe2/core/workspace.h" #include "caffe2/proto/caffe2_pb.h" #include "nomnigraph/Representations/NeuralNet.h" using namespace nom::repr; namespace caffe2 { /* This file sets up the optimization pass registry. * * You'll want to either create a class that inherits from OptimizationPass * and implements run or use the REGISTER_OPT_PASS_FROM_FUNC(name, func) * to register a function that takes in an NNModule*. * * If you need access to the workspace in the optimization you'll need to * use a different registry and inherit from WorkspaceOptimizationPass. */ class TORCH_API OptimizationPass { public: OptimizationPass(NNModule* nn) : nn_(nn) {} virtual void run() = 0; virtual ~OptimizationPass() {} protected: NNModule* nn_; }; class TORCH_API WorkspaceOptimizationPass : public OptimizationPass { public: WorkspaceOptimizationPass(NNModule* nn, Workspace* ws) : OptimizationPass(nn), ws_(ws) {} virtual ~WorkspaceOptimizationPass() {} protected: Workspace* ws_; }; C10_DECLARE_REGISTRY( WorkspaceOptimizationPassRegistry, WorkspaceOptimizationPass, NNModule*, Workspace*); #define REGISTER_WS_OPT_PASS(clsname) \ C10_REGISTER_CLASS(WorkspaceOptimizationPassRegistry, clsname, clsname) #define REGISTER_WS_OPT_PASS_FROM_FUNC(passname, funcname) \ class passname : public WorkspaceOptimizationPass { \ public: \ using WorkspaceOptimizationPass::WorkspaceOptimizationPass; \ void run() override { \ funcname(nn_, ws_); \ } \ }; \ REGISTER_WS_OPT_PASS(passname); C10_DECLARE_REGISTRY(OptimizationPassRegistry, OptimizationPass, NNModule*); #define REGISTER_OPT_PASS(clsname) \ C10_REGISTER_CLASS(OptimizationPassRegistry, clsname, clsname) #define REGISTER_OPT_PASS_FROM_FUNC(passname, funcname) \ class passname : public OptimizationPass { \ public: \ using OptimizationPass::OptimizationPass; \ void run() override { \ funcname(nn_); \ } \ }; \ REGISTER_OPT_PASS(passname); } // namespace caffe2 #endif // CAFFE2_OPT_OPT_PASSS_H