mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			62 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			62 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  TemplateMerge.hpp
 | |
| //  MNNConverter
 | |
| //
 | |
| //  Created by MNN on 2019/09/16.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include <MNN/expr/Optimizer.hpp>
 | |
| #include "Global.hpp"
 | |
| #include "config.hpp"
 | |
| 
 | |
| #define MNN_THROW_CHECK(success, log) \
 | |
| if(!(success)){ \
 | |
| MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
 | |
| }
 | |
| 
 | |
| namespace MNN {
 | |
| namespace Express {
 | |
| 
 | |
| enum PassPriority : int {
 | |
|    PASS_PRIORITY_FRONT = 0,
 | |
|    PASS_PRIORITY_HIGH = 1,
 | |
|    PASS_PRIORITY_MIDDLE = 2,
 | |
|    PASS_PRIORITY_LOW = 3,
 | |
|    PASS_PRIORITY_FINAL = 4,
 | |
| };
 | |
| 
 | |
| class TemplateMerge : public Optimizer {
 | |
| public:
 | |
|     virtual Cost onMeasure(const std::vector<VARP>& outputs,
 | |
|                            std::shared_ptr<Parameters> parameters = nullptr) override {
 | |
|         return Cost();
 | |
|     }
 | |
|     bool onExecute(const std::vector<VARP>& outputs,
 | |
|                    std::shared_ptr<Parameters> parameters = nullptr) override {
 | |
|         return onExecute(outputs, PASS_PRIORITY_HIGH);
 | |
|     }
 | |
|     bool onExecute(const std::vector<VARP>& outputs, PassPriority priority);
 | |
| 
 | |
|     static TemplateMerge& getInstance(const std::string& pass);
 | |
| 
 | |
|     void insertTemplate(std::string key, std::function<bool(EXPRP)> compare, std::function<bool(EXPRP)> transform,
 | |
|                         PassPriority priority = PASS_PRIORITY_HIGH);
 | |
|     void insertTemplateV2(std::string key, std::function<bool(EXPRP)> transform, PassPriority priority = PASS_PRIORITY_HIGH);
 | |
| 
 | |
| private:
 | |
|     TemplateMerge() {
 | |
|     }
 | |
|     std::vector<std::vector<std::string>> mPriorities;
 | |
|     std::map<std::string, std::function<bool(EXPRP)>> mTemplates;
 | |
| };
 | |
| class TemplateMergeRegister {
 | |
| public:
 | |
|     TemplateMergeRegister(const std::string& pass, std::string key, std::function<bool(EXPRP)> compare, std::function<bool(EXPRP)> transform,
 | |
|                           PassPriority priority) {
 | |
|         TemplateMerge::getInstance(pass).insertTemplate(key, compare, transform, priority);
 | |
|     }
 | |
| };
 | |
| } // namespace Express
 | |
| } // namespace MNN
 |