mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			52 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			52 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  ConvolutionIntFactory.cpp
 | |
| //  MNN
 | |
| //
 | |
| //  Created by MNN on 2018/08/06.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include "backend/cpu/compute/ConvolutionIntFactory.hpp"
 | |
| #include "backend/cpu/compute/ConvolutionGroup.hpp"
 | |
| #include "backend/cpu/compute/IdstConvolutionInt8.hpp"
 | |
| 
 | |
| namespace MNN {
 | |
| Execution *ConvolutionIntFactory::createUnit(const Tensor *input, const Tensor *output, const MNN::Op *op,
 | |
|                                              Backend *backend, const ConvolutionCommon::Int8Common *common, const float *bias,
 | |
|                                              size_t biasSize) {
 | |
|     auto conv2d = op->main_as_Convolution2D();
 | |
|     return new IdstConvolutionInt8(conv2d->common(), backend, common, bias, biasSize);
 | |
| }
 | |
| 
 | |
| Execution *ConvolutionIntFactory::create(const Tensor *input, const Tensor *output, const MNN::Op *op, Backend *backend,
 | |
|                                          const ConvolutionCommon::Int8Common *common) {
 | |
|     auto conv2d = op->main_as_Convolution2D();
 | |
|     int group            = conv2d->common()->group();
 | |
|     if (conv2d->common()->inputCount() != input->channel() && conv2d->common()->inputCount() > 0) {
 | |
|         group = input->channel()/ conv2d->common()->inputCount();
 | |
|     }
 | |
|     if (1 == group) {
 | |
|         return createUnit(input, output, op, backend, common, conv2d->bias()->data(), conv2d->bias()->size());
 | |
|     }
 | |
|     MNN_ASSERT(common->weight.get() != nullptr);
 | |
| 
 | |
|     // Split
 | |
|     std::vector<std::shared_ptr<Execution>> subConvolution;
 | |
|     auto groupOutputCount = conv2d->common()->outputCount() / group;
 | |
|     auto groupWeightSize  = common->weight.size() / group;
 | |
|     for (int i = 0; i < group; ++i) {
 | |
|         auto subCommon = std::make_shared<ConvolutionCommon::Int8Common>();
 | |
|         subCommon->alpha.reset(groupOutputCount);
 | |
|         ::memcpy(subCommon->alpha.get(), common->alpha.get() + groupOutputCount * i, groupOutputCount * sizeof(float));
 | |
|         subCommon->quan = common->quan;
 | |
|         subCommon->weight.reset(groupWeightSize);
 | |
|         ::memcpy(subCommon->weight.get(), common->weight.get() + groupWeightSize * i, groupWeightSize * sizeof(int8_t));
 | |
|         subConvolution.push_back(
 | |
|             std::shared_ptr<Execution>(createUnit(input, output, op, backend, subCommon.get(),
 | |
|                                                   conv2d->bias()->data() + groupOutputCount * i, groupOutputCount)));
 | |
|     }
 | |
|     return new ConvolutionGroup(backend, subConvolution);
 | |
| }
 | |
| 
 | |
| } // namespace MNN
 |