mirror of https://github.com/alibaba/MNN.git
quant weight valid range issue
This commit is contained in:
parent
5dfe97e4c8
commit
612199d0ee
|
|
@ -306,6 +306,7 @@ ConvBufExecution::ConvBufExecution(const std::vector<Tensor *> &inputs, const st
|
|||
std::string kernelName = "conv_2d_c4h1w4";
|
||||
mInputChannel = inputs[0]->channel();
|
||||
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
|
||||
if (inputs.size() != 1) {
|
||||
// Multi - Input
|
||||
mConv1x1Opt = false;
|
||||
|
|
@ -315,7 +316,6 @@ ConvBufExecution::ConvBufExecution(const std::vector<Tensor *> &inputs, const st
|
|||
mRasterExe.reset(new RasterBufExecution({virtualFilter.get()}, mOpenCLBackend));
|
||||
} else {
|
||||
int weightSize = 0;
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
|
||||
ConvolutionCommon::getConvParameters(&quanCommon, conv2dParams, &mFilterDataPtr, &weightSize);
|
||||
//select opt conv method
|
||||
mConv1x1Opt = (mKernelHeight == mKernelWidth && mKernelHeight == 1 && mPaddings[0] == 0 &&
|
||||
|
|
@ -470,10 +470,11 @@ ErrorCode ConvBufExecution::onResize(const std::vector<Tensor *> &inputs, const
|
|||
mLocalWorkSize = {retTune.first[0], retTune.first[1]};
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
|
||||
int min_index = min_cost.second;
|
||||
if(min_index >= c8_index_start) {//if best kernel is "conv_2d_1x1_c8h1w4", set weight packCout to 8
|
||||
int weightSize = 0;
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
|
||||
ConvolutionCommon::getConvParameters(&quanCommon, mConv2dParams, &mFilterDataPtr, &weightSize);
|
||||
setConv1x1WeightBuffer(8, 4, mFilterDataPtr);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue