2019-12-27 22:16:57 +08:00
//
// NN.cpp
// MNN
//
// Created by MNN on 2019/11/25.
// Copyright © 2018, Alibaba Group Holding Limited
//
2021-04-14 17:21:30 +08:00
# include "NN.hpp"
2019-12-27 22:16:57 +08:00
# include "Distributions.hpp"
2021-04-14 17:21:30 +08:00
# include "module/PipelineModule.hpp"
# include "module/WhileModule.hpp"
# include "module/IfModule.hpp"
2022-02-18 11:30:27 +08:00
# include "module/NMSModule.hpp"
2019-12-27 22:16:57 +08:00
# include "Initializer.hpp"
2020-01-15 13:33:47 +08:00
# include "MNN_generated.h"
2019-12-27 22:16:57 +08:00
# include "RandomGenerator.hpp"
2020-01-15 13:33:47 +08:00
# include "core/Macro.h"
2022-08-12 10:30:48 +08:00
# include "math/WingoradGenerater.hpp"
# include "common/WinogradInt8Attr.hpp"
2020-02-26 09:57:17 +08:00
# include <string>
2019-12-27 22:16:57 +08:00
using namespace MNN : : Express ;
namespace MNN {
2020-11-05 16:41:56 +08:00
namespace Express {
2020-02-26 09:57:17 +08:00
static VARP _activate ( VARP x , NN : : ActivationFunctionType type ) {
switch ( type ) {
case NN : : None :
return x ;
case NN : : Relu :
return _Relu ( x ) ;
case NN : : Relu6 :
return _Relu6 ( x ) ;
default :
break ;
}
return nullptr ;
}
2019-12-27 22:16:57 +08:00
class DropoutModule : public Module {
public :
DropoutModule ( const float dropRatio ) {
mDropRatio = dropRatio ;
2020-02-26 09:57:17 +08:00
setType ( " Dropout " ) ;
2019-12-27 22:16:57 +08:00
}
virtual std : : vector < Express : : VARP > onForward ( const std : : vector < Express : : VARP > & inputs ) override {
Express : : VARP x = inputs [ 0 ] ;
if ( getIsTraining ( ) ) {
float scale = 1. / ( 1. - mDropRatio ) ;
auto mask = _Input ( x - > getInfo ( ) - > dim , x - > getInfo ( ) - > order , x - > getInfo ( ) - > type ) ;
auto maskPtr = mask - > writeMap < float > ( ) ;
auto eltSize = x - > getInfo ( ) - > size ;
Distributions : : uniform ( eltSize , 0 , 1 , maskPtr , RandomGenerator : : generator ( ) ) ;
for ( int i = 0 ; i < eltSize ; i + + ) {
maskPtr [ i ] = maskPtr [ i ] < mDropRatio ? 0.0f : scale ;
}
x = x * mask ;
}
return { x } ;
}
private :
2020-11-05 16:41:56 +08:00
DropoutModule ( ) = default ;
Module * clone ( CloneContext * ctx ) const override {
DropoutModule * module ( new DropoutModule ) ;
module - > mDropRatio = mDropRatio ;
return this - > cloneBaseTo ( ctx , module ) ;
}
2019-12-27 22:16:57 +08:00
float mDropRatio ;
} ;
class BatchNormModule : public Module {
public :
2020-02-26 09:57:17 +08:00
BatchNormModule ( EXPRP expr , const float m = 0.99 ) {
MNN_ASSERT ( expr - > get ( ) ! = nullptr ) ;
MNN_ASSERT ( expr - > get ( ) - > type ( ) = = OpType_BatchNorm ) ;
auto bnPa = expr - > get ( ) - > main_as_BatchNorm ( ) ;
2020-11-24 13:10:56 +08:00
auto & inputs = expr - > inputs ( ) ;
int dims = 4 ;
if ( ! inputs . empty ( ) ) {
auto info = inputs [ 0 ] - > getInfo ( ) ;
if ( nullptr ! = info ) {
dims = info - > dim . size ( ) ;
}
}
2020-02-26 09:57:17 +08:00
mEps = bnPa - > epsilon ( ) ;
mMomentum = m ;
mChannels = bnPa - > channels ( ) ;
2020-11-24 13:10:56 +08:00
std : : vector < int > statShape ;
std : : vector < int > reductionDims ;
int channels = mChannels ;
if ( dims = = 2 ) {
statShape = { 1 , channels } ;
mReductionDims = { 0 } ;
}
if ( dims = = 3 ) {
statShape = { 1 , channels , 1 } ;
mReductionDims = { 0 , 2 } ;
}
if ( dims = = 4 ) {
statShape = { 1 , channels , 1 , 1 } ;
mReductionDims = { 0 , 2 , 3 } ;
}
2020-02-26 09:57:17 +08:00
MNN_ASSERT ( bnPa - > biasData ( ) - > size ( ) = = mChannels ) ;
2020-11-24 13:10:56 +08:00
mBias = _TrainableParam ( bnPa - > biasData ( ) - > data ( ) , statShape , NCHW ) ;
2020-02-26 09:57:17 +08:00
MNN_ASSERT ( bnPa - > slopeData ( ) - > size ( ) = = mChannels ) ;
2020-11-24 13:10:56 +08:00
mScale = _TrainableParam ( bnPa - > slopeData ( ) - > data ( ) , statShape , NCHW ) ;
2020-02-26 09:57:17 +08:00
MNN_ASSERT ( bnPa - > meanData ( ) - > size ( ) = = mChannels ) ;
2020-11-24 13:10:56 +08:00
mRunningMean = _Const ( bnPa - > meanData ( ) - > data ( ) , statShape , NCHW ) ;
2020-02-26 09:57:17 +08:00
MNN_ASSERT ( bnPa - > meanData ( ) - > size ( ) = = mChannels ) ;
2020-11-24 13:10:56 +08:00
mRunningVariance = _Const ( bnPa - > varData ( ) - > data ( ) , statShape , NCHW ) ;
2020-02-26 09:57:17 +08:00
addParameter ( mScale ) ;
addParameter ( mBias ) ;
2020-11-05 16:41:56 +08:00
mRunningVariancePos = addParameter ( mRunningVariance ) ;
mRunningMeanPos = addParameter ( mRunningMean ) ;
2020-11-24 13:10:56 +08:00
2020-02-26 09:57:17 +08:00
setType ( " BatchNorm " ) ;
}
BatchNormModule ( const int channels , const int dims = 4 , const float m = 0.99 , const float e = 1e-5 ) {
2019-12-27 22:16:57 +08:00
mMomentum = m ;
2020-01-15 13:33:47 +08:00
mEps = e ;
2019-12-27 22:16:57 +08:00
mChannels = channels ;
std : : vector < int > statShape ;
std : : vector < int > reductionDims ;
if ( dims = = 2 ) {
2020-02-26 09:57:17 +08:00
statShape = { 1 , channels } ;
2019-12-27 22:16:57 +08:00
mReductionDims = { 0 } ;
}
2020-11-24 13:10:56 +08:00
if ( dims = = 3 ) {
statShape = { 1 , channels , 1 } ;
mReductionDims = { 0 , 2 } ;
}
2019-12-27 22:16:57 +08:00
if ( dims = = 4 ) {
2020-02-26 09:57:17 +08:00
statShape = { 1 , channels , 1 , 1 } ;
2019-12-27 22:16:57 +08:00
mReductionDims = { 0 , 2 , 3 } ;
}
2020-02-26 09:57:17 +08:00
mScale = _TrainableParam ( 1.0f , statShape , NCHW ) ;
mBias = _TrainableParam ( 0.0f , statShape , NCHW ) ;
2020-01-15 13:33:47 +08:00
mRunningMean = _Const ( 0.0f , statShape , NCHW ) ;
2019-12-27 22:16:57 +08:00
mRunningVariance = _Const ( 0.0f , statShape , NCHW ) ;
addParameter ( mScale ) ;
addParameter ( mBias ) ;
2020-11-05 16:41:56 +08:00
mRunningVariancePos = addParameter ( mRunningVariance ) ;
mRunningMeanPos = addParameter ( mRunningMean ) ;
2020-02-26 09:57:17 +08:00
setType ( " BatchNorm " ) ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
VARP runningMean ( ) {
return mRunningMean ;
}
2019-12-27 22:16:57 +08:00
2020-02-26 09:57:17 +08:00
VARP runningVariance ( ) {
return mRunningVariance ;
}
2019-12-27 22:16:57 +08:00
2020-02-26 09:57:17 +08:00
VARP scale ( ) {
return mScale ;
}
2019-12-27 22:16:57 +08:00
2020-02-26 09:57:17 +08:00
VARP bias ( ) {
return mBias ;
}
float eps ( ) {
return mEps ;
}
2019-12-27 22:16:57 +08:00
2020-02-26 09:57:17 +08:00
virtual std : : vector < Express : : VARP > onForward ( const std : : vector < Express : : VARP > & inputs ) override {
Express : : VARP x = inputs [ 0 ] ;
auto dimFormat = x - > getInfo ( ) - > order ;
VARP outputData = nullptr ;
2019-12-27 22:16:57 +08:00
if ( getIsTraining ( ) ) {
2020-02-26 09:57:17 +08:00
if ( dimFormat = = NC4HW4 | | dimFormat = = NHWC ) {
x = _Convert ( x , NCHW ) ;
}
MNN_ASSERT ( x - > getInfo ( ) - > dim [ 1 ] = = mChannels ) ;
2020-01-15 13:33:47 +08:00
auto sampleMean = _ReduceMean ( x , mReductionDims , true ) ; // mean for each channel in the batch
2020-02-26 09:57:17 +08:00
auto xSub = x - sampleMean ;
auto sampleVar = _ReduceMean ( _Square ( xSub ) , mReductionDims ,
2020-01-15 13:33:47 +08:00
true ) ; // variance for each channel in the batch
2020-02-26 09:57:17 +08:00
auto rSampleStd = _Reciprocal ( _Sqrt ( sampleVar + _Const ( mEps ) ) ) ;
auto normalizedData = xSub * rSampleStd ;
2020-01-15 13:33:47 +08:00
outputData = normalizedData * mScale + mBias ;
2019-12-27 22:16:57 +08:00
mRunningMean = _Const ( mMomentum ) * mRunningMean + _Const ( 1 - mMomentum ) * sampleMean ;
mRunningVariance = _Const ( mMomentum ) * mRunningVariance + _Const ( 1 - mMomentum ) * sampleVar ;
2020-02-26 09:57:17 +08:00
outputData - > setName ( name ( ) ) ;
2019-12-27 22:16:57 +08:00
outputData = _Convert ( outputData , dimFormat ) ;
2020-11-05 16:41:56 +08:00
setParameter ( mRunningMean , mRunningMeanPos ) ;
setParameter ( mRunningVariance , mRunningVariancePos ) ;
2020-02-26 09:57:17 +08:00
return { outputData } ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
auto rStd = _Const ( 1.0f ) / _Sqrt ( mRunningVariance + _Const ( mEps ) ) ;
auto alpha = rStd * mScale ;
auto beta = mBias - mRunningMean * rStd * mScale ;
//outputData = (_Convert(x, NCHW) * alpha) + beta;
2020-04-11 17:53:34 +08:00
alpha . fix ( VARP : : CONSTANT ) ;
beta . fix ( VARP : : CONSTANT ) ;
2020-02-26 09:57:17 +08:00
//FUNC_PRINT_ALL(alpha->readMap<float>()[0], f);
x = _Convert ( x , NC4HW4 ) ;
std : : vector < float > scale ( alpha - > getInfo ( ) - > size ) ;
std : : vector < float > bias ( beta - > getInfo ( ) - > size ) ;
: : memcpy ( scale . data ( ) , alpha - > readMap < float > ( ) , scale . size ( ) * sizeof ( float ) ) ;
: : memcpy ( bias . data ( ) , beta - > readMap < float > ( ) , bias . size ( ) * sizeof ( float ) ) ;
outputData = _Scale ( x , mChannels , std : : move ( scale ) , std : : move ( bias ) ) ;
outputData - > setName ( name ( ) ) ;
outputData = _Convert ( outputData , dimFormat ) ;
2019-12-27 22:16:57 +08:00
return { outputData } ;
}
private :
2020-11-05 16:41:56 +08:00
BatchNormModule ( ) = default ;
Module * clone ( CloneContext * ctx ) const override {
BatchNormModule * module ( new BatchNormModule ) ;
module - > mMomentum = mMomentum ;
module - > mEps = mEps ;
module - > mScale = ctx - > getOrClone ( mScale ) ;
module - > mBias = ctx - > getOrClone ( mBias ) ;
module - > mRunningMean = ctx - > getOrClone ( mRunningMean ) ;
module - > mRunningVariance = ctx - > getOrClone ( mRunningVariance ) ;
module - > mRunningMeanPos = mRunningMeanPos ;
module - > mRunningVariancePos = mRunningVariancePos ;
module - > mChannels = mChannels ;
module - > mReductionDims = mReductionDims ;
return this - > cloneBaseTo ( ctx , module ) ;
}
2020-02-26 09:57:17 +08:00
float mMomentum = 0.99 ;
2020-01-15 13:33:47 +08:00
float mEps = 1e-5 ;
VARP mScale = nullptr ;
VARP mBias = nullptr ;
VARP mRunningMean = nullptr ;
2019-12-27 22:16:57 +08:00
VARP mRunningVariance = nullptr ;
2020-11-05 16:41:56 +08:00
int mRunningMeanPos = - 1 ;
int mRunningVariancePos = - 1 ;
2019-12-27 22:16:57 +08:00
int mChannels ;
std : : vector < int > mReductionDims ;
} ;
void NN : : ConvOption : : reset ( int size ) {
stride = std : : vector < int > ( size , 1 ) ;
channel = std : : vector < int > ( size , 0 ) ;
kernelSize = std : : vector < int > ( size , 1 ) ;
dilate = std : : vector < int > ( size , 1 ) ;
padMode = VALID ;
pads = std : : vector < int > ( size , 0 ) ;
depthwise = false ;
2020-02-26 09:57:17 +08:00
fusedActivationFunction = None ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
class ConvModule : public Module {
public :
ConvModule ( const NN : : ConvParameters & parameters ) {
mParameter = parameters ;
if ( nullptr ! = mParameter . bias ) {
addParameter ( mParameter . bias ) ;
}
if ( nullptr ! = mParameter . weight ) {
addParameter ( mParameter . weight ) ;
}
setName ( parameters . name ) ;
setType ( " Conv " ) ;
}
NN : : ConvParameters & convParameters ( ) {
return mParameter ;
}
virtual std : : vector < VARP > onForward ( const std : : vector < VARP > & inputs ) override {
auto input = inputs [ 0 ] ;
auto & option = mParameter . option ;
if ( getIsTraining ( ) ) {
auto tempOutput = _Conv ( mParameter . weight , mParameter . bias , _Convert ( input , NC4HW4 ) , option . padMode , option . stride , option . dilate , mParameter . group , mParameter . option . pads ) ;
tempOutput - > setName ( name ( ) ) ;
tempOutput = _activate ( tempOutput , option . fusedActivationFunction ) ;
return { tempOutput } ;
}
bool relu = option . fusedActivationFunction = = NN : : Relu ;
bool relu6 = option . fusedActivationFunction = = NN : : Relu6 ;
std : : vector < float > weight ;
std : : vector < float > bias ;
{
auto weightInfo = mParameter . weight - > getInfo ( ) ;
weight . resize ( weightInfo - > size ) ;
: : memcpy ( weight . data ( ) , mParameter . weight - > readMap < float > ( ) , weight . size ( ) * sizeof ( float ) ) ;
}
{
bias . resize ( mParameter . option . channel [ 1 ] ) ;
if ( nullptr ! = mParameter . bias ) {
: : memcpy ( bias . data ( ) , mParameter . bias - > readMap < float > ( ) , bias . size ( ) * sizeof ( float ) ) ;
} else {
: : memset ( bias . data ( ) , 0 , bias . size ( ) * sizeof ( float ) ) ;
}
}
auto tempOutput = _Conv ( std : : move ( weight ) , std : : move ( bias ) , _Convert ( input , NC4HW4 ) , option . channel , option . kernelSize , option . padMode , option . stride , option . dilate , mParameter . group , mParameter . option . pads , relu , relu6 ) ;
tempOutput - > setName ( name ( ) ) ;
return { tempOutput } ;
}
2020-11-05 16:41:56 +08:00
2020-02-26 09:57:17 +08:00
private :
2020-11-05 16:41:56 +08:00
ConvModule ( ) = default ;
Module * clone ( CloneContext * ctx ) const override {
ConvModule * module ( new ConvModule ) ;
module - > mParameter = mParameter ;
module - > mParameter . weight = ctx - > getOrClone ( mParameter . weight ) ;
module - > mParameter . bias = ctx - > getOrClone ( mParameter . bias ) ;
return this - > cloneBaseTo ( ctx , module ) ;
}
2020-02-26 09:57:17 +08:00
NN : : ConvParameters mParameter ;
} ;
2019-12-27 22:16:57 +08:00
static std : : tuple < VARP , VARP , int > _initParameters ( const NN : : ConvOption & option , bool hasBias ,
std : : shared_ptr < Initializer > weightInit ,
std : : shared_ptr < Initializer > biasInit ) {
std : : tuple < VARP , VARP , int > defaultRes ;
if ( nullptr = = weightInit ) {
weightInit . reset ( Initializer : : xavier ( ) ) ;
}
if ( nullptr = = biasInit ) {
biasInit . reset ( Initializer : : constValue ( 0.0f ) ) ;
}
VARP weight ;
int group = 1 ;
if ( option . depthwise ) {
if ( option . channel [ 1 ] ! = option . channel [ 0 ] ) {
MNN_ERROR ( " Can't support not the same channel for convolution depthwise \n " ) ;
return defaultRes ;
}
2020-02-26 09:57:17 +08:00
weight = weightInit - > createConstVar ( { option . channel [ 0 ] , 1 , option . kernelSize [ 1 ] , option . kernelSize [ 0 ] } , NCHW ) ;
weight . fix ( VARP : : TRAINABLE ) ;
2019-12-27 22:16:57 +08:00
group = option . channel [ 0 ] ;
} else {
weight = weightInit - > createConstVar (
{ option . channel [ 1 ] , option . channel [ 0 ] , option . kernelSize [ 1 ] , option . kernelSize [ 0 ] } , NCHW ) ;
2020-02-26 09:57:17 +08:00
weight . fix ( VARP : : TRAINABLE ) ;
2019-12-27 22:16:57 +08:00
}
VARP bias ;
if ( hasBias ) {
bias = biasInit - > createConstVar ( { option . channel [ 1 ] } , NCHW ) ;
2020-02-26 09:57:17 +08:00
bias . fix ( VARP : : TRAINABLE ) ;
2019-12-27 22:16:57 +08:00
}
return std : : make_tuple ( weight , bias , group ) ;
}
2020-02-26 09:57:17 +08:00
Module * NN : : ConvTranspose ( const ConvOption & option , bool hasBias ,
2019-12-27 22:16:57 +08:00
std : : shared_ptr < Initializer > weightInit ,
std : : shared_ptr < Initializer > biasInit ) {
2022-12-30 15:18:58 +08:00
VARP input = _Input ( { 1 , option . channel [ 0 ] , - 1 , - 1 } , NC4HW4 ) ;
2019-12-27 22:16:57 +08:00
auto tuple = _initParameters ( option , hasBias , weightInit , biasInit ) ;
auto weight = std : : get < 0 > ( tuple ) ;
if ( nullptr = = weight ) {
return nullptr ;
}
if ( ! option . depthwise ) {
weight = _Transpose ( weight , { 1 , 0 , 2 , 3 } ) ;
weight . fix ( VARP : : TRAINABLE ) ;
}
auto bias = std : : get < 1 > ( tuple ) ;
auto group = std : : get < 2 > ( tuple ) ;
if ( nullptr ! = bias ) {
auto tempOutput = _Deconv ( weight , bias , input , option . padMode , option . stride , option . dilate , group ) ;
2020-02-26 09:57:17 +08:00
tempOutput = _activate ( tempOutput , option . fusedActivationFunction ) ;
2021-04-14 17:21:30 +08:00
return NN : : extract ( { input } , { tempOutput } , true ) ;
2019-12-27 22:16:57 +08:00
}
auto tempOutput = _Deconv ( weight , nullptr , input , option . padMode , option . stride , option . dilate , group ) ;
2020-02-26 09:57:17 +08:00
tempOutput = _activate ( tempOutput , option . fusedActivationFunction ) ;
2021-04-14 17:21:30 +08:00
return NN : : extract ( { input } , { tempOutput } , true ) ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
Module * NN : : Conv ( const ConvOption & option , bool hasBias , std : : shared_ptr < Initializer > weightInit ,
2019-12-27 22:16:57 +08:00
std : : shared_ptr < Initializer > biasInit ) {
auto tuple = _initParameters ( option , hasBias , weightInit , biasInit ) ;
2020-02-26 09:57:17 +08:00
ConvParameters parameters ;
parameters . weight = std : : get < 0 > ( tuple ) ;
if ( nullptr = = parameters . weight ) {
2019-12-27 22:16:57 +08:00
return nullptr ;
}
2020-02-26 09:57:17 +08:00
parameters . bias = std : : get < 1 > ( tuple ) ;
parameters . group = std : : get < 2 > ( tuple ) ;
parameters . option = option ;
return new ConvModule ( parameters ) ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
Module * NN : : Linear ( int l , int t , bool hasBias , std : : shared_ptr < Initializer > weightInit ,
2019-12-27 22:16:57 +08:00
std : : shared_ptr < Initializer > biasInit ) {
if ( nullptr = = weightInit ) {
weightInit . reset ( Initializer : : xavier ( ) ) ;
}
if ( nullptr = = biasInit ) {
biasInit . reset ( Initializer : : constValue ( 0.0f ) ) ;
}
auto weight = weightInit - > createConstVar ( { t , l } , NCHW ) ;
2020-02-26 09:57:17 +08:00
weight . fix ( VARP : : TRAINABLE ) ;
2019-12-27 22:16:57 +08:00
auto input = _Input ( { l } , NCHW ) ;
auto output = _MatMul ( input , weight , false , true ) ;
if ( ! hasBias ) {
2021-04-14 17:21:30 +08:00
return NN : : extract ( { input } , { output } , true ) ;
2019-12-27 22:16:57 +08:00
}
auto bias = biasInit - > createConstVar ( { 1 , t } , NCHW ) ;
2020-02-26 09:57:17 +08:00
bias . fix ( VARP : : TRAINABLE ) ;
2019-12-27 22:16:57 +08:00
output = _Add ( output , bias ) ;
2021-04-14 17:21:30 +08:00
auto module = NN : : extract ( { input } , { output } , true ) ;
2020-02-26 09:57:17 +08:00
module - > setType ( " Linear " ) ;
return module ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
Module * NN : : Dropout ( const float dropRatio ) {
return new DropoutModule ( dropRatio ) ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
Module * NN : : BatchNorm ( const int channels , const int dims , const float m , const float e ) {
return new BatchNormModule ( channels , dims , m , e ) ;
2020-01-15 13:33:47 +08:00
}
2020-02-26 09:57:17 +08:00
NN : : ConvParameters NN : : Utils : : ExtractConvolution ( EXPRP source ) {
ConvParameters _default ;
2020-01-15 13:33:47 +08:00
if ( source - > get ( ) = = nullptr ) {
return _default ;
}
if ( source - > get ( ) - > type ( ) ! = OpType_Convolution & & source - > get ( ) - > type ( ) ! = OpType_ConvolutionDepthwise ) {
return _default ;
}
auto conv2D = source - > get ( ) - > main_as_Convolution2D ( ) ;
NN : : ConvOption option ;
option . kernelSize = { conv2D - > common ( ) - > kernelX ( ) , conv2D - > common ( ) - > kernelY ( ) } ;
option . stride = { conv2D - > common ( ) - > strideX ( ) , conv2D - > common ( ) - > strideY ( ) } ;
2020-02-26 09:57:17 +08:00
if ( nullptr ! = conv2D - > common ( ) - > pads ( ) ) {
option . pads . resize ( conv2D - > common ( ) - > pads ( ) - > size ( ) ) ;
for ( int i = 0 ; i < option . pads . size ( ) ; + + i ) {
option . pads [ i ] = conv2D - > common ( ) - > pads ( ) - > data ( ) [ i ] ;
}
} else {
option . pads = { conv2D - > common ( ) - > padX ( ) , conv2D - > common ( ) - > padY ( ) } ;
}
2020-01-15 13:33:47 +08:00
switch ( conv2D - > common ( ) - > padMode ( ) ) {
case MNN : : PadMode_SAME :
option . padMode = SAME ;
break ;
case MNN : : PadMode_VALID :
option . padMode = VALID ;
break ;
2020-02-26 09:57:17 +08:00
case MNN : : PadMode_CAFFE :
option . padMode = CAFFE ;
break ;
2020-01-15 13:33:47 +08:00
default :
break ;
}
option . dilate = { conv2D - > common ( ) - > dilateX ( ) , conv2D - > common ( ) - > dilateY ( ) } ;
option . depthwise = source - > get ( ) - > type ( ) = = OpType_ConvolutionDepthwise ;
2020-07-04 01:21:30 +08:00
auto inputCount = conv2D - > common ( ) - > inputCount ( ) ;
if ( 0 = = inputCount ) {
auto inputInfo = source - > inputs ( ) [ 0 ] - > getInfo ( ) ;
if ( nullptr ! = inputInfo ) {
if ( NHWC = = inputInfo - > order ) {
inputCount = source - > inputs ( ) [ 0 ] - > getInfo ( ) - > dim [ 3 ] ;
} else {
inputCount = source - > inputs ( ) [ 0 ] - > getInfo ( ) - > dim [ 1 ] ;
}
} else {
if ( nullptr = = conv2D - > weight ( ) ) {
MNN_ERROR ( " Can't extract convolution \n " ) ;
return _default ;
}
auto weightCount = conv2D - > weight ( ) - > size ( ) ;
if ( option . depthwise ) {
inputCount = conv2D - > common ( ) - > outputCount ( ) ;
} else {
inputCount = weightCount / conv2D - > common ( ) - > kernelX ( ) / conv2D - > common ( ) - > kernelY ( ) / conv2D - > common ( ) - > outputCount ( ) ;
}
}
}
option . channel = { inputCount , conv2D - > common ( ) - > outputCount ( ) } ;
2020-01-15 13:33:47 +08:00
int group = 1 ;
2020-02-26 09:57:17 +08:00
if ( option . depthwise ) {
2020-01-15 13:33:47 +08:00
group = conv2D - > common ( ) - > outputCount ( ) ;
}
2020-02-26 09:57:17 +08:00
VARP weight ;
auto inputs = source - > inputs ( ) ;
if ( inputs . size ( ) > 1 ) {
weight = inputs [ 1 ] ;
}
2020-01-15 13:33:47 +08:00
VARP bias ;
if ( inputs . size ( ) > 2 ) {
bias = inputs [ 2 ] ;
}
2020-02-26 09:57:17 +08:00
if ( inputs . size ( ) < 2 ) {
// Extract Weight And Bias from Conv2D
if ( conv2D - > weight ( ) = = nullptr | | conv2D - > bias ( ) = = nullptr ) {
return _default ;
}
bias = _TrainableParam ( conv2D - > bias ( ) - > data ( ) , { option . channel [ 1 ] } , NCHW ) ;
weight = _TrainableParam ( conv2D - > weight ( ) - > data ( ) , { option . channel [ 1 ] , option . channel [ 0 ] / group , option . kernelSize [ 1 ] , option . kernelSize [ 0 ] } , NCHW ) ;
}
_default . option = std : : move ( option ) ;
_default . weight = std : : move ( weight ) ;
_default . bias = std : : move ( bias ) ;
_default . group = group ;
if ( conv2D - > common ( ) - > relu ( ) ) {
_default . option . fusedActivationFunction = NN : : Relu ;
}
if ( conv2D - > common ( ) - > relu6 ( ) ) {
_default . option . fusedActivationFunction = NN : : Relu6 ;
}
_default . name = source - > name ( ) ;
return _default ;
2020-01-15 13:33:47 +08:00
}
2020-02-26 09:57:17 +08:00
Module * NN : : Conv ( const ConvParameters & parameter ) {
return new ConvModule ( parameter ) ;
}
2020-11-05 16:41:56 +08:00
Module * NN : : Utils : : ExtractNotRunableOp ( Express : : EXPRP expr , const std : : map < std : : string , SubGraph > & subgraphs ) {
2020-02-26 09:57:17 +08:00
if ( nullptr = = expr - > get ( ) ) {
2020-01-15 13:33:47 +08:00
return nullptr ;
}
2020-02-26 09:57:17 +08:00
if ( expr - > get ( ) - > type ( ) = = OpType_BatchNorm ) {
return new BatchNormModule ( expr ) ;
}
if ( expr - > get ( ) - > type ( ) = = OpType_Dropout ) {
return new DropoutModule ( 0.3f ) ;
2020-01-15 13:33:47 +08:00
}
2020-02-26 09:57:17 +08:00
return nullptr ;
2020-01-15 13:33:47 +08:00
}
2020-02-26 09:57:17 +08:00
class ConvBNReluFusedModule : public Module {
public :
ConvBNReluFusedModule ( std : : vector < std : : shared_ptr < Module > > modules ,
NN : : FeatureScaleStatMethod featureScaleStatMethod ,
2022-08-12 10:30:48 +08:00
NN : : ScaleUpdateMethod scaleUpdateMethod , const int bits , bool winograd = false ) {
2020-02-26 09:57:17 +08:00
MNN_ASSERT ( modules . size ( ) > = 1 ) ;
MNN_ASSERT ( modules [ 0 ] - > type ( ) = = " Conv " ) ;
if ( modules . size ( ) = = 3 ) {
MNN_ASSERT ( modules [ 1 ] - > type ( ) = = " BatchNorm " ) ;
MNN_ASSERT ( modules [ 2 ] - > type ( ) = = " ReLU " | | modules [ 2 ] - > type ( ) = = " ReLU6 " ) ;
}
for ( int i = 0 ; i < modules . size ( ) ; i + + ) {
auto type = modules [ i ] - > type ( ) ;
if ( type = = " Conv " ) {
mConvParameter = std : : static_pointer_cast < ConvModule > ( modules [ i ] ) - > convParameters ( ) ;
mOption = mConvParameter . option ;
mGroup = mConvParameter . group ;
mWeight = mConvParameter . weight ;
mBias = mConvParameter . bias ;
if ( nullptr ! = mWeight ) {
addParameter ( mWeight ) ;
}
if ( nullptr ! = mBias ) {
addParameter ( mBias ) ;
}
2022-08-12 10:30:48 +08:00
if ( winograd & & mOption . kernelSize [ 0 ] > 1 & & mOption . kernelSize [ 1 ] > 1
& & mOption . stride [ 0 ] = = 1 & & mOption . stride [ 1 ] = = 1
2022-08-31 20:11:16 +08:00
& & mOption . dilate [ 0 ] = = 1 & & mOption . dilate [ 1 ] = = 1 & & mGroup = = 1 ) {
2022-08-12 10:30:48 +08:00
mWinogradAttr . reset ( new WinogradInt8Attr ) ;
mWinogradTransInputMaxPos = addParameter ( mWinogradTransInputMax ) ;
mWinogradTransInputMinPos = addParameter ( mWinogradTransInputMin ) ;
2023-08-21 14:51:54 +08:00
mWinogradTransWeightScalePos = addParameter ( mWinogradTransInputMax ) ;
2022-08-12 10:30:48 +08:00
}
2020-02-26 09:57:17 +08:00
setName ( mConvParameter . name ) ;
modules [ i ] = nullptr ;
} else if ( type = = " BatchNorm " ) {
mBatchNorm = modules [ i ] ;
registerModel ( { mBatchNorm } ) ;
} else if ( type = = " ReLU " ) {
mActivation = NN : : Relu ;
modules [ i ] = nullptr ;
} else if ( type = = " ReLU6 " ) {
mActivation = NN : : Relu6 ;
modules [ i ] = nullptr ;
} else {
MNN_ASSERT ( false ) ;
}
}
if ( mOption . fusedActivationFunction = = NN : : Relu | | mOption . fusedActivationFunction = = NN : : Relu6 ) {
mActivation = mOption . fusedActivationFunction ;
}
2021-04-08 15:34:23 +08:00
mFeatureScaleStatMethod = NN : : PerTensor ;
2020-02-26 09:57:17 +08:00
mScaleUpdateMethod = scaleUpdateMethod ;
2021-01-06 16:29:37 +08:00
mBits = bits ;
2021-04-08 15:34:23 +08:00
mLimit = ( float ) ( 1 < < ( bits - 1 ) ) - 1.0f ;
mLimitScale = _Scalar < float > ( 1.0f / mLimit ) ;
mWeightClampValue = _Scalar < float > ( mLimit ) ;
2021-04-28 18:02:10 +08:00
// mInputClampValue = _Scalar<float>(mLimit);
// mOutputClampValue = _Scalar<float>(mLimit);
// lower bits only apply to weights
mInputClampValue = _Scalar < float > ( ( float ) ( 1 < < ( 8 - 1 ) ) - 1.0f ) ;
mOutputClampValue = _Scalar < float > ( ( float ) ( 1 < < ( 8 - 1 ) ) - 1.0f ) ;
2020-11-05 16:41:56 +08:00
2021-04-08 15:34:23 +08:00
mInputMinPos = addParameter ( mInputMin ) ;
mInputMaxPos = addParameter ( mInputMax ) ;
mOutputMinPos = addParameter ( mOutputMin ) ;
mOutputMaxPos = addParameter ( mOutputMax ) ;
2020-02-26 09:57:17 +08:00
setType ( " ConvBNReluFused " ) ;
}
2021-04-08 15:34:23 +08:00
std : : pair < VARP , VARP > computeScaleAndZeroPoint ( VARP min , VARP max , VARP clampVar ) {
MNN_ASSERT ( ( ! ( min = = nullptr ) ) ) ;
MNN_ASSERT ( ( ! ( max = = nullptr ) ) ) ;
min = _Minimum ( _Scalar < float > ( 0.0f ) , min ) ;
max = _Maximum ( _Scalar < float > ( 0.0f ) , max ) ;
auto scale = ( max - min ) / ( _Scalar ( 2.0f ) * clampVar ) ;
auto zeroPoint = _Round ( ( _Scalar ( 0.0f ) - min ) / scale - clampVar ) ;
return std : : make_pair ( scale , zeroPoint ) ;
}
2022-08-12 10:30:48 +08:00
std : : vector < VARP > fakeQuantFeatureWithMinMax ( VARP x , VARP useMin , VARP useMax , VARP clampVar , INTS axis = { } ) {
2020-02-26 09:57:17 +08:00
auto originFormat = x - > getInfo ( ) - > order ;
auto tempX = x ;
if ( originFormat = = NC4HW4 ) {
tempX = _Convert ( tempX , NCHW ) ;
}
auto originX = tempX ;
2021-04-08 15:34:23 +08:00
VARP min , max ;
2022-08-31 20:11:16 +08:00
bool keepDims = false ;
if ( axis . size ( ) > 0 ) {
// PerChannel for winograd
keepDims = true ;
}
min = _ReduceMin ( tempX , axis , keepDims ) ;
max = _ReduceMax ( tempX , axis , keepDims ) ;
2021-04-08 15:34:23 +08:00
VARP scale , zeroPoint ;
VARP nudgeMin , nudgeMax ;
if ( ! ( useMin = = nullptr ) ) {
MNN_ASSERT ( ! ( useMax = = nullptr ) ) ;
auto scaleAndZeroPoint = computeScaleAndZeroPoint ( useMin , useMax , clampVar ) ;
scale = scaleAndZeroPoint . first ;
zeroPoint = scaleAndZeroPoint . second ;
2020-02-26 09:57:17 +08:00
} else {
2021-04-08 15:34:23 +08:00
auto scaleAndZeroPoint = computeScaleAndZeroPoint ( min , max , clampVar ) ;
scale = scaleAndZeroPoint . first ;
zeroPoint = scaleAndZeroPoint . second ;
2020-02-26 09:57:17 +08:00
}
2021-04-08 15:34:23 +08:00
float limit = clampVar - > readMap < float > ( ) [ 0 ] ;
nudgeMin = ( _Scalar < float > ( - limit ) - zeroPoint ) * scale ;
nudgeMax = ( _Scalar < float > ( limit ) - zeroPoint ) * scale ;
nudgeMin = _Minimum ( _Scalar < float > ( 0.0f ) , nudgeMin ) ;
nudgeMax = _Maximum ( _Scalar < float > ( 0.0f ) , nudgeMax ) ;
auto quantX = clamp ( _Round ( tempX / scale + zeroPoint ) , clampVar ) ;
tempX = scale * ( quantX - zeroPoint ) ;
2020-11-05 16:41:56 +08:00
// Break the grad by use cast
tempX = _Cast < float > ( tempX ) ;
// Move grad from tempX to originX
2020-02-26 09:57:17 +08:00
tempX = _Convert ( tempX + _ZeroGrad ( originX ) , originFormat ) ;
2021-04-08 15:34:23 +08:00
return { tempX , nudgeMin , nudgeMax } ;
2020-02-26 09:57:17 +08:00
}
2021-04-08 15:34:23 +08:00
VARP clamp ( VARP x , VARP clampVar ) {
return _Maximum ( _Minimum ( x , clampVar ) , _Negative ( clampVar ) ) ;
2020-02-26 09:57:17 +08:00
}
2021-04-08 15:34:23 +08:00
VARP updateParameter ( VARP originValue , VARP newValue ) const {
2020-02-26 09:57:17 +08:00
if ( nullptr = = originValue ) {
return newValue ;
}
2023-08-21 14:51:54 +08:00
auto ptr = originValue - > readMap < float > ( ) ;
if ( ptr [ 0 ] = = - 100.0f ) {
return newValue ;
}
2020-02-26 09:57:17 +08:00
switch ( mScaleUpdateMethod ) {
case NN : : MovingAverage :
return originValue * _Scalar < float > ( mMomentum ) + newValue * _Scalar < float > ( 1.0f - mMomentum ) ;
case NN : : Maximum :
return _Maximum ( originValue , newValue ) ;
default :
break ;
}
MNN_ASSERT ( false ) ;
return nullptr ;
}
2022-08-12 10:30:48 +08:00
bool bestWinogradUnit ( const VARP x , int * unitH = nullptr , int * unitW = nullptr ) {
if ( x - > getInfo ( ) = = nullptr ) {
return false ;
}
int kernelW = mOption . kernelSize [ 0 ] , kernelH = mOption . kernelSize [ 1 ] , padW = mOption . pads [ 0 ] , padH = mOption . pads [ 1 ] ;
int outH = x - > getInfo ( ) - > dim [ 2 ] + 2 * padH - kernelH + 1 , outW = x - > getInfo ( ) - > dim [ 3 ] + 2 * padW - kernelW + 1 ;
int inChannel = mOption . channel [ 0 ] , outChannel = mOption . channel [ 1 ] ;
2022-08-31 20:11:16 +08:00
int threadNumber = 1 , ePack = 12 ;
2022-08-12 10:30:48 +08:00
int unit2 = UP_DIV ( outH * outW , ePack * threadNumber ) ;
int maxUnit = ( int ) : : sqrtf ( ( float ) unit2 ) ;
const int MAX_UNIT = 4 , MIN_UNIT = 2 ;
maxUnit = std : : max ( std : : min ( maxUnit , MAX_UNIT ) , MIN_UNIT ) ;
auto units = std : : pair < int , int > ( { 0 , 0 } ) ;
2023-08-21 14:51:54 +08:00
float maxRate = 2.0f , originCost = outH * outW * inChannel * outChannel * kernelH * kernelW ;
2022-08-12 10:30:48 +08:00
std : : set < int > supportSu { 4 , 6 } ;
for ( int uh = MIN_UNIT ; uh < = maxUnit ; + + uh ) {
for ( int uw = MIN_UNIT ; uw < = maxUnit ; + + uw ) {
auto alphaH = uh + kernelH - 1 , alphaW = uw + kernelW - 1 ;
if ( supportSu . find ( alphaH ) = = supportSu . end ( ) | | supportSu . find ( alphaW ) = = supportSu . end ( ) ) {
continue ;
}
float winogradCost =
( 2 * alphaH * alphaW * inChannel + alphaH * alphaW * inChannel * outChannel + ( alphaH * alphaW + uh * alphaW ) * outChannel ) * ( UP_DIV ( outW , uw ) * UP_DIV ( outH , uh ) ) ;
float reduceRate = originCost / winogradCost ;
if ( reduceRate > maxRate ) {
maxRate = reduceRate ;
units = std : : pair < int , int > ( { uh , uw } ) ;
}
}
}
if ( units . first = = 0 | | units . second = = 0 ) {
return false ;
}
if ( unitH ! = nullptr & & unitW ! = nullptr ) {
* unitH = units . first ;
* unitW = units . second ;
}
return true ;
}
VARP _winogradConv ( const VARP x , const VARP weight ) {
auto inDims = x - > getInfo ( ) - > dim ;
int batch = inDims [ 0 ] , inH = inDims [ 2 ] , inW = inDims [ 3 ] ;
int inChannel = mOption . channel [ 0 ] , outChannel = mOption . channel [ 1 ] ;
int kernelW = mOption . kernelSize [ 0 ] , kernelH = mOption . kernelSize [ 1 ] , padW = mOption . pads [ 0 ] , padH = mOption . pads [ 1 ] ;
int outH = inH + 2 * padH - kernelH + 1 , outW = inW + 2 * padW - kernelW + 1 ;
int unitH , unitW ;
bestWinogradUnit ( x , & unitH , & unitW ) ;
if ( mWinogradAttr - > attrs . empty ( ) ) {
mWinogradAttr - > add ( 0 , 0 , kernelH , kernelW , unitH , unitW ) ;
}
if ( unitH ! = mWinogradAttr - > attrs [ 0 ] . unitY | | unitW ! = mWinogradAttr - > attrs [ 0 ] . unitX ) {
MNN_ERROR ( " Winograd Conv not support variable input shape \n " ) ;
return nullptr ;
}
int alphaH = unitH + kernelH - 1 , alphaW = unitW + kernelW - 1 ;
int unitNumH = UP_DIV ( outH , unitH ) , unitNumW = UP_DIV ( outW , unitW ) ;
int needH = unitNumH * unitH + kernelH - 1 , needW = unitNumW * unitW + kernelW - 1 ;
2022-08-31 20:11:16 +08:00
int paddings [ ] = { 0 , 0 , 0 , 0 , padH , needH - inH - padH , padW , needW - inW - padW } ;
2022-08-12 10:30:48 +08:00
auto xx = _Pad ( x , _Const ( paddings , { 8 } , NCHW , halide_type_of < int32_t > ( ) ) ) ;
// [ic * alphaH * alphaW, N * h_unit_num * w_unit_num]
xx = _Im2Col ( xx , { alphaW , alphaH } , { 1 , 1 } , { 0 , 0 } , { unitW , unitH } ) ;
// [N * h_unit_num * w_unit_num, ic, alphaH, alphaW]
xx = _Transpose ( _Reshape ( xx , { inChannel , alphaH , alphaW , - 1 } ) , { 3 , 0 , 1 , 2 } ) ;
2023-02-15 10:30:27 +08:00
// Must be the same as ConvInt8Winograd.cpp
Math : : WinogradGenerater genH ( unitH , kernelH , 1 , true ) , genW ( unitW , kernelW , 1 , true ) ;
2022-08-12 10:30:48 +08:00
auto srcTransH = _Const ( genH . B ( ) - > host < void > ( ) , { alphaH , alphaH } , NCHW ) ;
auto srcTransW = _Const ( genW . B ( ) - > host < void > ( ) , { alphaW , alphaW } , NCHW ) ;
2022-08-31 20:11:16 +08:00
xx = _MatMul ( _MatMul ( _Transpose ( srcTransH , { 1 , 0 } ) , xx ) , srcTransW ) ;
2022-08-12 10:30:48 +08:00
// [alphaH * alphaW, ic, N * h_unit_num * w_unit_num]
xx = _Reshape ( _Transpose ( xx , { 2 , 3 , 1 , 0 } ) , { alphaH * alphaW , inChannel , - 1 } ) ;
auto inputPair = fakeQuantFeatureWithMinMax ( xx , nullptr , nullptr , mInputClampValue , { 1 , 2 , 3 } ) ;
mWinogradTransInputMin = updateParameter ( mWinogradTransInputMin , inputPair [ 1 ] ) ;
mWinogradTransInputMax = updateParameter ( mWinogradTransInputMax , inputPair [ 2 ] ) ;
setParameter ( mWinogradTransInputMin , mWinogradTransInputMinPos ) ;
setParameter ( mWinogradTransInputMax , mWinogradTransInputMaxPos ) ;
auto wTransH = _Const ( genH . G ( ) - > host < void > ( ) , { alphaH , kernelH } , NCHW ) ;
auto wTransW = _Const ( genW . G ( ) - > host < void > ( ) , { alphaW , kernelW } , NCHW ) ;
// [oc, ic, alphaH, alphaW]
2022-08-31 20:11:16 +08:00
auto ww = _MatMul ( _MatMul ( wTransH , weight ) , _Transpose ( wTransW , { 1 , 0 } ) ) ;
2022-08-12 10:30:48 +08:00
// [alphaH * alphaW, oc, ic]
ww = _Transpose ( _Reshape ( ww , { outChannel , inChannel , - 1 } ) , { 2 , 0 , 1 } ) ;
2023-02-15 10:30:27 +08:00
auto wwInfo = ww - > getInfo ( ) ;
2022-08-12 10:30:48 +08:00
// simulate weight quant
2023-02-15 10:30:27 +08:00
auto weightScale = _Maximum ( _ReduceMax ( _Abs ( ww ) , { 2 } , true ) , _Scalar < float > ( 1E-6 ) ) * _Reciprocal ( mWeightClampValue ) ;
2022-08-31 20:11:16 +08:00
// ww = clamp(_Round(ww * _Reciprocal(weightScale)), mWeightClampValue) * weightScale;
2022-08-12 10:30:48 +08:00
setParameter ( weightScale , mWinogradTransWeightScalePos ) ;
// [alphaH * alphaW, oc, N * h_unit_num * w_unit_num]
auto yy = _MatMul ( ww , xx ) ;
// [oc, N * h_unit_num * w_unit_num, alphaH, alphaW]
yy = _Reshape ( _Transpose ( yy , { 1 , 2 , 0 } ) , { outChannel , - 1 , alphaH , alphaW } ) ;
auto dstTransH = _Const ( genH . A ( ) - > host < void > ( ) , { alphaH , unitH } , NCHW ) ;
auto dstTransW = _Const ( genW . A ( ) - > host < void > ( ) , { alphaW , unitW } , NCHW ) ;
// [oc, N * h_unit_num * w_unit_num, unitH, unitW]
2022-08-31 20:11:16 +08:00
yy = _MatMul ( _MatMul ( _Transpose ( dstTransH , { 1 , 0 } ) , yy ) , dstTransW ) ;
2022-08-12 10:30:48 +08:00
// [N, oc, h_unit_num * unitH, w_unit_num * unitW]
yy = _Reshape ( _Transpose ( _Reshape ( yy , { outChannel , batch , unitNumH , unitNumW , unitH , unitW } ) , { 1 , 0 , 2 , 4 , 3 , 5 } ) , { batch , outChannel , unitNumH * unitH , unitNumW * unitW } ) ;
int sliceStartData [ ] = { 0 , 0 , 0 , 0 } , sliceEndData [ ] = { - 1 , - 1 , outH , outW } ;
yy = _Slice ( yy , _Const ( sliceStartData , { 4 } , NCHW ) , _Const ( sliceEndData , { 4 } , NCHW ) ) ;
// TODO: add operator!= to VARP
if ( ! ( mBias = = nullptr ) ) {
yy = yy + _Reshape ( mBias , { 1 , - 1 , 1 , 1 } ) ;
}
return yy ;
}
2020-02-26 09:57:17 +08:00
virtual std : : vector < Express : : VARP > onForward ( const std : : vector < Express : : VARP > & inputs ) override {
VARP res ;
if ( getIsTraining ( ) ) {
auto x = _Convert ( inputs [ 0 ] , NCHW ) ;
// simulate weight quant
2021-04-08 15:34:23 +08:00
auto weightScale = _Maximum ( _ReduceMax ( _Abs ( mWeight ) , { 1 , 2 , 3 } , true ) , _Scalar < float > ( 1E-6 ) ) * _Reciprocal ( mWeightClampValue ) ;
auto weightTemp = clamp ( _Round ( mWeight * _Reciprocal ( weightScale ) ) , mWeightClampValue ) * weightScale ;
2020-02-26 09:57:17 +08:00
weightTemp = weightTemp + _ZeroGrad ( mWeight ) ;
// simulate input quant to get original input scale
2021-04-08 15:34:23 +08:00
auto inputPair = fakeQuantFeatureWithMinMax ( x , nullptr , nullptr , mInputClampValue ) ;
mInputMin = updateParameter ( mInputMin , inputPair [ 1 ] ) ;
mInputMax = updateParameter ( mInputMax , inputPair [ 2 ] ) ;
setParameter ( mInputMin , mInputMinPos ) ;
setParameter ( mInputMax , mInputMaxPos ) ;
2020-02-26 09:57:17 +08:00
// simulate output quant to get original output scale
2022-08-12 10:30:48 +08:00
if ( mWinogradAttr ! = nullptr & & bestWinogradUnit ( x ) ) {
res = _winogradConv ( x , weightTemp ) ;
2023-02-15 10:30:27 +08:00
# ifdef MNN_WINOGRAD_DEBUG
VARP res2 = _Conv ( weightTemp , mBias , _Convert ( inputPair [ 0 ] , NC4HW4 ) , mOption . padMode , mOption . stride ,
mOption . dilate , mGroup , mOption . pads ) ;
auto diff = res2 - res ;
diff = diff * diff ;
FUNC_PRINT_ALL ( _ReduceMax ( diff ) - > readMap < float > ( ) [ 0 ] , f ) ;
# endif
2022-08-12 10:30:48 +08:00
} else {
res = _Conv ( weightTemp , mBias , _Convert ( inputPair [ 0 ] , NC4HW4 ) , mOption . padMode , mOption . stride ,
mOption . dilate , mGroup , mOption . pads ) ;
}
2020-02-26 09:57:17 +08:00
res - > setName ( name ( ) ) ;
if ( mBatchNorm ) {
res = mBatchNorm - > forward ( res ) ;
}
res = _activate ( res , mActivation ) ;
2021-04-08 15:34:23 +08:00
auto outputPair = fakeQuantFeatureWithMinMax ( res , nullptr , nullptr , mOutputClampValue ) ;
mOutputMin = updateParameter ( mOutputMin , outputPair [ 1 ] ) ;
mOutputMax = updateParameter ( mOutputMax , outputPair [ 2 ] ) ;
setParameter ( mOutputMin , mOutputMinPos ) ;
setParameter ( mOutputMax , mOutputMaxPos ) ;
res = outputPair [ 0 ] ;
2020-02-26 09:57:17 +08:00
} else {
2021-04-08 15:34:23 +08:00
if ( nullptr = = mInputMin ) {
2020-02-26 09:57:17 +08:00
// Initial for test
// simulate weight quant
2021-04-08 15:34:23 +08:00
auto weightScale = _Maximum ( _ReduceMax ( _Abs ( mWeight ) , { 1 , 2 , 3 } , true ) , _Scalar < float > ( 1E-6 ) ) * _Reciprocal ( mWeightClampValue ) ;
auto weightTemp = clamp ( _Round ( mWeight * _Reciprocal ( weightScale ) ) , mWeightClampValue ) * weightScale ;
2020-02-26 09:57:17 +08:00
auto x = _Convert ( inputs [ 0 ] , NCHW ) ;
2021-04-08 15:34:23 +08:00
auto inputPair = fakeQuantFeatureWithMinMax ( x , nullptr , nullptr , mInputClampValue ) ;
mInputMin = updateParameter ( mInputMin , inputPair [ 1 ] ) ;
mInputMax = updateParameter ( mInputMax , inputPair [ 2 ] ) ;
setParameter ( mInputMin , mInputMinPos ) ;
setParameter ( mInputMax , mInputMaxPos ) ;
2022-08-12 10:30:48 +08:00
VARP simuRes ;
if ( mWinogradAttr ! = nullptr & & bestWinogradUnit ( x ) ) {
simuRes = _winogradConv ( x , weightTemp ) ;
} else {
simuRes = _Conv ( weightTemp , mBias , _Convert ( inputPair [ 0 ] , NC4HW4 ) , mOption . padMode , mOption . stride ,
2020-02-26 09:57:17 +08:00
mOption . dilate , mGroup , mOption . pads ) ;
2022-08-12 10:30:48 +08:00
}
2020-02-26 09:57:17 +08:00
if ( mBatchNorm ) {
simuRes = mBatchNorm - > forward ( simuRes ) ;
}
simuRes = _activate ( simuRes , mActivation ) ;
Variable : : prepareCompute ( { simuRes } ) ;
2021-04-08 15:34:23 +08:00
auto outputPair = fakeQuantFeatureWithMinMax ( simuRes , nullptr , nullptr , mOutputClampValue ) ;
mOutputMin = updateParameter ( mOutputMin , outputPair [ 1 ] ) ;
mOutputMax = updateParameter ( mOutputMax , outputPair [ 2 ] ) ;
setParameter ( mOutputMin , mOutputMinPos ) ;
setParameter ( mOutputMax , mOutputMaxPos ) ;
2020-02-26 09:57:17 +08:00
}
// fold bn to conv weights and bias
VARP fusedWeights = mWeight ;
VARP fusedBias = mBias ;
fusedBias = _Reshape ( fusedBias , { fusedBias - > getInfo ( ) - > size , 1 , 1 , 1 } ) ;
if ( mBatchNorm ) {
auto bn = std : : static_pointer_cast < BatchNormModule > ( mBatchNorm ) ;
auto bnMean = bn - > runningMean ( ) ;
auto bnVar = bn - > runningVariance ( ) ;
auto bnScale = bn - > scale ( ) ;
auto bnBias = bn - > bias ( ) ;
auto bnEps = bn - > eps ( ) ;
MNN_ASSERT ( bnMean - > getInfo ( ) - > dim . size ( ) = = 4 ) ;
auto rStd = _Const ( 1.0f ) / _Sqrt ( bnVar + _Const ( bnEps ) ) ;
auto alpha = rStd * bnScale ;
auto beta = bnBias - bnMean * rStd * bnScale ;
alpha = _Reshape ( alpha , { alpha - > getInfo ( ) - > size , 1 , 1 , 1 } ) ;
beta = _Reshape ( beta , { beta - > getInfo ( ) - > size , 1 , 1 , 1 } ) ;
fusedWeights = alpha * fusedWeights ;
fusedBias = alpha * fusedBias + beta ;
}
auto x = _Convert ( inputs [ 0 ] , NC4HW4 ) ;
2021-04-08 15:34:23 +08:00
int8_t inputZeroPoint , outputZeroPoint ;
2020-02-26 09:57:17 +08:00
{
2021-04-08 15:34:23 +08:00
VARP channelScale , zeroPoint ;
auto scaleAndZeroPoint = computeScaleAndZeroPoint ( mInputMin , mInputMax , mInputClampValue ) ;
mInputScale = scaleAndZeroPoint . first ;
mInputZeroPoint = scaleAndZeroPoint . second ;
// always PerTensor
channelScale = _Reciprocal ( mInputScale ) ;
zeroPoint = _Cast < int8_t > ( mInputZeroPoint ) ;
inputZeroPoint = zeroPoint - > readMap < int8_t > ( ) [ 0 ] ;
x = _FloatToInt8 ( x , channelScale , - int8_t ( mInputClampValue - > readMap < float > ( ) [ 0 ] ) , int8_t ( mInputClampValue - > readMap < float > ( ) [ 0 ] ) , inputZeroPoint ) ;
}
{
VARP channelScale , zeroPoint ;
auto scaleAndZeroPoint = computeScaleAndZeroPoint ( mOutputMin , mOutputMax , mOutputClampValue ) ;
mOutputScale = scaleAndZeroPoint . first ;
mOutputZeroPoint = scaleAndZeroPoint . second ;
// always PerTensor
channelScale = mOutputScale ;
zeroPoint = _Cast < int8_t > ( mOutputZeroPoint ) ;
outputZeroPoint = zeroPoint - > readMap < int8_t > ( ) [ 0 ] ;
2020-02-26 09:57:17 +08:00
}
std : : vector < int8_t > weight ;
2021-04-08 14:24:07 +08:00
std : : vector < float > bias ;
std : : vector < float > weightScaleVector ;
2020-02-26 09:57:17 +08:00
{
VARP weightScale , quanWeight , convScale ;
2021-04-08 14:24:07 +08:00
// auto newWeight = fusedWeights * mInputScale;
weightScale = _Maximum ( _ReduceMax ( _Abs ( fusedWeights ) , { 1 , 2 , 3 } , true ) , _Scalar < float > ( 1E-6 ) ) * mLimitScale ;
2021-04-28 18:02:10 +08:00
quanWeight = _Cast < int8_t > ( clamp ( _Round ( fusedWeights * _Reciprocal ( weightScale ) ) , mWeightClampValue ) ) ;
2021-04-08 14:24:07 +08:00
convScale = _Reciprocal ( mOutputScale ) * weightScale * mInputScale ;
2021-04-08 15:34:23 +08:00
Variable : : prepareCompute ( { quanWeight , convScale } ) ;
2021-04-08 14:24:07 +08:00
// // reference for how to get quantized bias
// auto remains = _ReduceSum(_Cast<int32_t>(mInputZeroPoint) * _Cast<int32_t>(quanWeight), {1, 2, 3}, true);
// MNN_ASSERT((mOutputZeroPoint->getInfo()->dim.size() == 0) && (mOutputZeroPoint->getInfo()->size == 1)); // only support per-tensor, per-channel is removed.
// auto outputZeroPointFused = _Cast<int32_t>(_Cast<float>(mOutputZeroPoint) * _Reciprocal(convScale));
// auto quanBias = _Cast<int32_t>(fusedBias * _Reciprocal(weightScale * mInputScale)) - remains + outputZeroPointFused;
2021-04-08 15:34:23 +08:00
2020-02-26 09:57:17 +08:00
{
auto info = quanWeight - > getInfo ( ) ;
weight . resize ( info - > size ) ;
auto ptr = quanWeight - > readMap < int8_t > ( ) ;
: : memcpy ( weight . data ( ) , ptr , weight . size ( ) * sizeof ( int8_t ) ) ;
}
{
2021-04-08 14:24:07 +08:00
auto biasinfo = fusedBias - > getInfo ( ) ;
2020-02-26 09:57:17 +08:00
bias . resize ( biasinfo - > size ) ;
2021-04-08 14:24:07 +08:00
auto ptr = fusedBias - > readMap < float > ( ) ;
: : memcpy ( bias . data ( ) , ptr , bias . size ( ) * sizeof ( float ) ) ;
auto info = weightScale - > getInfo ( ) ;
weightScaleVector . resize ( info - > size ) ;
MNN_ASSERT ( weightScaleVector . size ( ) = = bias . size ( ) ) ;
auto ptrScale = weightScale - > readMap < float > ( ) ;
: : memcpy ( weightScaleVector . data ( ) , ptrScale , weightScaleVector . size ( ) * sizeof ( float ) ) ;
2020-02-26 09:57:17 +08:00
}
}
bool relu = mActivation = = NN : : None ? false : true ;
2021-04-08 14:24:07 +08:00
res = _Conv ( std : : move ( weight ) , std : : move ( bias ) , std : : move ( weightScaleVector ) , _Convert ( x , NC4HW4 ) , mOption . channel ,
2021-04-08 15:34:23 +08:00
mOption . kernelSize , mOption . padMode , mOption . stride , mOption . dilate , mGroup , mOption . pads , relu ,
2021-04-08 14:24:07 +08:00
mInputScale - > readMap < float > ( ) [ 0 ] , mOutputScale - > readMap < float > ( ) [ 0 ] ,
2021-04-08 15:34:23 +08:00
inputZeroPoint , outputZeroPoint ,
2021-04-08 14:24:07 +08:00
- int8_t ( mOutputClampValue - > readMap < float > ( ) [ 0 ] ) , int8_t ( mOutputClampValue - > readMap < float > ( ) [ 0 ] ) , mWeightClampValue - > readMap < float > ( ) [ 0 ] , mAccumulateToInt16 ) ;
2022-08-12 10:30:48 +08:00
if ( mWinogradAttr ! = nullptr & & ! mWinogradAttr - > attrs . empty ( ) ) {
auto scaleAndZeroPoint = computeScaleAndZeroPoint ( mWinogradTransInputMin , mWinogradTransInputMax , mInputClampValue ) ;
auto inputScaleVar = scaleAndZeroPoint . first ;
auto inputZeroPointVar = scaleAndZeroPoint . second ;
auto weightScaleVar = parameters ( ) [ mWinogradTransWeightScalePos ] ;
// Winograd Transformed input scale
auto inputScaleInfo = inputScaleVar - > getInfo ( ) ;
auto inputScaleData = inputScaleVar - > readMap < float > ( ) ;
if ( inputScaleInfo = = nullptr | | inputScaleData = = nullptr ) {
MNN_ERROR ( " Error for WinogradConvModule, trans input scale not ready \n " ) ;
return { } ;
}
std : : vector < float > inputScales ( inputScaleData , inputScaleData + inputScaleInfo - > size ) ;
// Winograd Transformed input zero point
2022-08-31 20:11:16 +08:00
inputZeroPointVar = _Cast < int32_t > ( inputZeroPointVar ) ;
2022-08-12 10:30:48 +08:00
auto inputZeroPointInfo = inputZeroPointVar - > getInfo ( ) ;
auto inputZeroPointData = inputZeroPointVar - > readMap < int32_t > ( ) ;
if ( inputZeroPointInfo = = nullptr | | inputZeroPointData = = nullptr ) {
MNN_ERROR ( " Error for WinogradConvModule, trans input zero point not ready \n " ) ;
return { } ;
}
std : : vector < int32_t > inputZeroPoints ( inputZeroPointData , inputZeroPointData + inputZeroPointInfo - > size ) ;
// Winograd Transformed weight scale
auto weightScaleInfo = weightScaleVar - > getInfo ( ) ;
auto weightScaleData = weightScaleVar - > readMap < float > ( ) ;
if ( weightScaleInfo = = nullptr | | weightScaleData = = nullptr ) {
MNN_ERROR ( " Error for WinogradConvModule, trans input scale not ready \n " ) ;
return { } ;
}
std : : vector < float > weightScales ( weightScaleData , weightScaleData + weightScaleInfo - > size ) ;
mWinogradAttr - > attrs [ 0 ] . inputScales = inputScales ;
mWinogradAttr - > attrs [ 0 ] . inputZeroPoints = inputZeroPoints ;
mWinogradAttr - > attrs [ 0 ] . weightScales = weightScales ;
res = mWinogradAttr - > turnToWinogradConv ( res ) ;
}
2020-02-26 09:57:17 +08:00
res - > setName ( name ( ) ) ;
2021-04-08 15:34:23 +08:00
// always PerTensor
res = _Int8ToFloat ( res , mOutputScale , outputZeroPoint ) ;
2020-02-26 09:57:17 +08:00
}
return { res } ;
}
private :
2020-11-05 16:41:56 +08:00
ConvBNReluFusedModule ( ) = default ;
Module * clone ( CloneContext * ctx ) const override {
ConvBNReluFusedModule * module ( new ConvBNReluFusedModule ) ;
module - > mConvParameter = mConvParameter ;
module - > mConvParameter . weight = ctx - > getOrClone ( mConvParameter . weight ) ;
module - > mConvParameter . bias = ctx - > getOrClone ( mConvParameter . bias ) ;
module - > mOption = mOption ;
module - > mGroup = mGroup ;
module - > mWeight = ctx - > getOrClone ( mWeight ) ;
module - > mBias = ctx - > getOrClone ( mBias ) ;
module - > mActivation = mActivation ;
2021-01-06 16:29:37 +08:00
module - > mBits = mBits ;
2021-04-08 15:34:23 +08:00
module - > mLimit = mLimit ;
2020-11-05 16:41:56 +08:00
module - > mLimitScale = ctx - > getOrClone ( mLimitScale ) ;
2021-04-08 15:34:23 +08:00
module - > mWeightClampValue = ctx - > getOrClone ( mWeightClampValue ) ;
2020-11-05 16:41:56 +08:00
module - > mInputScale = ctx - > getOrClone ( mInputScale ) ;
module - > mOutputScale = ctx - > getOrClone ( mOutputScale ) ;
2021-04-08 15:34:23 +08:00
module - > mInputMin = ctx - > getOrClone ( mInputMin ) ;
module - > mInputMax = ctx - > getOrClone ( mInputMax ) ;
module - > mOutputMin = ctx - > getOrClone ( mOutputMin ) ;
module - > mOutputMax = ctx - > getOrClone ( mOutputMax ) ;
module - > mInputZeroPoint = ctx - > getOrClone ( mInputZeroPoint ) ;
module - > mOutputZeroPoint = ctx - > getOrClone ( mOutputZeroPoint ) ;
module - > mInputMinPos = mInputMinPos ;
module - > mInputMaxPos = mInputMaxPos ;
module - > mOutputMinPos = mOutputMinPos ;
module - > mOutputMaxPos = mOutputMaxPos ;
module - > mInputClampValue = ctx - > getOrClone ( mInputClampValue ) ;
module - > mOutputClampValue = ctx - > getOrClone ( mOutputClampValue ) ;
2020-11-05 16:41:56 +08:00
module - > mMomentum = mMomentum ;
module - > mFeatureScaleStatMethod = mFeatureScaleStatMethod ;
module - > mScaleUpdateMethod = mScaleUpdateMethod ;
if ( mBatchNorm ) {
module - > mBatchNorm . reset ( mBatchNorm - > clone ( ctx ) ) ;
module - > registerModel ( { module - > mBatchNorm } ) ;
}
2022-08-12 10:30:48 +08:00
module - > mWinogradAttr = mWinogradAttr ;
module - > mWinogradTransInputMin = ctx - > getOrClone ( mWinogradTransInputMin ) ;
module - > mWinogradTransInputMax = ctx - > getOrClone ( mWinogradTransInputMax ) ;
module - > mWinogradTransInputMinPos = mWinogradTransInputMinPos ;
module - > mWinogradTransInputMaxPos = mWinogradTransInputMaxPos ;
module - > mWinogradTransWeightScalePos = mWinogradTransWeightScalePos ;
2020-11-05 16:41:56 +08:00
return this - > cloneBaseTo ( ctx , module ) ;
}
2020-02-26 09:57:17 +08:00
NN : : ConvParameters mConvParameter ;
NN : : ConvOption mOption ;
int mGroup ;
VARP mWeight ;
VARP mBias ;
NN : : ActivationFunctionType mActivation = NN : : ActivationFunctionType : : None ;
std : : shared_ptr < Module > mBatchNorm = nullptr ;
2021-01-06 16:29:37 +08:00
int mBits ;
2021-04-08 15:34:23 +08:00
float mLimit ;
2020-02-26 09:57:17 +08:00
VARP mLimitScale ;
2021-04-08 15:34:23 +08:00
Express : : VARP mWeightClampValue ;
2020-02-26 09:57:17 +08:00
VARP mInputScale = nullptr ;
VARP mOutputScale = nullptr ;
2021-04-08 15:34:23 +08:00
VARP mInputMin = nullptr ;
VARP mInputMax = nullptr ;
VARP mOutputMin = nullptr ;
VARP mOutputMax = nullptr ;
VARP mInputZeroPoint = nullptr ;
VARP mOutputZeroPoint = nullptr ;
int mInputMinPos = - 1 ;
int mInputMaxPos = - 1 ;
int mOutputMinPos = - 1 ;
int mOutputMaxPos = - 1 ;
VARP mInputClampValue ;
VARP mOutputClampValue ;
2020-02-26 09:57:17 +08:00
float mMomentum = 0.99f ;
NN : : FeatureScaleStatMethod mFeatureScaleStatMethod ;
NN : : ScaleUpdateMethod mScaleUpdateMethod ;
2021-04-08 15:34:23 +08:00
bool mAccumulateToInt16 = false ;
2022-08-12 10:30:48 +08:00
std : : shared_ptr < WinogradInt8Attr > mWinogradAttr ;
2023-08-21 14:51:54 +08:00
VARP mWinogradTransInputMin = _Const ( - 100.f ) ;
VARP mWinogradTransInputMax = _Const ( - 100.f ) ;
2022-08-12 10:30:48 +08:00
int mWinogradTransInputMinPos = - 1 ;
int mWinogradTransInputMaxPos = - 1 ;
int mWinogradTransWeightScalePos = - 1 ;
2020-02-26 09:57:17 +08:00
} ;
Module * NN : : ConvBNReluFused ( std : : vector < std : : shared_ptr < Module > > modules ,
NN : : FeatureScaleStatMethod featureScaleStatMethod ,
2022-08-12 10:30:48 +08:00
NN : : ScaleUpdateMethod scaleUpdateMethod , const int bits , bool winograd ) {
return new ConvBNReluFusedModule ( modules , featureScaleStatMethod , scaleUpdateMethod , bits , winograd ) ;
2020-02-26 09:57:17 +08:00
}
Module * NN : : ConvInt8 ( const ConvOption & option , int bits , bool hasBias ,
std : : shared_ptr < Initializer > weightInit , std : : shared_ptr < Initializer > biasInit , NN : : FeatureScaleStatMethod featureMethod , NN : : ScaleUpdateMethod method ) {
std : : shared_ptr < Module > conv ( NN : : Conv ( option ) ) ;
return new ConvBNReluFusedModule ( { conv } , featureMethod , method , bits ) ;
}
Module * NN : : ConvInt8 ( const ConvParameters & para , int bits , NN : : FeatureScaleStatMethod featureMethod , NN : : ScaleUpdateMethod method ) {
std : : shared_ptr < Module > conv ( NN : : Conv ( para ) ) ;
return new ConvBNReluFusedModule ( { conv } , featureMethod , method , bits ) ;
2020-01-15 13:33:47 +08:00
}
2022-08-12 10:30:48 +08:00
bool NN : : turnQuantize ( Module * module , const int bits , NN : : FeatureScaleStatMethod featureScaleStatMethod , NN : : ScaleUpdateMethod scaleUpdateMethod , bool winogradOpt ) {
2021-04-14 17:21:30 +08:00
if ( nullptr = = module | | module - > type ( ) ! = PIPELINE_MODULE ) {
MNN_ERROR ( " Invalide module for quantized \n " ) ;
return false ;
}
auto pipModule = static_cast < PipelineModule * > ( module ) ;
std : : vector < int > needEraseIndices ;
for ( int i = 0 ; i < pipModule - > mSubModules . size ( ) ; i + + ) {
auto & m = pipModule - > mSubModules [ i ] ;
auto & theModule = std : : get < 0 > ( m ) ;
auto moduleType = theModule - > type ( ) ;
//auto& inputIndices = std::get<1>(m);
auto & outputIndices = std : : get < 2 > ( m ) ;
if ( moduleType = = " Conv " & & i < pipModule - > mSubModules . size ( ) - 1 ) {
auto & p1 = pipModule - > mSubModules [ i + 1 ] ;
auto p1Module = std : : get < 0 > ( p1 ) ;
auto & p1ModuleType = p1Module - > type ( ) ;
auto & p1InputIndices = std : : get < 1 > ( p1 ) ;
auto & p1OutputIndices = std : : get < 2 > ( p1 ) ;
auto convOutputCount = pipModule - > countOutputReference ( outputIndices ) ;
bool convSingleOutputReference = ( ( outputIndices . size ( ) = = 1 ) & & ( convOutputCount [ 0 ] = = 1 ) ) ;
// only conv
if ( ( ! convSingleOutputReference ) | | ( p1ModuleType = = " Conv " ) | |
( p1ModuleType ! = " BatchNorm " & & p1ModuleType ! = " ReLU " & & p1ModuleType ! = " ReLU6 " ) ) {
2022-08-12 10:30:48 +08:00
theModule . reset ( NN : : ConvBNReluFused ( { theModule } , featureScaleStatMethod , scaleUpdateMethod , bits , winogradOpt ) ) ;
2021-04-14 17:21:30 +08:00
pipModule - > registerModel ( { theModule } ) ;
continue ;
}
// conv + bn + ?
if ( p1ModuleType = = " BatchNorm " ) {
bool convBnConnected = ( ( convSingleOutputReference ) & & ( p1InputIndices . size ( ) = = 1 ) & & ( p1InputIndices [ 0 ] = = outputIndices [ 0 ] ) ) ;
if ( ! convBnConnected ) {
2022-08-12 10:30:48 +08:00
theModule . reset ( NN : : ConvBNReluFused ( { theModule } , featureScaleStatMethod , scaleUpdateMethod , bits , winogradOpt ) ) ;
2021-04-14 17:21:30 +08:00
pipModule - > registerModel ( { theModule } ) ;
continue ;
}
// last conv + bn
if ( i = = pipModule - > mSubModules . size ( ) - 2 ) {
2022-08-12 10:30:48 +08:00
theModule . reset ( NN : : ConvBNReluFused ( { theModule , p1Module } , featureScaleStatMethod , scaleUpdateMethod , bits , winogradOpt ) ) ;
2021-04-14 17:21:30 +08:00
pipModule - > registerModel ( { theModule } ) ;
outputIndices = p1OutputIndices ;
needEraseIndices . emplace_back ( i + 1 ) ;
continue ;
}
// maybe there is a relu or relu6 after conv + bn
auto & p2 = pipModule - > mSubModules [ i + 2 ] ;
auto & p2Module = std : : get < 0 > ( p2 ) ;
auto p2ModuleType = p2Module - > type ( ) ;
auto & p2InputIndices = std : : get < 1 > ( p2 ) ;
auto & p2OutputIndices = std : : get < 2 > ( p2 ) ;
auto bnOutputCount = pipModule - > countOutputReference ( p1OutputIndices ) ;
bool bnSingleOutputReference = ( ( p1OutputIndices . size ( ) = = 1 ) & & ( bnOutputCount [ 0 ] = = 1 ) ) ;
// only conv + bn
if ( ( ! bnSingleOutputReference ) | | ( p2ModuleType ! = " ReLU " & & p2ModuleType ! = " ReLU6 " ) ) {
2022-08-12 10:30:48 +08:00
theModule . reset ( NN : : ConvBNReluFused ( { theModule , p1Module } , featureScaleStatMethod , scaleUpdateMethod , bits , winogradOpt ) ) ;
2021-04-14 17:21:30 +08:00
pipModule - > registerModel ( { theModule } ) ;
outputIndices = p1OutputIndices ;
needEraseIndices . emplace_back ( i + 1 ) ;
continue ;
} else { // conv + bn + relu or conv + bn + relu6
bool convBnReluConnected = ( ( bnSingleOutputReference ) & & ( p2InputIndices . size ( ) = = 1 ) & & ( p2InputIndices [ 0 ] = = p1OutputIndices [ 0 ] ) ) ;
2021-06-11 17:17:13 +08:00
bool isPrelu = false ;
if ( p2ModuleType = = " ReLU " ) {
auto p2Op = ( ( ExprModule * ) p2Module . get ( ) ) - > getExpr ( ) - > get ( ) ;
float slope = p2Op - > main_as_Relu ( ) - > slope ( ) ;
isPrelu = std : : abs ( slope ) > 1e-6 ;
}
if ( ! convBnReluConnected | | isPrelu ) {
2022-08-12 10:30:48 +08:00
theModule . reset ( NN : : ConvBNReluFused ( { theModule , p1Module } , featureScaleStatMethod , scaleUpdateMethod , bits , winogradOpt ) ) ;
2021-04-14 17:21:30 +08:00
pipModule - > registerModel ( { theModule } ) ;
outputIndices = p1OutputIndices ;
needEraseIndices . emplace_back ( i + 1 ) ;
continue ;
}
2022-08-12 10:30:48 +08:00
theModule . reset ( NN : : ConvBNReluFused ( { theModule , p1Module , p2Module } , featureScaleStatMethod , scaleUpdateMethod , bits , winogradOpt ) ) ;
2021-04-14 17:21:30 +08:00
pipModule - > registerModel ( { theModule } ) ;
outputIndices = p2OutputIndices ;
needEraseIndices . emplace_back ( i + 1 ) ;
needEraseIndices . emplace_back ( i + 2 ) ;
continue ;
}
}
// conv + relu or conv + relu6
if ( p1ModuleType = = " ReLU " | | p1ModuleType = = " ReLU6 " ) {
bool convReluConnected = ( ( convSingleOutputReference ) & & ( p1InputIndices . size ( ) = = 1 ) & & ( p1InputIndices [ 0 ] = = outputIndices [ 0 ] ) ) ;
2021-06-11 17:17:13 +08:00
bool isPrelu = false ;
if ( p1ModuleType = = " ReLU " ) {
auto p1Op = ( ( ExprModule * ) p1Module . get ( ) ) - > getExpr ( ) - > get ( ) ;
float slope = p1Op - > main_as_Relu ( ) - > slope ( ) ;
isPrelu = std : : abs ( slope ) > 1e-6 ;
}
if ( ! convReluConnected | | isPrelu ) {
2021-04-14 17:21:30 +08:00
theModule . reset ( NN : : ConvBNReluFused ( { theModule } , featureScaleStatMethod , scaleUpdateMethod , bits ) ) ;
pipModule - > registerModel ( { theModule } ) ;
continue ;
}
theModule . reset ( NN : : ConvBNReluFused ( { theModule , p1Module } , featureScaleStatMethod , scaleUpdateMethod , bits ) ) ;
pipModule - > registerModel ( { theModule } ) ;
outputIndices = p1OutputIndices ;
needEraseIndices . emplace_back ( i + 1 ) ;
continue ;
}
}
if ( i = = pipModule - > mSubModules . size ( ) - 1 & & moduleType = = " Conv " ) {
theModule . reset ( NN : : ConvBNReluFused ( { theModule } , featureScaleStatMethod , scaleUpdateMethod , bits ) ) ;
pipModule - > registerModel ( { theModule } ) ;
}
}
// erase useless submodules
const int eraseSize = needEraseIndices . size ( ) ;
int alreadyErasedCount = 0 ;
for ( int i = 0 ; i < eraseSize ; i + + ) {
auto position = needEraseIndices [ i ] - alreadyErasedCount ;
auto type = std : : get < 0 > ( pipModule - > mSubModules [ position ] ) - > type ( ) ;
MNN_ASSERT ( type = = " BatchNorm " | | type = = " ReLU " | | type = = " ReLU6 " ) ;
pipModule - > mSubModules . erase ( pipModule - > mSubModules . begin ( ) + position ) ;
alreadyErasedCount + + ;
}
return true ;
}
Module * NN : : extract ( std : : vector < Express : : VARP > inputs , std : : vector < Express : : VARP > outputs , bool fortrain , const std : : map < std : : string , SubGraph > & subGraph ) {
std : : function < std : : pair < std : : vector < int > , std : : shared_ptr < Module > > ( EXPRP ) > transformFunction ;
if ( fortrain ) {
transformFunction =
[ & subGraph ] ( EXPRP source ) {
if ( source - > get ( ) = = nullptr ) {
return std : : make_pair ( std : : vector < int > { } , std : : shared_ptr < Module > ( nullptr ) ) ;
}
std : : shared_ptr < Module > m ( NN : : Utils : : ExtractNotRunableOp ( source , subGraph ) ) ;
if ( nullptr ! = m ) {
m - > setName ( source - > name ( ) ) ;
return std : : make_pair ( std : : vector < int > { } , m ) ;
}
auto convExtracted = NN : : Utils : : ExtractConvolution ( source ) ;
if ( convExtracted . weight = = nullptr ) {
return std : : make_pair ( std : : vector < int > { } , std : : shared_ptr < Module > ( nullptr ) ) ;
}
std : : shared_ptr < Module > module ( NN : : Conv ( convExtracted ) ) ;
module - > setName ( source - > name ( ) ) ;
return std : : make_pair ( std : : vector < int > { 0 } , module ) ;
} ;
} else {
transformFunction = [ & subGraph ] ( EXPRP source ) {
if ( source - > get ( ) = = nullptr ) {
return std : : make_pair ( std : : vector < int > { } , std : : shared_ptr < Module > ( nullptr ) ) ;
}
std : : shared_ptr < Module > m ( NN : : Utils : : ExtractNotRunableOp ( source , subGraph ) ) ;
if ( nullptr ! = m ) {
m - > setName ( source - > name ( ) ) ;
return std : : make_pair ( std : : vector < int > { } , m ) ;
}
return std : : make_pair ( std : : vector < int > { } , std : : shared_ptr < Module > ( nullptr ) ) ;
} ;
}
return new PipelineModule ( inputs , outputs , transformFunction ) ;
}
2020-11-05 16:41:56 +08:00
} // namespace Express
2021-04-08 15:34:23 +08:00
} // namespace MNN