| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  PostTreatUtils.cpp
 | 
					
						
							|  |  |  | //  MNNConverter
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2019/01/31.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "PostTreatUtils.hpp"
 | 
					
						
							|  |  |  | #include <iostream>
 | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  | #include <set>
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | using namespace MNN; | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  | static bool _OpNeedConvertContent(OpType type, int index) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     switch (type) { | 
					
						
							|  |  |  |         case OpType_Shape: | 
					
						
							|  |  |  |         case OpType_PriorBox: | 
					
						
							| 
									
										
										
										
											2019-07-19 17:09:09 +08:00
										 |  |  |         case OpType_Const: | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             return false; | 
					
						
							|  |  |  |         case OpType_Interp: | 
					
						
							|  |  |  |         case OpType_Crop: | 
					
						
							|  |  |  |         case OpType_Reshape: | 
					
						
							|  |  |  |         case OpType_Resize: | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |         case OpType_Padding: | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             if (1 == index) { | 
					
						
							|  |  |  |                 return false; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |         default: | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return true; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | PostTreatUtils::PostTreatUtils(std::unique_ptr<MNN::NetT>& net) : mNet(std::move(net)) { | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | const std::set<MNN::OpType> PostTreatUtils::NC4HW4_OPs = { | 
					
						
							|  |  |  |     MNN::OpType_Convolution, | 
					
						
							|  |  |  |     MNN::OpType_ConvolutionDepthwise, | 
					
						
							|  |  |  |     MNN::OpType_Pooling, | 
					
						
							|  |  |  |     MNN::OpType_ROIPooling, | 
					
						
							|  |  |  |     MNN::OpType_Resize, | 
					
						
							|  |  |  |     MNN::OpType_LSTM, | 
					
						
							|  |  |  |     MNN::OpType_SpatialProduct, | 
					
						
							|  |  |  |     MNN::OpType_Deconvolution, | 
					
						
							|  |  |  |     MNN::OpType_DeconvolutionDepthwise, | 
					
						
							|  |  |  |     MNN::OpType_Proposal, | 
					
						
							|  |  |  |     MNN::OpType_PriorBox, | 
					
						
							|  |  |  |     MNN::OpType_DetectionOutput, | 
					
						
							|  |  |  |     MNN::OpType_Eltwise, | 
					
						
							|  |  |  |     MNN::OpType_LRN, | 
					
						
							|  |  |  |     MNN::OpType_Interp, | 
					
						
							|  |  |  |     MNN::OpType_Crop, | 
					
						
							|  |  |  |     MNN::OpType_Scale, | 
					
						
							|  |  |  |     MNN::OpType_TfQuantizedConv2D, | 
					
						
							|  |  |  |     MNN::OpType_QuantizedDepthwiseConv2D, | 
					
						
							|  |  |  |     MNN::OpType_BatchToSpaceND, | 
					
						
							|  |  |  |     MNN::OpType_SpaceToBatchND, | 
					
						
							|  |  |  |     MNN::OpType_BatchNorm, | 
					
						
							|  |  |  |     MNN::OpType_Moments, | 
					
						
							| 
									
										
										
										
											2019-05-05 20:27:57 +08:00
										 |  |  |     MNN::OpType_QuantizedAvgPool, | 
					
						
							|  |  |  |     MNN::OpType_QuantizedAdd, | 
					
						
							| 
									
										
										
										
											2019-07-19 17:09:09 +08:00
										 |  |  |     MNN::OpType_PReLU, | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | const std::set<MNN::OpType> PostTreatUtils::COMPABILITY_OPs = { | 
					
						
							| 
									
										
										
										
											2019-08-22 20:13:46 +08:00
										 |  |  |     MNN::OpType_ReLU, MNN::OpType_ReLU6,         MNN::OpType_Concat,  MNN::OpType_Slice, MNN::OpType_Permute, | 
					
						
							|  |  |  |     MNN::OpType_Selu, MNN::OpType_ConvertTensor, MNN::OpType_Sigmoid, MNN::OpType_Cast,  MNN::OpType_Reshape, | 
					
						
							|  |  |  |     MNN::OpType_TanH, MNN::OpType_ArgMax,        MNN::OpType_Padding}; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | const std::vector<MNN::OpType> PostTreatUtils::DELETE_Ops = { | 
					
						
							|  |  |  |     MNN::OpType_Seq2Out, | 
					
						
							|  |  |  |     MNN::OpType_Dropout, | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void PostTreatUtils::treatIm2Seq() { | 
					
						
							|  |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end();) { | 
					
						
							|  |  |  |         auto& op = *iter; | 
					
						
							|  |  |  |         if (op->type != MNN::OpType_Im2Seq) { | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         auto inputId    = op->inputIndexes[0]; | 
					
						
							|  |  |  |         auto outputId   = op->outputIndexes[0]; | 
					
						
							|  |  |  |         auto outputname = mNet->tensorName[outputId]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // New Reshape
 | 
					
						
							|  |  |  |         MNN::OpT* reshapeT = new MNN::OpT; | 
					
						
							|  |  |  |         reshapeT->name     = "____reshape____" + op->name; | 
					
						
							|  |  |  |         auto reshapeP      = new MNN::ReshapeT; | 
					
						
							|  |  |  |         reshapeP->dims.push_back(0);  // b
 | 
					
						
							|  |  |  |         reshapeP->dims.push_back(-1); // c
 | 
					
						
							|  |  |  |         reshapeP->dims.push_back(1);  // h
 | 
					
						
							|  |  |  |         reshapeP->dims.push_back(0);  // w
 | 
					
						
							|  |  |  |         reshapeT->main.type  = MNN::OpParameter_Reshape; | 
					
						
							|  |  |  |         reshapeT->type       = MNN::OpType_Reshape; | 
					
						
							|  |  |  |         reshapeT->main.value = reshapeP; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // Net Tensor
 | 
					
						
							|  |  |  |         mNet->tensorName.push_back(reshapeT->name); | 
					
						
							|  |  |  |         int tempId = mNet->tensorName.size() - 1; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         reshapeT->inputIndexes.push_back(inputId); | 
					
						
							|  |  |  |         reshapeT->outputIndexes.push_back(tempId); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         op->inputIndexes[0] = tempId; | 
					
						
							|  |  |  |         op->type            = MNN::OpType_Permute; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto convP     = new MNN::PermuteT; | 
					
						
							|  |  |  |         op->main.type  = MNN::OpParameter_Permute; | 
					
						
							|  |  |  |         op->main.value = convP; | 
					
						
							|  |  |  |         convP->dims.push_back(0); | 
					
						
							|  |  |  |         convP->dims.push_back(3); | 
					
						
							|  |  |  |         convP->dims.push_back(2); | 
					
						
							|  |  |  |         convP->dims.push_back(1); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         iter = mNet->oplists.insert(iter, std::unique_ptr<MNN::OpT>(reshapeT)); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void PostTreatUtils::deleteUnusefulOp() { | 
					
						
							|  |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end();) { | 
					
						
							|  |  |  |         auto& op          = *iter; | 
					
						
							|  |  |  |         bool shouldDelete = false; | 
					
						
							|  |  |  |         for (int i = 0; i < PostTreatUtils::DELETE_Ops.size(); ++i) { | 
					
						
							|  |  |  |             if (op->type == PostTreatUtils::DELETE_Ops[i]) { | 
					
						
							|  |  |  |                 shouldDelete = true; | 
					
						
							|  |  |  |                 break; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (!shouldDelete) { | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // Find the next op
 | 
					
						
							|  |  |  |         auto originInput  = op->inputIndexes[0]; | 
					
						
							|  |  |  |         auto originOutput = op->outputIndexes[0]; | 
					
						
							|  |  |  |         iter              = mNet->oplists.erase(iter); | 
					
						
							|  |  |  |         for (auto subIter = mNet->oplists.begin(); subIter != mNet->oplists.end(); subIter++) { | 
					
						
							|  |  |  |             auto& subOp = *subIter; | 
					
						
							|  |  |  |             for (int v = 0; v < subOp->inputIndexes.size(); ++v) { | 
					
						
							|  |  |  |                 if (subOp->inputIndexes[v] == originOutput) { | 
					
						
							|  |  |  |                     subOp->inputIndexes[v] = originInput; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | bool PostTreatUtils::_merge2Convolution(const MNN::OpT* inplaceOp, MNN::OpT* convolutionOp) { | 
					
						
							|  |  |  |     if (inplaceOp->type == MNN::OpType_ReLU && inplaceOp->main.AsRelu()->slope == 0.0f) { | 
					
						
							|  |  |  |         convolutionOp->main.AsConvolution2D()->common->relu = true; | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (inplaceOp->type == MNN::OpType_ReLU6) { | 
					
						
							|  |  |  |         convolutionOp->main.AsConvolution2D()->common->relu6 = true; | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-05-27 15:59:21 +08:00
										 |  |  |     const auto& convCommon = convolutionOp->main.AsConvolution2D()->common; | 
					
						
							|  |  |  |     if (convCommon->relu || convCommon->relu6) { | 
					
						
							|  |  |  |         return false; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     if (inplaceOp->type == MNN::OpType_BatchNorm || inplaceOp->type == MNN::OpType_Scale) { | 
					
						
							|  |  |  |         std::vector<float> alpha; | 
					
						
							|  |  |  |         std::vector<float> bias; | 
					
						
							|  |  |  |         if (inplaceOp->type == MNN::OpType_BatchNorm) { | 
					
						
							|  |  |  |             auto l = inplaceOp->main.AsBatchNorm(); | 
					
						
							|  |  |  |             alpha.resize(l->channels); | 
					
						
							|  |  |  |             bias.resize(l->channels); | 
					
						
							|  |  |  |             const float* slopePtr    = l->slopeData.data(); | 
					
						
							|  |  |  |             const float* meanDataPtr = l->meanData.data(); | 
					
						
							|  |  |  |             const float* varDataPtr  = l->varData.data(); | 
					
						
							|  |  |  |             const float* biasDataPtr = l->biasData.data(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for (int i = 0; i < l->channels; i++) { | 
					
						
							|  |  |  |                 float sqrt_var = sqrt(varDataPtr[i]); | 
					
						
							|  |  |  |                 bias[i]        = biasDataPtr[i] - slopePtr[i] * meanDataPtr[i] / sqrt_var; | 
					
						
							|  |  |  |                 alpha[i]       = slopePtr[i] / sqrt_var; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (inplaceOp->type == MNN::OpType_Scale) { | 
					
						
							|  |  |  |             bias  = inplaceOp->main.AsScale()->biasData; | 
					
						
							|  |  |  |             alpha = inplaceOp->main.AsScale()->scaleData; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto conv2D     = convolutionOp->main.AsConvolution2D(); | 
					
						
							|  |  |  |         int outputCount = conv2D->common->outputCount; | 
					
						
							|  |  |  |         for (int i = 0; i < outputCount; ++i) { | 
					
						
							|  |  |  |             conv2D->bias[i] = conv2D->bias[i] * alpha[i] + bias[i]; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (nullptr != conv2D->quanParameter.get()) { | 
					
						
							|  |  |  |             for (int i = 0; i < outputCount; ++i) { | 
					
						
							|  |  |  |                 conv2D->quanParameter->alpha[i] *= alpha[i]; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             int weightPartSize = conv2D->weight.size() / outputCount; | 
					
						
							| 
									
										
										
										
											2019-08-07 16:44:09 +08:00
										 |  |  |             if (convolutionOp->type == OpType_Deconvolution) { | 
					
						
							|  |  |  |                 int inputCount = | 
					
						
							|  |  |  |                     conv2D->weight.size() / outputCount / conv2D->common->kernelX / conv2D->common->kernelY; | 
					
						
							|  |  |  |                 for (int i = 0; i < inputCount; ++i) { | 
					
						
							|  |  |  |                     auto dstPos = i * outputCount * conv2D->common->kernelY * conv2D->common->kernelX; | 
					
						
							|  |  |  |                     for (int j = 0; j < outputCount; ++j) { | 
					
						
							|  |  |  |                         auto dstPosJ = dstPos + j * conv2D->common->kernelY * conv2D->common->kernelX; | 
					
						
							|  |  |  |                         float a      = alpha[j]; | 
					
						
							|  |  |  |                         for (int k = 0; k < conv2D->common->kernelY * conv2D->common->kernelX; ++k) { | 
					
						
							|  |  |  |                             conv2D->weight[dstPosJ + k] *= a; | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } else { | 
					
						
							|  |  |  |                 for (int i = 0; i < outputCount; ++i) { | 
					
						
							|  |  |  |                     float a = alpha[i]; | 
					
						
							|  |  |  |                     for (int j = 0; j < weightPartSize; ++j) { | 
					
						
							|  |  |  |                         conv2D->weight[i * weightPartSize + j] *= a; | 
					
						
							|  |  |  |                     } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return false; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | bool PostTreatUtils::_isSingleInputOutput(const MNN::OpT* op) { | 
					
						
							|  |  |  |     if (op->inputIndexes.size() != 1 || op->outputIndexes.size() != 1) { | 
					
						
							|  |  |  |         return false; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return true; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void PostTreatUtils::merge2Convolution() { | 
					
						
							|  |  |  |     // Merge Layer
 | 
					
						
							|  |  |  |     std::vector<MNN::OpT*> readyToDelete; | 
					
						
							|  |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) { | 
					
						
							|  |  |  |         MNN::OpT& currentOp = *(iter->get()); | 
					
						
							|  |  |  |         if (currentOp.type != MNN::OpType_Convolution && currentOp.type != MNN::OpType_Deconvolution && | 
					
						
							|  |  |  |             currentOp.type != MNN::OpType_ConvolutionDepthwise) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         DCHECK(currentOp.outputIndexes.size() == 1) << "Conv output ERROR!"; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // merge Batchnorm/Relu/Relu6 to Convolution
 | 
					
						
							|  |  |  |         std::vector<MNN::OpT*> nextOp = this->_findOpByInputIndex(currentOp.outputIndexes[0]); | 
					
						
							|  |  |  |         while (1) { | 
					
						
							|  |  |  |             if (nextOp.size() != 1) { | 
					
						
							|  |  |  |                 break; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             const int nextOutputIndex = nextOp[0]->outputIndexes[0]; | 
					
						
							|  |  |  |             bool succ                 = _merge2Convolution(nextOp[0], ¤tOp); | 
					
						
							|  |  |  |             if (_isSingleInputOutput(nextOp[0]) && succ) { | 
					
						
							|  |  |  |                 // LOG(INFO) << "Merge " << nextOp[0]->name.c_str()<< " into convolution: " << currentOp.name.c_str();
 | 
					
						
							|  |  |  |                 currentOp.outputIndexes[0] = nextOp[0]->outputIndexes[0]; | 
					
						
							|  |  |  |                 readyToDelete.push_back(nextOp[0]); | 
					
						
							|  |  |  |                 nextOp = this->_findOpByInputIndex(nextOutputIndex); | 
					
						
							|  |  |  |             } else { | 
					
						
							|  |  |  |                 break; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     for (auto op : readyToDelete) { | 
					
						
							|  |  |  |         _removeOpInNet(op); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void PostTreatUtils::addTensorType() { | 
					
						
							|  |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) { | 
					
						
							|  |  |  |         auto& op = *iter; | 
					
						
							|  |  |  |         if (op->type == MNN::OpType_StridedSlice) { | 
					
						
							|  |  |  |             auto parameter = op->main.AsStridedSliceParam(); | 
					
						
							|  |  |  |             auto dataType  = parameter->T; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 int index                = op->inputIndexes[0]; | 
					
						
							|  |  |  |                 auto describe            = std::unique_ptr<MNN::TensorDescribeT>(new MNN::TensorDescribeT); | 
					
						
							|  |  |  |                 describe->index          = index; | 
					
						
							|  |  |  |                 describe->blob           = std::unique_ptr<MNN::BlobT>(new MNN::BlobT); | 
					
						
							|  |  |  |                 describe->blob->dataType = dataType; | 
					
						
							|  |  |  |                 mNet->extraTensorDescribe.push_back(std::move(describe)); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 int index                = op->outputIndexes[0]; | 
					
						
							|  |  |  |                 auto describe            = std::unique_ptr<MNN::TensorDescribeT>(new MNN::TensorDescribeT); | 
					
						
							|  |  |  |                 describe->index          = index; | 
					
						
							|  |  |  |                 describe->blob           = std::unique_ptr<MNN::BlobT>(new MNN::BlobT); | 
					
						
							|  |  |  |                 describe->blob->dataType = dataType; | 
					
						
							|  |  |  |                 mNet->extraTensorDescribe.push_back(std::move(describe)); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (op->type == MNN::OpType_Const) { | 
					
						
							|  |  |  |             auto constP = op->main.AsBlob(); | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 int index                = op->outputIndexes[0]; | 
					
						
							|  |  |  |                 auto describe            = std::unique_ptr<MNN::TensorDescribeT>(new MNN::TensorDescribeT); | 
					
						
							|  |  |  |                 describe->index          = index; | 
					
						
							|  |  |  |                 describe->blob           = std::unique_ptr<MNN::BlobT>(new MNN::BlobT); | 
					
						
							|  |  |  |                 describe->blob->dataType = constP->dataType; | 
					
						
							|  |  |  |                 mNet->extraTensorDescribe.push_back(std::move(describe)); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | void PostTreatUtils::removeInplaceOp() { | 
					
						
							|  |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) { | 
					
						
							|  |  |  |         auto& op = *iter; | 
					
						
							|  |  |  |         if (!_isSingleInputOutput(op.get())) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (op->inputIndexes[0] != op->outputIndexes[0]) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         auto originIndex = op->inputIndexes[0]; | 
					
						
							|  |  |  |         mNet->tensorName.push_back(op->name); | 
					
						
							|  |  |  |         int newIndex         = mNet->tensorName.size() - 1; | 
					
						
							|  |  |  |         op->outputIndexes[0] = newIndex; | 
					
						
							|  |  |  |         for (auto subIter = iter + 1; subIter != mNet->oplists.end(); subIter++) { | 
					
						
							|  |  |  |             auto& subOp = *subIter; | 
					
						
							|  |  |  |             for (int i = 0; i < subOp->inputIndexes.size(); ++i) { | 
					
						
							|  |  |  |                 if (subOp->inputIndexes[i] == originIndex) { | 
					
						
							|  |  |  |                     subOp->inputIndexes[i] = newIndex; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             for (int i = 0; i < subOp->outputIndexes.size(); ++i) { | 
					
						
							|  |  |  |                 if (subOp->outputIndexes[i] == originIndex) { | 
					
						
							|  |  |  |                     subOp->outputIndexes[i] = newIndex; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         mNet->tensorNumber = mNet->tensorName.size(); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  | void PostTreatUtils::turnOnnxPadToTensorflow() { | 
					
						
							|  |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end();) { | 
					
						
							|  |  |  |         auto op = iter->get(); | 
					
						
							|  |  |  |         if (OpType_Padding == op->type && op->main.type == OpParameter_Blob) { | 
					
						
							|  |  |  |             std::unique_ptr<OpT> paddingConst(new OpT); | 
					
						
							|  |  |  |             paddingConst->type          = OpType_Const; | 
					
						
							|  |  |  |             paddingConst->main.type     = OpParameter_Blob; | 
					
						
							|  |  |  |             paddingConst->main.value    = new BlobT(*op->main.AsBlob()); | 
					
						
							|  |  |  |             paddingConst->name          = op->name + "padding"; | 
					
						
							|  |  |  |             paddingConst->outputIndexes = {(int)mNet->tensorName.size()}; | 
					
						
							|  |  |  |             mNet->tensorName.emplace_back(paddingConst->name); | 
					
						
							|  |  |  |             op->inputIndexes = {op->inputIndexes[0], paddingConst->outputIndexes[0]}; | 
					
						
							|  |  |  |             op->main.Reset(); | 
					
						
							|  |  |  |             iter = mNet->oplists.insert(iter, std::move(paddingConst)); | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         iter++; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | void PostTreatUtils::reIndexTensor() { | 
					
						
							|  |  |  |     std::map<int, int> usefulTensorIndexMap; | 
					
						
							|  |  |  |     std::vector<std::string> usefulTensorName; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     std::vector<bool> tensorValid(mNet->tensorName.size(), false); | 
					
						
							|  |  |  |     for (auto& op : mNet->oplists) { | 
					
						
							|  |  |  |         for (auto index : op->inputIndexes) { | 
					
						
							|  |  |  |             tensorValid[index] = true; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (auto index : op->outputIndexes) { | 
					
						
							|  |  |  |             tensorValid[index] = true; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for (int i = 0; i < tensorValid.size(); ++i) { | 
					
						
							|  |  |  |         if (tensorValid[i]) { | 
					
						
							|  |  |  |             usefulTensorIndexMap.insert(std::make_pair(i, usefulTensorName.size())); | 
					
						
							|  |  |  |             usefulTensorName.push_back(mNet->tensorName[i]); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Re index
 | 
					
						
							|  |  |  |     for (auto& op : mNet->oplists) { | 
					
						
							|  |  |  |         for (int i = 0; i < op->inputIndexes.size(); ++i) { | 
					
						
							|  |  |  |             auto iter = usefulTensorIndexMap.find(op->inputIndexes[i]); | 
					
						
							|  |  |  |             DCHECK(iter != usefulTensorIndexMap.end()) << "ERROR"; | 
					
						
							|  |  |  |             op->inputIndexes[i] = iter->second; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (int i = 0; i < op->outputIndexes.size(); ++i) { | 
					
						
							|  |  |  |             auto iter = usefulTensorIndexMap.find(op->outputIndexes[i]); | 
					
						
							|  |  |  |             DCHECK(iter != usefulTensorIndexMap.end()) << "ERROR"; | 
					
						
							|  |  |  |             op->outputIndexes[i] = iter->second; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     mNet->tensorName = usefulTensorName; | 
					
						
							|  |  |  |     for (auto iter = mNet->extraTensorDescribe.begin(); iter != mNet->extraTensorDescribe.end();) { | 
					
						
							|  |  |  |         auto index = (*iter)->index; | 
					
						
							|  |  |  |         if (usefulTensorIndexMap.find(index) == usefulTensorIndexMap.end()) { | 
					
						
							|  |  |  |             iter = mNet->extraTensorDescribe.erase(iter); | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         (*iter)->index = usefulTensorIndexMap.find(index)->second; | 
					
						
							|  |  |  |         iter++; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void PostTreatUtils::addConverterForTensorFlowModel() { | 
					
						
							| 
									
										
										
										
											2019-08-22 20:13:46 +08:00
										 |  |  |     if (mNet->sourceType == MNN::NetSource_CAFFE) { | 
					
						
							| 
									
										
										
										
											2019-08-12 10:27:58 +08:00
										 |  |  |         return; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2019-07-19 17:09:09 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-08-07 16:44:09 +08:00
										 |  |  |     auto originTensorType = MNN::MNN_DATA_FORMAT_NHWC; | 
					
						
							| 
									
										
										
										
											2019-08-12 10:27:58 +08:00
										 |  |  |     if (mNet->sourceType == MNN::NetSource_ONNX) { | 
					
						
							| 
									
										
										
										
											2019-08-07 16:44:09 +08:00
										 |  |  |         originTensorType = MNN::MNN_DATA_FORMAT_NCHW; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-19 17:09:09 +08:00
										 |  |  |     // set the layout of every tensor
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     // Don't support inplace
 | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |     std::map<int, MNN::MNN_DATA_FORMAT> tensorType; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     std::map<std::string, MNN::MNN_DATA_FORMAT> opType; | 
					
						
							|  |  |  |     for (auto& iter : mNet->oplists) { | 
					
						
							| 
									
										
										
										
											2019-07-19 17:09:09 +08:00
										 |  |  |         // set output tensor layout of this op according to context
 | 
					
						
							| 
									
										
										
										
											2019-08-07 16:44:09 +08:00
										 |  |  |         auto type = originTensorType; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         if (iter->type == MNN::OpType_ConvertTensor) { | 
					
						
							|  |  |  |             type = iter->main.AsTensorConvertInfo()->dest; | 
					
						
							|  |  |  |         } else if (PostTreatUtils::NC4HW4_OPs.find(iter->type) != PostTreatUtils::NC4HW4_OPs.end()) { | 
					
						
							|  |  |  |             type = MNN::MNN_DATA_FORMAT_NC4HW4; | 
					
						
							|  |  |  |         } else if (PostTreatUtils::COMPABILITY_OPs.find(iter->type) != PostTreatUtils::COMPABILITY_OPs.end()) { | 
					
						
							| 
									
										
										
										
											2019-08-22 20:13:46 +08:00
										 |  |  |             int nc4hw4TypeNumber = 0; // NC4HW4 number
 | 
					
						
							| 
									
										
										
										
											2019-08-07 16:44:09 +08:00
										 |  |  |             int originTypeNumber = 0; | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |             for (int i = 0; i < iter->inputIndexes.size(); ++i) { | 
					
						
							|  |  |  |                 auto index = iter->inputIndexes[i]; | 
					
						
							|  |  |  |                 if (_OpNeedConvertContent(iter->type, i)) { | 
					
						
							|  |  |  |                     if (tensorType[index] == MNN::MNN_DATA_FORMAT_NC4HW4) { | 
					
						
							| 
									
										
										
										
											2019-08-07 16:44:09 +08:00
										 |  |  |                         nc4hw4TypeNumber++; | 
					
						
							|  |  |  |                     } else if (tensorType[index] == originTensorType) { | 
					
						
							|  |  |  |                         originTypeNumber++; | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |                     } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2019-08-07 16:44:09 +08:00
										 |  |  |             if (nc4hw4TypeNumber > originTypeNumber) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |                 type = MNN::MNN_DATA_FORMAT_NC4HW4; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             if (iter->type == MNN::OpType_Reshape) { | 
					
						
							|  |  |  |                 if (iter->main.AsReshape()->dims.size() != 4) { | 
					
						
							| 
									
										
										
										
											2019-08-07 16:44:09 +08:00
										 |  |  |                     type = originTensorType; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-07-19 17:09:09 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         for (auto index : iter->outputIndexes) { | 
					
						
							|  |  |  |             tensorType[index] = type; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         opType.insert(std::make_pair(iter->name, type)); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end();) { | 
					
						
							|  |  |  |         auto op = iter->get(); | 
					
						
							|  |  |  |         // Insert Pretreat Op if needed
 | 
					
						
							|  |  |  |         if (opType.find(op->name)->second == MNN::MNN_DATA_FORMAT_NHWC) { | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (op->type == OpType_Padding) { | 
					
						
							|  |  |  |             // Add Gather op for padding, turn nhwc -> nchw
 | 
					
						
							|  |  |  |             std::unique_ptr<OpT> gatherIndex(new OpT); | 
					
						
							|  |  |  |             gatherIndex->outputIndexes = {(int)mNet->tensorName.size()}; | 
					
						
							|  |  |  |             gatherIndex->type          = OpType_Const; | 
					
						
							|  |  |  |             gatherIndex->name          = op->name + "_Gather_Index"; | 
					
						
							|  |  |  |             mNet->tensorName.emplace_back(gatherIndex->name); | 
					
						
							|  |  |  |             gatherIndex->main.type                 = OpParameter_Blob; | 
					
						
							|  |  |  |             gatherIndex->main.value                = new BlobT; | 
					
						
							|  |  |  |             gatherIndex->main.AsBlob()->dataType   = DataType_DT_INT32; | 
					
						
							| 
									
										
										
										
											2019-08-12 10:27:58 +08:00
										 |  |  |             gatherIndex->main.AsBlob()->dataFormat = originTensorType; | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |             gatherIndex->main.AsBlob()->int32s     = {0, 3, 1, 2}; | 
					
						
							|  |  |  |             gatherIndex->main.AsBlob()->dims       = {4}; | 
					
						
							| 
									
										
										
										
											2019-08-12 10:27:58 +08:00
										 |  |  |             opType.insert(std::make_pair(gatherIndex->name, originTensorType)); | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             std::unique_ptr<OpT> gather(new OpT); | 
					
						
							|  |  |  |             gather->outputIndexes = {(int)mNet->tensorName.size()}; | 
					
						
							|  |  |  |             gather->inputIndexes  = {op->inputIndexes[1], gatherIndex->outputIndexes[0]}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             gather->type = OpType_GatherV2; | 
					
						
							|  |  |  |             gather->name = op->name + "_Gather"; | 
					
						
							|  |  |  |             mNet->tensorName.emplace_back(gather->name); | 
					
						
							| 
									
										
										
										
											2019-08-12 10:27:58 +08:00
										 |  |  |             opType.insert(std::make_pair(gather->name, originTensorType)); | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             op->inputIndexes[1]                       = gather->outputIndexes[0]; | 
					
						
							| 
									
										
										
										
											2019-08-12 10:27:58 +08:00
										 |  |  |             tensorType[gather->outputIndexes[0]]      = originTensorType; | 
					
						
							|  |  |  |             tensorType[gatherIndex->outputIndexes[0]] = originTensorType; | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             iter = mNet->oplists.insert(iter, std::move(gather)); | 
					
						
							|  |  |  |             iter = mNet->oplists.insert(iter, std::move(gatherIndex)); | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end();) { | 
					
						
							|  |  |  |         auto& op         = *iter; | 
					
						
							|  |  |  |         auto currentType = opType.find(op->name)->second; | 
					
						
							|  |  |  |         std::vector<MNN::OpT*> transformOps; | 
					
						
							|  |  |  |         auto currentName         = op->name; | 
					
						
							|  |  |  |         const bool useAutoFormat = NC4HW4_OPs.find(op->type) != NC4HW4_OPs.end(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for (int i = 0; i < op->inputIndexes.size(); ++i) { | 
					
						
							|  |  |  |             auto inputIndex = op->inputIndexes[i]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             MNN::OpT* inputOp = this->_findOpByOutputIndex(inputIndex); | 
					
						
							|  |  |  |             if (inputOp && inputOp->type == MNN::OpType_Input && useAutoFormat) { | 
					
						
							|  |  |  |                 auto inputOpParam      = inputOp->main.AsInput(); | 
					
						
							|  |  |  |                 inputOpParam->dformat  = MNN::MNN_DATA_FORMAT_NC4HW4; | 
					
						
							|  |  |  |                 tensorType[inputIndex] = MNN::MNN_DATA_FORMAT_NC4HW4; | 
					
						
							|  |  |  |                 opType[inputOp->name]  = MNN::MNN_DATA_FORMAT_NC4HW4; | 
					
						
							|  |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             auto type = tensorType[inputIndex]; | 
					
						
							|  |  |  |             if (type == currentType) { | 
					
						
							|  |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |             if (!_OpNeedConvertContent(op->type, i)) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // Insert Transform op
 | 
					
						
							|  |  |  |             MNN::OpT* transformOp = new MNN::OpT; | 
					
						
							|  |  |  |             transformOps.push_back(transformOp); | 
					
						
							|  |  |  |             MNN::TensorConvertInfoT* tc = new MNN::TensorConvertInfoT; | 
					
						
							|  |  |  |             tc->source                  = type; | 
					
						
							|  |  |  |             tc->dest                    = currentType; | 
					
						
							|  |  |  |             transformOp->main.type      = MNN::OpParameter_TensorConvertInfo; | 
					
						
							|  |  |  |             transformOp->main.value     = tc; | 
					
						
							|  |  |  |             transformOp->name           = mNet->tensorName[inputIndex] + "___tr4" + op->name; | 
					
						
							|  |  |  |             // printf("Insert convert for %s, %s 's input %d\n", net->tensorName[inputIndex].c_str(), op->name.c_str(),
 | 
					
						
							|  |  |  |             // i);
 | 
					
						
							|  |  |  |             transformOp->inputIndexes.push_back(inputIndex); | 
					
						
							|  |  |  |             transformOp->outputIndexes.push_back(mNet->tensorName.size()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             mNet->tensorName.push_back(transformOp->name); | 
					
						
							|  |  |  |             op->inputIndexes[i] = transformOp->outputIndexes[0]; | 
					
						
							|  |  |  |             transformOp->type   = MNN::OpType_ConvertTensor; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (int i = transformOps.size() - 1; i >= 0; i--) { | 
					
						
							|  |  |  |             iter = mNet->oplists.insert(iter, std::unique_ptr<MNN::OpT>(transformOps[i])); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (; (*iter)->name != currentName; iter++) { | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         iter++; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2019-07-19 17:09:09 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-08-12 10:27:58 +08:00
										 |  |  |     if (mNet->sourceType == MNN::NetSource_ONNX) { | 
					
						
							| 
									
										
										
										
											2019-07-19 17:09:09 +08:00
										 |  |  |         return; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     // Reset axis map
 | 
					
						
							|  |  |  |     const int axisMap[4] = {0, 2, 3, 1}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for (auto& op : mNet->oplists) { | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |         if (opType.find(op->name) == opType.end()) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         if (opType.find(op->name)->second == MNN::MNN_DATA_FORMAT_NHWC) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (MNN::OpType_Input == op->type) { | 
					
						
							|  |  |  |             auto input = op->main.AsInput(); | 
					
						
							|  |  |  |             if (4 == input->dims.size()) { | 
					
						
							|  |  |  |                 int h          = input->dims[1]; | 
					
						
							|  |  |  |                 int c          = input->dims[3]; | 
					
						
							|  |  |  |                 int w          = input->dims[2]; | 
					
						
							|  |  |  |                 input->dims[1] = c; | 
					
						
							|  |  |  |                 input->dims[2] = h; | 
					
						
							|  |  |  |                 input->dims[3] = w; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (MNN::OpType_Concat == op->type) { | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |             auto axis       = op->main.AsAxis(); | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |             auto concatAxis = axis->axis; | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |             if (concatAxis < 0) { | 
					
						
							|  |  |  |                 concatAxis = 4 + concatAxis; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             } | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |             DCHECK(concatAxis >= 0 && concatAxis <= 3) << "Concat axis ERROR!"; | 
					
						
							|  |  |  |             axis->axis = axisMap[concatAxis]; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |         if (MNN::OpType_Permute == op->type) { | 
					
						
							|  |  |  |             auto permuteT = op->main.AsPermute(); | 
					
						
							|  |  |  |             for (int i = 0; i < permuteT->dims.size(); ++i) { | 
					
						
							|  |  |  |                 DCHECK(permuteT->dims[i] >= 0 && permuteT->dims[i] <= 3) << "Dim Error ==> " << op->name; | 
					
						
							|  |  |  |                 permuteT->dims[i] = axisMap[permuteT->dims[i]]; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (MNN::OpType_Slice == op->type) { | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |             auto slice      = op->main.AsSlice(); | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |             auto concatAxis = slice->axis; | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |             if (concatAxis < 0) { | 
					
						
							|  |  |  |                 concatAxis = 4 + concatAxis; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             } | 
					
						
							| 
									
										
										
										
											2019-06-24 11:32:41 +08:00
										 |  |  |             DCHECK(concatAxis >= 0 && concatAxis <= 3) << "Slice axis ERROR!"; | 
					
						
							|  |  |  |             slice->axis = axisMap[concatAxis]; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |         if (MNN::OpType_Reshape == op->type) { | 
					
						
							|  |  |  |             auto reshape   = op->main.AsReshape(); | 
					
						
							|  |  |  |             auto originDim = reshape->dims; | 
					
						
							|  |  |  |             for (int i = 0; i < reshape->dims.size(); ++i) { | 
					
						
							|  |  |  |                 CHECK(i >= 0 && i <= 3) << "Error"; | 
					
						
							|  |  |  |                 reshape->dims[axisMap[i]] = originDim[i]; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |     std::set<int> tensorTypeSet; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     for (auto& iter : mNet->extraTensorDescribe) { | 
					
						
							|  |  |  |         auto index             = iter->index; | 
					
						
							|  |  |  |         iter->blob->dataFormat = tensorType[index]; | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |         tensorTypeSet.insert(index); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |     for (auto iter : tensorType) { | 
					
						
							|  |  |  |         if (tensorTypeSet.find(iter.first) != tensorTypeSet.end()) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         auto describe              = new MNN::TensorDescribeT; | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |         describe->index            = iter.first; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         describe->blob             = std::unique_ptr<MNN::BlobT>(new MNN::BlobT); | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  |         describe->blob->dataFormat = iter.second; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         describe->blob->dataType   = MNN::DataType_DT_FLOAT; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         mNet->extraTensorDescribe.push_back(std::unique_ptr<MNN::TensorDescribeT>(describe)); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | MNN::OpT* ensureOpInNet(std::unique_ptr<MNN::NetT>& net, MNN::OpT* op) { | 
					
						
							|  |  |  |     for (auto& _op : net->oplists) { | 
					
						
							|  |  |  |         if (_op.get() == op) { | 
					
						
							|  |  |  |             return op; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return nullptr; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | MNN::OpT* PostTreatUtils::_findOpByOutputIndex(int outputIndex) { | 
					
						
							|  |  |  |     for (auto& op : mNet->oplists) { | 
					
						
							|  |  |  |         if (inVector(op->outputIndexes, outputIndex)) { | 
					
						
							|  |  |  |             return op.get(); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return nullptr; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | std::vector<MNN::OpT*> PostTreatUtils::_findOpByInputIndex(int inputIndex) { | 
					
						
							|  |  |  |     std::vector<MNN::OpT*> ops; | 
					
						
							|  |  |  |     for (auto& op : mNet->oplists) { | 
					
						
							|  |  |  |         if (inVector(op->inputIndexes, inputIndex)) { | 
					
						
							|  |  |  |             ops.push_back(op.get()); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // check whether the next op is in_place op
 | 
					
						
							|  |  |  |     const int opsSize = ops.size(); | 
					
						
							|  |  |  |     if (opsSize > 1) { | 
					
						
							|  |  |  |         auto realNextOp = ops[0]; | 
					
						
							|  |  |  |         if (inVector(realNextOp->outputIndexes, inputIndex)) { | 
					
						
							|  |  |  |             ops.clear(); | 
					
						
							|  |  |  |             ops.push_back(realNextOp); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return ops; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | int PostTreatUtils::_getOpDecestorCount(MNN::OpT* op) { | 
					
						
							|  |  |  |     int decestorCount = 0; | 
					
						
							|  |  |  |     for (auto& otherOp : mNet->oplists) { | 
					
						
							|  |  |  |         if (otherOp.get() != op) { | 
					
						
							|  |  |  |             for (auto inputIndex : otherOp->inputIndexes) { | 
					
						
							|  |  |  |                 if (inVector(op->outputIndexes, inputIndex)) { | 
					
						
							|  |  |  |                     decestorCount++; | 
					
						
							|  |  |  |                     break; // one decestor just count one.
 | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return decestorCount; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void PostTreatUtils::_removeOpInNet(MNN::OpT* op) { | 
					
						
							|  |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) { | 
					
						
							|  |  |  |         if (iter->get() == op) { | 
					
						
							|  |  |  |             // LOG(INFO) << "remove op: " << op->name;
 | 
					
						
							|  |  |  |             mNet->oplists.erase(iter); | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void PostTreatUtils::_removeOnlyOneDecestorOps(MNN::OpT* op) { | 
					
						
							|  |  |  |     std::vector<MNN::OpT*> opsToBeChecked; | 
					
						
							|  |  |  |     opsToBeChecked.push_back(op); | 
					
						
							|  |  |  |     while (!opsToBeChecked.empty()) { | 
					
						
							|  |  |  |         bool hasRemoved = false; | 
					
						
							|  |  |  |         std::vector<MNN::OpT*> addedToBeChecked; | 
					
						
							|  |  |  |         for (auto iter = opsToBeChecked.begin(); iter != opsToBeChecked.end();) { | 
					
						
							|  |  |  |             MNN::OpT* op = *iter; | 
					
						
							|  |  |  |             if (!ensureOpInNet(mNet, op)) { | 
					
						
							|  |  |  |                 hasRemoved = true; | 
					
						
							|  |  |  |                 iter       = opsToBeChecked.erase(iter); | 
					
						
							|  |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             if (this->_getOpDecestorCount(op) == 0) { | 
					
						
							|  |  |  |                 for (int inputIndex : op->inputIndexes) { | 
					
						
							|  |  |  |                     addedToBeChecked.push_back(this->_findOpByOutputIndex(inputIndex)); | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 hasRemoved = true; | 
					
						
							|  |  |  |                 this->_removeOpInNet(op); | 
					
						
							|  |  |  |                 iter = opsToBeChecked.erase(iter); | 
					
						
							|  |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (!hasRemoved) | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |         opsToBeChecked.insert(opsToBeChecked.end(), addedToBeChecked.begin(), addedToBeChecked.end()); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void PostTreatUtils::removeDeconvolutionShapeInput() { | 
					
						
							|  |  |  |     std::set<MNN::OpT*> shapeOps; | 
					
						
							|  |  |  |     for (auto& op : mNet->oplists) { | 
					
						
							| 
									
										
										
										
											2019-08-12 10:27:58 +08:00
										 |  |  |         if (op->type == MNN::OpType_Deconvolution || op->type == MNN::OpType_DeconvolutionDepthwise) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             if (op->inputIndexes.size() == 1) { | 
					
						
							|  |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             int firstInputIndex = op->inputIndexes[0]; | 
					
						
							|  |  |  |             op->inputIndexes.erase(op->inputIndexes.begin()); | 
					
						
							|  |  |  |             MNN::OpT* shapeOp = this->_findOpByOutputIndex(firstInputIndex); | 
					
						
							|  |  |  |             if (shapeOp) { | 
					
						
							|  |  |  |                 shapeOps.insert(shapeOp); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     for (auto& op : shapeOps) { | 
					
						
							|  |  |  |         this->_removeOnlyOneDecestorOps(op); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void PostTreatUtils::turnInnerProduct2Convolution() { | 
					
						
							|  |  |  |     std::vector<MNN::OpT*> readyToDelete; | 
					
						
							|  |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end();) { | 
					
						
							|  |  |  |         auto& op = *iter; | 
					
						
							|  |  |  |         if (op->type != MNN::OpType_InnerProduct) { | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         // ONNX Gemm will be mapped to InnerProduct, check whether is Flatten before Gemm
 | 
					
						
							|  |  |  |         // then delete Flatten(mapped to Reshape, and this Reshape will reshape tensor to be
 | 
					
						
							|  |  |  |         // two dimensions, such as [M,K], which is the input of Gemm)
 | 
					
						
							| 
									
										
										
										
											2019-06-10 21:08:55 +08:00
										 |  |  |         auto inputId       = op->inputIndexes[0]; | 
					
						
							|  |  |  |         auto beforeGemm    = _findOpByOutputIndex(inputId); | 
					
						
							|  |  |  |         auto refBeforeGemm = _findOpByInputIndex(beforeGemm->outputIndexes[0]); | 
					
						
							|  |  |  |         if (beforeGemm->type == MNN::OpType_Reshape && _isSingleInputOutput(beforeGemm) && refBeforeGemm.size() == 1) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             // change the input index
 | 
					
						
							|  |  |  |             const int beforeGemmInputId = beforeGemm->inputIndexes[0]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             op->inputIndexes[0] = beforeGemmInputId; | 
					
						
							|  |  |  |             inputId             = beforeGemmInputId; | 
					
						
							|  |  |  |             readyToDelete.push_back(beforeGemm); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto paramInner = op->main.AsInnerProduct(); | 
					
						
							|  |  |  |         const auto axis = paramInner->axis; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         std::vector<MNN::OpT*> newOpPrevious; | 
					
						
							|  |  |  |         std::vector<MNN::OpT*> newOpPost; | 
					
						
							|  |  |  |         // New Reshape
 | 
					
						
							|  |  |  |         MNN::OpT* reshapeT = new MNN::OpT; | 
					
						
							|  |  |  |         newOpPrevious.push_back(reshapeT); | 
					
						
							|  |  |  |         reshapeT->name = "____reshape____" + op->name; | 
					
						
							|  |  |  |         auto reshapeP  = new MNN::ReshapeT; | 
					
						
							|  |  |  |         reshapeP->dims.resize(4); | 
					
						
							|  |  |  |         for (int i = 0; i < axis; ++i) { | 
					
						
							|  |  |  |             reshapeP->dims[i] = 0; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         reshapeP->dims[axis] = -1; | 
					
						
							|  |  |  |         for (int i = axis + 1; i < 4; ++i) { | 
					
						
							|  |  |  |             reshapeP->dims[i] = 1; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (mNet->sourceType == MNN::NetSource_TENSORFLOW) { | 
					
						
							|  |  |  |             reshapeP->dims[3] = -1; | 
					
						
							|  |  |  |             reshapeP->dims[1] = 1; | 
					
						
							|  |  |  |             reshapeP->dims[2] = 1; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         reshapeT->main.type  = MNN::OpParameter_Reshape; | 
					
						
							|  |  |  |         reshapeT->type       = MNN::OpType_Reshape; | 
					
						
							|  |  |  |         reshapeT->main.value = reshapeP; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // Net Tensor
 | 
					
						
							|  |  |  |         mNet->tensorName.push_back(reshapeT->name); | 
					
						
							|  |  |  |         int tempId = mNet->tensorName.size() - 1; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         reshapeT->inputIndexes.push_back(inputId); | 
					
						
							|  |  |  |         reshapeT->outputIndexes.push_back(tempId); | 
					
						
							|  |  |  |         auto opName      = op->name; | 
					
						
							|  |  |  |         bool needPermute = 1 != axis && mNet->sourceType == MNN::NetSource_CAFFE; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (needPermute) { | 
					
						
							|  |  |  |             // Add Permute
 | 
					
						
							|  |  |  |             auto permuteBefore       = new MNN::OpT; | 
					
						
							|  |  |  |             permuteBefore->type      = MNN::OpType_Permute; | 
					
						
							|  |  |  |             permuteBefore->main.type = MNN::OpParameter_Permute; | 
					
						
							|  |  |  |             auto permuteT            = new MNN::PermuteT; | 
					
						
							|  |  |  |             permuteBefore->name      = "___permute1__" + reshapeT->name; | 
					
						
							|  |  |  |             permuteT->dims.resize(4); | 
					
						
							|  |  |  |             for (int i = 0; i < 4; ++i) { | 
					
						
							|  |  |  |                 permuteT->dims[i] = i; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             permuteT->dims[1]         = axis; | 
					
						
							|  |  |  |             permuteT->dims[axis]      = 3; | 
					
						
							|  |  |  |             permuteT->dims[3]         = 1; | 
					
						
							|  |  |  |             permuteBefore->main.value = permuteT; | 
					
						
							|  |  |  |             permuteBefore->inputIndexes.push_back(tempId); | 
					
						
							|  |  |  |             mNet->tensorName.push_back(permuteBefore->name); | 
					
						
							|  |  |  |             tempId = mNet->tensorName.size() - 1; | 
					
						
							|  |  |  |             permuteBefore->outputIndexes.push_back(tempId); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             newOpPrevious.push_back(permuteBefore); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         op->inputIndexes[0] = tempId; | 
					
						
							|  |  |  |         op->type            = MNN::OpType_Convolution; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto convP                 = new MNN::Convolution2DT; | 
					
						
							|  |  |  |         auto originInner           = op->main.AsInnerProduct(); | 
					
						
							|  |  |  |         convP->common              = std::unique_ptr<MNN::Convolution2DCommonT>(new MNN::Convolution2DCommonT); | 
					
						
							|  |  |  |         convP->common->kernelX     = 1; | 
					
						
							|  |  |  |         convP->common->kernelY     = 1; | 
					
						
							|  |  |  |         convP->common->dilateX     = 1; | 
					
						
							|  |  |  |         convP->common->dilateY     = 1; | 
					
						
							|  |  |  |         convP->common->strideX     = 1; | 
					
						
							|  |  |  |         convP->common->strideY     = 1; | 
					
						
							|  |  |  |         convP->common->group       = 1; | 
					
						
							|  |  |  |         convP->common->outputCount = originInner->outputCount; | 
					
						
							|  |  |  |         convP->common->padX        = 0; | 
					
						
							|  |  |  |         convP->common->padY        = 0; | 
					
						
							|  |  |  |         convP->common->padMode     = MNN::PadMode_CAFFE; | 
					
						
							|  |  |  |         convP->bias                = originInner->bias; | 
					
						
							|  |  |  |         convP->weight              = originInner->weight; | 
					
						
							|  |  |  |         convP->quanParameter       = std::move(originInner->quanParameter); | 
					
						
							|  |  |  |         if (convP->quanParameter.get() != nullptr) { | 
					
						
							|  |  |  |             convP->quanParameter->has_scaleInt = false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         op->main.Reset(); | 
					
						
							|  |  |  |         op->main.type  = MNN::OpParameter_Convolution2D; | 
					
						
							|  |  |  |         op->main.value = convP; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (needPermute) { | 
					
						
							|  |  |  |             // Add Permute After
 | 
					
						
							|  |  |  |             auto permuteBefore       = new MNN::OpT; | 
					
						
							|  |  |  |             permuteBefore->type      = MNN::OpType_Permute; | 
					
						
							|  |  |  |             permuteBefore->main.type = MNN::OpParameter_Permute; | 
					
						
							|  |  |  |             auto permuteT            = new MNN::PermuteT; | 
					
						
							|  |  |  |             permuteBefore->name      = "___permute2__" + reshapeT->name; | 
					
						
							|  |  |  |             permuteT->dims.resize(4); | 
					
						
							|  |  |  |             permuteT->dims[0]         = 0; | 
					
						
							|  |  |  |             permuteT->dims[1]         = 3; | 
					
						
							|  |  |  |             permuteT->dims[2]         = 2; | 
					
						
							|  |  |  |             permuteT->dims[3]         = 2; | 
					
						
							|  |  |  |             permuteT->dims[axis]      = 1; | 
					
						
							|  |  |  |             permuteBefore->main.value = permuteT; | 
					
						
							|  |  |  |             mNet->tensorName.push_back(permuteBefore->name); | 
					
						
							|  |  |  |             tempId = mNet->tensorName.size() - 1; | 
					
						
							|  |  |  |             permuteBefore->inputIndexes.push_back(tempId); | 
					
						
							|  |  |  |             permuteBefore->outputIndexes.push_back(op->outputIndexes[0]); | 
					
						
							|  |  |  |             op->outputIndexes[0] = tempId; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             newOpPost.push_back(permuteBefore); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for (int i = 0; i < newOpPrevious.size(); ++i) { | 
					
						
							|  |  |  |             iter = mNet->oplists.insert(iter, std::unique_ptr<MNN::OpT>(newOpPrevious[newOpPrevious.size() - i - 1])); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for (;; iter++) { | 
					
						
							|  |  |  |             auto& op = *iter; | 
					
						
							|  |  |  |             if (op->name == opName) { | 
					
						
							|  |  |  |                 break; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for (int i = 0; i < newOpPost.size(); ++i) { | 
					
						
							|  |  |  |             iter = mNet->oplists.insert(iter + 1, std::unique_ptr<MNN::OpT>(newOpPost[i])); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for (auto op : readyToDelete) { | 
					
						
							|  |  |  |         _removeOpInNet(op); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void PostTreatUtils::turnGroupConvolution() { | 
					
						
							|  |  |  |     // Pick DepthWise one
 | 
					
						
							|  |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) { | 
					
						
							|  |  |  |         auto& op           = *iter; | 
					
						
							|  |  |  |         const auto op_type = op->type; | 
					
						
							|  |  |  |         auto conv2D        = op->main.AsConvolution2D(); | 
					
						
							|  |  |  |         auto& common       = conv2D->common; | 
					
						
							|  |  |  |         if (op_type == MNN::OpType_Convolution || op_type == MNN::OpType_Deconvolution) { | 
					
						
							|  |  |  |             bool turnConv2DW = false; | 
					
						
							|  |  |  |             // check whether idst quantization model
 | 
					
						
							|  |  |  |             if (nullptr != conv2D->quanParameter.get()) { | 
					
						
							|  |  |  |                 auto& quanParam          = conv2D->quanParameter; | 
					
						
							|  |  |  |                 auto quanWeightBuffer    = quanParam->buffer.data(); | 
					
						
							|  |  |  |                 const int weightShapeDim = static_cast<int>(quanWeightBuffer[0]); | 
					
						
							|  |  |  |                 if (weightShapeDim == 4) { | 
					
						
							|  |  |  |                     const auto weightShapePtr = reinterpret_cast<unsigned short*>(quanWeightBuffer + 1); | 
					
						
							|  |  |  |                     int ci                    = weightShapePtr[1]; | 
					
						
							|  |  |  |                     if (ci == 1 && common->group != 1 && mNet->sourceType == MNN::NetSource_CAFFE) { | 
					
						
							|  |  |  |                         ci = weightShapePtr[0]; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                     turnConv2DW = common->outputCount == common->group && ci == common->outputCount; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } else { | 
					
						
							|  |  |  |                 const int srcCount = | 
					
						
							|  |  |  |                     conv2D->weight.size() * common->group / common->outputCount / common->kernelX / common->kernelY; | 
					
						
							|  |  |  |                 turnConv2DW = common->outputCount == common->group && srcCount == common->outputCount; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if (turnConv2DW) { | 
					
						
							|  |  |  |                 switch (op_type) { | 
					
						
							|  |  |  |                     case MNN::OpType_Convolution: | 
					
						
							|  |  |  |                         op->type = MNN::OpType_ConvolutionDepthwise; | 
					
						
							|  |  |  |                         break; | 
					
						
							|  |  |  |                     case MNN::OpType_Deconvolution: | 
					
						
							|  |  |  |                         op->type = MNN::OpType_DeconvolutionDepthwise; | 
					
						
							|  |  |  |                         break; | 
					
						
							|  |  |  |                     default: | 
					
						
							|  |  |  |                         break; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Delete Convolution With Grouop
 | 
					
						
							|  |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end();) { | 
					
						
							|  |  |  |         auto& op = *iter; | 
					
						
							|  |  |  |         if (op->type != MNN::OpType_Convolution && op->type != MNN::OpType_Deconvolution) { | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         auto conv2D  = op->main.AsConvolution2D(); | 
					
						
							|  |  |  |         auto& common = conv2D->common; | 
					
						
							|  |  |  |         if (common->group == 1) { | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         int srcCount = conv2D->weight.size() * common->group / common->outputCount / common->kernelX / common->kernelY; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         DCHECK(srcCount % common->group == 0 && common->outputCount % common->group == 0) | 
					
						
							|  |  |  |             << "split group convolution ERROR! ==> " << op->name; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         std::vector<int> newConvolutionInputIndex; | 
					
						
							|  |  |  |         std::vector<int> newConvolutionOutputIndex; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for (int i = 0; i < common->group; ++i) { | 
					
						
							|  |  |  |             std::ostringstream newTensorNameOs; | 
					
						
							|  |  |  |             newTensorNameOs << op->name << "___input___" << i; | 
					
						
							|  |  |  |             newConvolutionInputIndex.push_back(mNet->tensorName.size()); | 
					
						
							|  |  |  |             mNet->tensorName.push_back(newTensorNameOs.str()); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (int i = 0; i < common->group; ++i) { | 
					
						
							|  |  |  |             std::ostringstream newTensorNameOs; | 
					
						
							|  |  |  |             newTensorNameOs << op->name << "___output___" << i; | 
					
						
							|  |  |  |             newConvolutionOutputIndex.push_back(mNet->tensorName.size()); | 
					
						
							|  |  |  |             mNet->tensorName.push_back(newTensorNameOs.str()); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         std::vector<MNN::OpT*> newOp; | 
					
						
							|  |  |  |         // Create slice op
 | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             MNN::OpT* sliceOp      = new MNN::OpT; | 
					
						
							|  |  |  |             sliceOp->type          = MNN::OpType_Slice; | 
					
						
							|  |  |  |             sliceOp->name          = op->name + "_____slice"; | 
					
						
							|  |  |  |             sliceOp->inputIndexes  = op->inputIndexes; | 
					
						
							|  |  |  |             sliceOp->outputIndexes = newConvolutionInputIndex; | 
					
						
							|  |  |  |             auto sliceT            = new MNN::SliceT; | 
					
						
							|  |  |  |             sliceOp->main.type     = MNN::OpParameter_Slice; | 
					
						
							|  |  |  |             sliceOp->main.value    = sliceT; | 
					
						
							|  |  |  |             sliceT->axis           = 1; | 
					
						
							|  |  |  |             for (int i = 0; i < common->group - 1; ++i) { | 
					
						
							|  |  |  |                 sliceT->slicePoints.push_back(srcCount / (common->group) * (i + 1)); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             newOp.push_back(sliceOp); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         int partWeightSize = conv2D->weight.size() / common->group; | 
					
						
							|  |  |  |         int partBiasSize   = conv2D->bias.size() / common->group; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // Create Sub Convolution
 | 
					
						
							|  |  |  |         for (int i = 0; i < common->group; ++i) { | 
					
						
							|  |  |  |             std::ostringstream opNameOs; | 
					
						
							|  |  |  |             auto newConvOp = new MNN::OpT; | 
					
						
							|  |  |  |             opNameOs << op->name << "__group__" << i; | 
					
						
							|  |  |  |             newConvOp->type      = op->type; | 
					
						
							|  |  |  |             newConvOp->name      = opNameOs.str(); | 
					
						
							|  |  |  |             newConvOp->main.type = MNN::OpParameter_Convolution2D; | 
					
						
							|  |  |  |             newConvOp->inputIndexes.push_back(newConvolutionInputIndex[i]); | 
					
						
							|  |  |  |             newConvOp->outputIndexes.push_back(newConvolutionOutputIndex[i]); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             auto newConvolutionT    = new MNN::Convolution2DT; | 
					
						
							|  |  |  |             newConvOp->main.value   = newConvolutionT; | 
					
						
							|  |  |  |             newConvolutionT->common = std::unique_ptr<MNN::Convolution2DCommonT>(new MNN::Convolution2DCommonT); | 
					
						
							|  |  |  |             newConvolutionT->common->kernelX     = common->kernelX; | 
					
						
							|  |  |  |             newConvolutionT->common->kernelY     = common->kernelY; | 
					
						
							|  |  |  |             newConvolutionT->common->dilateY     = common->dilateY; | 
					
						
							|  |  |  |             newConvolutionT->common->dilateX     = common->dilateX; | 
					
						
							|  |  |  |             newConvolutionT->common->strideX     = common->strideX; | 
					
						
							|  |  |  |             newConvolutionT->common->strideY     = common->strideY; | 
					
						
							|  |  |  |             newConvolutionT->common->group       = 1; | 
					
						
							|  |  |  |             newConvolutionT->common->padMode     = common->padMode; | 
					
						
							|  |  |  |             newConvolutionT->common->outputCount = common->outputCount / common->group; | 
					
						
							|  |  |  |             newConvolutionT->common->padX        = common->padX; | 
					
						
							|  |  |  |             newConvolutionT->common->padY        = common->padY; | 
					
						
							|  |  |  |             newConvolutionT->common->relu        = common->relu; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             int startWeight = partWeightSize * i; | 
					
						
							|  |  |  |             int startBias   = partBiasSize * i; | 
					
						
							|  |  |  |             for (int v = 0; v < partWeightSize; ++v) { | 
					
						
							|  |  |  |                 newConvolutionT->weight.push_back(conv2D->weight[startWeight + v]); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             for (int v = 0; v < partBiasSize; ++v) { | 
					
						
							|  |  |  |                 newConvolutionT->bias.push_back(conv2D->bias[startBias + v]); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             newOp.push_back(newConvOp); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // Set this op be Concat Op
 | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             op->type         = MNN::OpType_Concat; | 
					
						
							|  |  |  |             op->inputIndexes = newConvolutionOutputIndex; | 
					
						
							|  |  |  |             op->main.Reset(); | 
					
						
							|  |  |  |             op->main.type = MNN::OpParameter_Axis; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             auto axisT     = new MNN::AxisT; | 
					
						
							|  |  |  |             axisT->axis    = 1; | 
					
						
							|  |  |  |             op->main.value = axisT; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for (int v = 0; v < newOp.size(); ++v) { | 
					
						
							|  |  |  |             int index = newOp.size() - v - 1; | 
					
						
							|  |  |  |             iter      = mNet->oplists.insert(iter, std::unique_ptr<MNN::OpT>(newOp[index])); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void PostTreatUtils::changeBatchnNorm2Scale() { | 
					
						
							|  |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end();) { | 
					
						
							|  |  |  |         auto& op                 = *iter; | 
					
						
							|  |  |  |         const MNN::OpType opType = op->type; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (MNN::OpType_BatchNorm != opType) { | 
					
						
							|  |  |  |             iter++; | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-07-25 13:36:35 +08:00
										 |  |  |         const int inputSize = op->inputIndexes.size(); | 
					
						
							|  |  |  |         DCHECK(inputSize == 1 || inputSize == 3) << "MNN BatchnNorm input size error!"; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         // instance norm have three input tensors(input_tensor, mean, variance)
 | 
					
						
							| 
									
										
										
										
											2019-07-25 13:36:35 +08:00
										 |  |  |         if (inputSize == 3) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             iter++; | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         // DLOG(INFO) << "change BatchNorm to Scale: " << op->name;
 | 
					
						
							|  |  |  |         auto batchnormParam  = op->main.AsBatchNorm(); | 
					
						
							|  |  |  |         auto scaleParam      = new MNN::ScaleT; | 
					
						
							|  |  |  |         scaleParam->channels = batchnormParam->channels; | 
					
						
							|  |  |  |         scaleParam->scaleData.resize(batchnormParam->channels); | 
					
						
							|  |  |  |         scaleParam->biasData.resize(batchnormParam->channels); | 
					
						
							|  |  |  |         const float* slopePtr    = batchnormParam->slopeData.data(); | 
					
						
							|  |  |  |         const float* meanDataPtr = batchnormParam->meanData.data(); | 
					
						
							|  |  |  |         const float* varDataPtr  = batchnormParam->varData.data(); | 
					
						
							|  |  |  |         const float* biasDataPtr = batchnormParam->biasData.data(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for (int i = 0; i < batchnormParam->channels; i++) { | 
					
						
							|  |  |  |             float sqrt_var           = sqrt(varDataPtr[i]); | 
					
						
							|  |  |  |             scaleParam->biasData[i]  = biasDataPtr[i] - slopePtr[i] * meanDataPtr[i] / sqrt_var; | 
					
						
							|  |  |  |             scaleParam->scaleData[i] = slopePtr[i] / sqrt_var; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         op->type       = MNN::OpType_Scale; | 
					
						
							|  |  |  |         op->main.type  = MNN::OpParameter_Scale; | 
					
						
							|  |  |  |         op->main.value = scaleParam; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2019-07-02 18:01:08 +08:00
										 |  |  | void PostTreatUtils::pluginConvert() { | 
					
						
							|  |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end();) { | 
					
						
							|  |  |  |         auto op = iter->get(); | 
					
						
							|  |  |  |         if (op->type == OpType_PLUGIN) { | 
					
						
							|  |  |  |             auto plugin = op->main.AsPlugin(); | 
					
						
							|  |  |  |             if (plugin->type == "ShuffleChannel") { | 
					
						
							|  |  |  |                 int currentTensorCount = (int)mNet->tensorName.size(); | 
					
						
							|  |  |  |                 std::unique_ptr<OpT> convertTo(new OpT); | 
					
						
							|  |  |  |                 convertTo->type                               = OpType_ConvertTensor; | 
					
						
							|  |  |  |                 convertTo->main.type                          = OpParameter_TensorConvertInfo; | 
					
						
							|  |  |  |                 convertTo->main.value                         = new TensorConvertInfoT; | 
					
						
							|  |  |  |                 convertTo->main.AsTensorConvertInfo()->source = MNN_DATA_FORMAT_NC4HW4; | 
					
						
							|  |  |  |                 convertTo->main.AsTensorConvertInfo()->dest   = MNN_DATA_FORMAT_NHWC; | 
					
						
							|  |  |  |                 convertTo->inputIndexes                       = op->inputIndexes; | 
					
						
							|  |  |  |                 convertTo->name                               = op->name + "_ConvertToNHWC"; | 
					
						
							|  |  |  |                 convertTo->outputIndexes                      = {currentTensorCount + 0}; | 
					
						
							|  |  |  |                 mNet->tensorName.emplace_back(convertTo->name); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 auto group = plugin->buffer[0]->int32s[0]; | 
					
						
							|  |  |  |                 std::unique_ptr<OpT> reshape(new OpT); | 
					
						
							|  |  |  |                 reshape->type                      = OpType_Reshape; | 
					
						
							|  |  |  |                 reshape->name                      = op->name + "_Reshape"; | 
					
						
							|  |  |  |                 reshape->main.value                = new ReshapeT; | 
					
						
							|  |  |  |                 reshape->main.type                 = OpParameter_Reshape; | 
					
						
							|  |  |  |                 reshape->main.AsReshape()->dimType = MNN_DATA_FORMAT_NHWC; | 
					
						
							|  |  |  |                 reshape->main.AsReshape()->dims    = {0, 0, 0, group, -1}; | 
					
						
							|  |  |  |                 reshape->inputIndexes              = {currentTensorCount + 0}; | 
					
						
							|  |  |  |                 reshape->outputIndexes             = {currentTensorCount + 1}; | 
					
						
							|  |  |  |                 mNet->tensorName.emplace_back(reshape->name); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 std::unique_ptr<OpT> constOp(new OpT); | 
					
						
							|  |  |  |                 constOp->type          = OpType_Const; | 
					
						
							|  |  |  |                 auto blob              = new BlobT; | 
					
						
							|  |  |  |                 blob->int32s           = {0, 1, 2, 4, 3}; | 
					
						
							|  |  |  |                 blob->dataFormat       = MNN_DATA_FORMAT_NHWC; | 
					
						
							|  |  |  |                 blob->dataType         = DataType_DT_INT32; | 
					
						
							|  |  |  |                 blob->dims             = {5}; | 
					
						
							|  |  |  |                 constOp->main.value    = blob; | 
					
						
							|  |  |  |                 constOp->main.type     = OpParameter_Blob; | 
					
						
							|  |  |  |                 constOp->name          = op->name + "_Const"; | 
					
						
							|  |  |  |                 constOp->outputIndexes = {currentTensorCount + 2}; | 
					
						
							|  |  |  |                 mNet->tensorName.emplace_back(constOp->name); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 std::unique_ptr<OpT> permute(new OpT); | 
					
						
							|  |  |  |                 permute->type                      = OpType_Transpose; | 
					
						
							|  |  |  |                 permute->name                      = op->name + "_Transpose"; | 
					
						
							|  |  |  |                 permute->main.value                = new TransposeT; | 
					
						
							|  |  |  |                 permute->main.type                 = OpParameter_Transpose; | 
					
						
							|  |  |  |                 permute->main.AsTranspose()->Tperm = DataType_DT_INT32; | 
					
						
							|  |  |  |                 permute->inputIndexes              = {currentTensorCount + 1, currentTensorCount + 2}; | 
					
						
							|  |  |  |                 permute->outputIndexes             = {currentTensorCount + 3}; | 
					
						
							|  |  |  |                 mNet->tensorName.emplace_back(permute->name); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 std::unique_ptr<OpT> reshapeR(new OpT); | 
					
						
							|  |  |  |                 reshapeR->type                      = OpType_Reshape; | 
					
						
							|  |  |  |                 reshapeR->name                      = op->name + "_ReshapeR"; | 
					
						
							|  |  |  |                 reshapeR->main.value                = new ReshapeT; | 
					
						
							|  |  |  |                 reshapeR->main.type                 = OpParameter_Reshape; | 
					
						
							|  |  |  |                 reshapeR->main.AsReshape()->dimType = MNN_DATA_FORMAT_NHWC; | 
					
						
							|  |  |  |                 reshapeR->main.AsReshape()->dims    = {0, 0, 0, -1}; | 
					
						
							|  |  |  |                 reshapeR->inputIndexes              = {currentTensorCount + 3}; | 
					
						
							|  |  |  |                 reshapeR->outputIndexes             = {currentTensorCount + 4}; | 
					
						
							|  |  |  |                 mNet->tensorName.emplace_back(reshapeR->name); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 std::unique_ptr<OpT> convertFrom(new OpT); | 
					
						
							|  |  |  |                 convertFrom->type                               = OpType_ConvertTensor; | 
					
						
							|  |  |  |                 convertFrom->main.type                          = OpParameter_TensorConvertInfo; | 
					
						
							|  |  |  |                 convertFrom->main.value                         = new TensorConvertInfoT; | 
					
						
							|  |  |  |                 convertFrom->main.AsTensorConvertInfo()->source = MNN_DATA_FORMAT_NHWC; | 
					
						
							|  |  |  |                 convertFrom->main.AsTensorConvertInfo()->dest   = MNN_DATA_FORMAT_NC4HW4; | 
					
						
							|  |  |  |                 convertFrom->inputIndexes                       = {currentTensorCount + 4}; | 
					
						
							|  |  |  |                 convertFrom->outputIndexes                      = op->outputIndexes; | 
					
						
							|  |  |  |                 convertFrom->name                               = op->name + "_ConvertToNC4HW4"; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 iter = mNet->oplists.erase(iter); | 
					
						
							|  |  |  |                 iter = mNet->oplists.insert(iter, std::move(convertFrom)); | 
					
						
							|  |  |  |                 iter = mNet->oplists.insert(iter, std::move(reshapeR)); | 
					
						
							|  |  |  |                 iter = mNet->oplists.insert(iter, std::move(permute)); | 
					
						
							|  |  |  |                 iter = mNet->oplists.insert(iter, std::move(constOp)); | 
					
						
							|  |  |  |                 iter = mNet->oplists.insert(iter, std::move(reshape)); | 
					
						
							|  |  |  |                 iter = mNet->oplists.insert(iter, std::move(convertTo)); | 
					
						
							|  |  |  |                 iter++; | 
					
						
							|  |  |  |                 iter++; | 
					
						
							|  |  |  |                 iter++; | 
					
						
							|  |  |  |                 iter++; | 
					
						
							|  |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         iter++; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2019-08-22 20:13:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | void PostTreatUtils::convertBinaryToElementwise() { | 
					
						
							|  |  |  |     for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) { | 
					
						
							|  |  |  |         auto op = iter->get(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (op->type != MNN::OpType_BinaryOp) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto param = op->main.AsBinaryOp(); | 
					
						
							|  |  |  |         if (param->opType != BinaryOpOperation_MUL && param->opType != BinaryOpOperation_ADD && | 
					
						
							|  |  |  |             param->opType != BinaryOpOperation_SUB) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         const int inputNum = op->inputIndexes.size(); | 
					
						
							|  |  |  |         DCHECK(inputNum == 2) << "BinaryOp should have two inputs"; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         const int inputIndex0 = op->inputIndexes[0]; | 
					
						
							|  |  |  |         auto inputOp0         = _findOpByOutputIndex(inputIndex0); | 
					
						
							|  |  |  |         const int inputIndex1 = op->inputIndexes[1]; | 
					
						
							|  |  |  |         auto inputOp1         = _findOpByOutputIndex(inputIndex1); | 
					
						
							|  |  |  |         bool readyToChange    = (inputOp0->type == MNN::OpType_Convolution || inputOp0->type == MNN::OpType_Eltwise) && | 
					
						
							|  |  |  |                              (inputOp1->type == MNN::OpType_Convolution || inputOp1->type == MNN::OpType_Eltwise); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if (readyToChange) { | 
					
						
							|  |  |  |             // convert binary op to elementwise op
 | 
					
						
							|  |  |  |             auto elementParam = new MNN::EltwiseT; | 
					
						
							|  |  |  |             switch (param->opType) { | 
					
						
							|  |  |  |                 case BinaryOpOperation_MUL: | 
					
						
							|  |  |  |                     elementParam->type = EltwiseType_PROD; | 
					
						
							|  |  |  |                     break; | 
					
						
							|  |  |  |                 case BinaryOpOperation_ADD: | 
					
						
							|  |  |  |                     elementParam->type = EltwiseType_SUM; | 
					
						
							|  |  |  |                     break; | 
					
						
							|  |  |  |                 case BinaryOpOperation_SUB: | 
					
						
							|  |  |  |                     elementParam->type = EltwiseType_SUB; | 
					
						
							|  |  |  |                     break; | 
					
						
							|  |  |  |                 default: | 
					
						
							|  |  |  |                     break; | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2019-09-01 19:25:26 +08:00
										 |  |  |             op->type = MNN::OpType_Eltwise; | 
					
						
							|  |  |  |             op->main.Reset(); | 
					
						
							|  |  |  |             op->main.type  = OpParameter_Eltwise; | 
					
						
							| 
									
										
										
										
											2019-08-22 20:13:46 +08:00
										 |  |  |             op->main.value = elementParam; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } |