mirror of https://github.com/alibaba/MNN.git
54 lines
1.7 KiB
C++
54 lines
1.7 KiB
C++
//
|
|
// TemplateMerge.hpp
|
|
// MNNConverter
|
|
//
|
|
// Created by MNN on 2019/09/16.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include <MNN/expr/Optimizer.hpp>
|
|
|
|
namespace MNN {
|
|
namespace Express {
|
|
|
|
enum PassPriority : int {
|
|
PASS_PRIORITY_FRONT = 0,
|
|
PASS_PRIORITY_HIGH = 1,
|
|
PASS_PRIORITY_MIDDLE = 2,
|
|
PASS_PRIORITY_LOW = 3,
|
|
PASS_PRIORITY_FINAL = 4,
|
|
};
|
|
|
|
class TemplateMerge : public Optimizer {
|
|
public:
|
|
virtual Cost onMeasure(const std::vector<VARP>& outputs,
|
|
std::shared_ptr<Parameters> parameters = nullptr) override {
|
|
return Cost();
|
|
}
|
|
bool onExecute(const std::vector<VARP>& outputs,
|
|
std::shared_ptr<Parameters> parameters = nullptr) override {
|
|
return onExecute(outputs, PASS_PRIORITY_HIGH);
|
|
}
|
|
bool onExecute(const std::vector<VARP>& outputs, PassPriority priority);
|
|
|
|
static TemplateMerge& getInstance(const std::string& pass);
|
|
|
|
void insertTemplate(std::string key, std::function<bool(EXPRP)> compare, std::function<bool(EXPRP)> transform,
|
|
PassPriority priority = PASS_PRIORITY_HIGH);
|
|
|
|
private:
|
|
TemplateMerge() {
|
|
}
|
|
std::vector<std::vector<std::string>> mPriorities;
|
|
std::map<std::string, std::pair<std::function<bool(EXPRP)>, std::function<bool(EXPRP)>>> mTemplates;
|
|
};
|
|
class TemplateMergeRegister {
|
|
public:
|
|
TemplateMergeRegister(const std::string& pass, std::string key, std::function<bool(EXPRP)> compare, std::function<bool(EXPRP)> transform,
|
|
PassPriority priority) {
|
|
TemplateMerge::getInstance(pass).insertTemplate(key, compare, transform, priority);
|
|
}
|
|
};
|
|
} // namespace Express
|
|
} // namespace MNN
|