mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			52 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			52 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  Optimizer.cpp
 | |
| //  MNN
 | |
| //
 | |
| //  Created by MNN on 2019/08/20.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include <MNN/expr/Optimizer.hpp>
 | |
| #include "MergeOptimizer.hpp"
 | |
| #include "core/Backend.hpp"
 | |
| namespace MNN {
 | |
| namespace Express {
 | |
| Optimizer::Parameters::Parameters(int n) {
 | |
|     MNN_ASSERT(n > 0);
 | |
|     mValue = new float[n];
 | |
|     mSize  = n;
 | |
| }
 | |
| Optimizer::Parameters::~Parameters() {
 | |
|     if (nullptr != mValue) {
 | |
|         delete[] mValue;
 | |
|     }
 | |
| }
 | |
| std::shared_ptr<Optimizer> Optimizer::create(Config config) {
 | |
|     const int numThread = config.numThread;
 | |
|     auto forwardType = config.forwardType;
 | |
|     if (forwardType != MNN_FORWARD_ALL) {
 | |
|         if (MNNGetExtraBackendCreator(forwardType) == nullptr) {
 | |
|             return nullptr;
 | |
|         }
 | |
|         return std::shared_ptr<Optimizer>(new MergeOptimizer(config.forwardType, numThread, nullptr));
 | |
|     }
 | |
| 
 | |
|     auto device = config.device;
 | |
|     if (CPU == device) {
 | |
|         return std::shared_ptr<Optimizer>(new MergeOptimizer(MNN_FORWARD_CPU, numThread, nullptr));
 | |
|     }
 | |
|     if (GPU == device) {
 | |
|         std::vector<MNNForwardType> types {MNN_FORWARD_METAL, MNN_FORWARD_OPENCL, MNN_FORWARD_VULKAN, MNN_FORWARD_OPENGL};
 | |
|         for (auto type : types) {
 | |
|             auto creator = MNNGetExtraBackendCreator(type);
 | |
|             if (nullptr != creator) {
 | |
|                 return std::shared_ptr<Optimizer>(new MergeOptimizer(type, numThread, nullptr));
 | |
|             }
 | |
|         }
 | |
|     }
 | |
|     return nullptr;
 | |
| }
 | |
| 
 | |
| } // namespace Express
 | |
| } // namespace MNN
 |