mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
	
	
		
			52 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			C++
		
	
	
	
		
		
			
		
	
	
			52 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			C++
		
	
	
	
|  | //
 | ||
|  | //  Optimizer.cpp
 | ||
|  | //  MNN
 | ||
|  | //
 | ||
|  | //  Created by MNN on 2019/08/20.
 | ||
|  | //  Copyright © 2018, Alibaba Group Holding Limited
 | ||
|  | //
 | ||
|  | 
 | ||
|  | #include <MNN/expr/Optimizer.hpp>
 | ||
|  | #include "MergeOptimizer.hpp"
 | ||
|  | #include "core/Backend.hpp"
 | ||
|  | namespace MNN { | ||
|  | namespace Express { | ||
|  | Optimizer::Parameters::Parameters(int n) { | ||
|  |     MNN_ASSERT(n > 0); | ||
|  |     mValue = new float[n]; | ||
|  |     mSize  = n; | ||
|  | } | ||
|  | Optimizer::Parameters::~Parameters() { | ||
|  |     if (nullptr != mValue) { | ||
|  |         delete[] mValue; | ||
|  |     } | ||
|  | } | ||
|  | std::shared_ptr<Optimizer> Optimizer::create(Config config) { | ||
|  |     const int numThread = config.numThread; | ||
|  |     auto forwardType = config.forwardType; | ||
|  |     if (forwardType != MNN_FORWARD_ALL) { | ||
|  |         if (MNNGetExtraBackendCreator(forwardType) == nullptr) { | ||
|  |             return nullptr; | ||
|  |         } | ||
|  |         return std::shared_ptr<Optimizer>(new MergeOptimizer(config.forwardType, numThread, nullptr)); | ||
|  |     } | ||
|  | 
 | ||
|  |     auto device = config.device; | ||
|  |     if (CPU == device) { | ||
|  |         return std::shared_ptr<Optimizer>(new MergeOptimizer(MNN_FORWARD_CPU, numThread, nullptr)); | ||
|  |     } | ||
|  |     if (GPU == device) { | ||
|  |         std::vector<MNNForwardType> types {MNN_FORWARD_METAL, MNN_FORWARD_OPENCL, MNN_FORWARD_VULKAN, MNN_FORWARD_OPENGL}; | ||
|  |         for (auto type : types) { | ||
|  |             auto creator = MNNGetExtraBackendCreator(type); | ||
|  |             if (nullptr != creator) { | ||
|  |                 return std::shared_ptr<Optimizer>(new MergeOptimizer(type, numThread, nullptr)); | ||
|  |             } | ||
|  |         } | ||
|  |     } | ||
|  |     return nullptr; | ||
|  | } | ||
|  | 
 | ||
|  | } // namespace Express
 | ||
|  | } // namespace MNN
 |