2019-12-27 22:16:57 +08:00
//
// NN.cpp
// MNN
//
// Created by MNN on 2019/11/25.
// Copyright © 2018, Alibaba Group Holding Limited
//
2021-04-14 17:21:30 +08:00
# include "NN.hpp"
2019-12-27 22:16:57 +08:00
# include "Distributions.hpp"
2021-04-14 17:21:30 +08:00
# include "module/PipelineModule.hpp"
# include "module/WhileModule.hpp"
# include "module/IfModule.hpp"
2022-02-18 11:30:27 +08:00
# include "module/NMSModule.hpp"
2019-12-27 22:16:57 +08:00
# include "Initializer.hpp"
2020-01-15 13:33:47 +08:00
# include "MNN_generated.h"
2019-12-27 22:16:57 +08:00
# include "RandomGenerator.hpp"
2020-01-15 13:33:47 +08:00
# include "core/Macro.h"
2020-02-26 09:57:17 +08:00
# include <string>
2019-12-27 22:16:57 +08:00
using namespace MNN : : Express ;
namespace MNN {
2020-11-05 16:41:56 +08:00
namespace Express {
2020-02-26 09:57:17 +08:00
static VARP _activate ( VARP x , NN : : ActivationFunctionType type ) {
switch ( type ) {
case NN : : None :
return x ;
case NN : : Relu :
return _Relu ( x ) ;
case NN : : Relu6 :
return _Relu6 ( x ) ;
default :
break ;
}
return nullptr ;
}
2019-12-27 22:16:57 +08:00
class DropoutModule : public Module {
public :
DropoutModule ( const float dropRatio ) {
mDropRatio = dropRatio ;
2020-02-26 09:57:17 +08:00
setType ( " Dropout " ) ;
2019-12-27 22:16:57 +08:00
}
virtual std : : vector < Express : : VARP > onForward ( const std : : vector < Express : : VARP > & inputs ) override {
Express : : VARP x = inputs [ 0 ] ;
if ( getIsTraining ( ) ) {
float scale = 1. / ( 1. - mDropRatio ) ;
auto mask = _Input ( x - > getInfo ( ) - > dim , x - > getInfo ( ) - > order , x - > getInfo ( ) - > type ) ;
auto maskPtr = mask - > writeMap < float > ( ) ;
auto eltSize = x - > getInfo ( ) - > size ;
Distributions : : uniform ( eltSize , 0 , 1 , maskPtr , RandomGenerator : : generator ( ) ) ;
for ( int i = 0 ; i < eltSize ; i + + ) {
maskPtr [ i ] = maskPtr [ i ] < mDropRatio ? 0.0f : scale ;
}
x = x * mask ;
}
return { x } ;
}
private :
2020-11-05 16:41:56 +08:00
DropoutModule ( ) = default ;
Module * clone ( CloneContext * ctx ) const override {
DropoutModule * module ( new DropoutModule ) ;
module - > mDropRatio = mDropRatio ;
return this - > cloneBaseTo ( ctx , module ) ;
}
2019-12-27 22:16:57 +08:00
float mDropRatio ;
} ;
class BatchNormModule : public Module {
public :
2020-02-26 09:57:17 +08:00
BatchNormModule ( EXPRP expr , const float m = 0.99 ) {
MNN_ASSERT ( expr - > get ( ) ! = nullptr ) ;
MNN_ASSERT ( expr - > get ( ) - > type ( ) = = OpType_BatchNorm ) ;
auto bnPa = expr - > get ( ) - > main_as_BatchNorm ( ) ;
2020-11-24 13:10:56 +08:00
auto & inputs = expr - > inputs ( ) ;
int dims = 4 ;
if ( ! inputs . empty ( ) ) {
auto info = inputs [ 0 ] - > getInfo ( ) ;
if ( nullptr ! = info ) {
dims = info - > dim . size ( ) ;
}
}
2020-02-26 09:57:17 +08:00
mEps = bnPa - > epsilon ( ) ;
mMomentum = m ;
mChannels = bnPa - > channels ( ) ;
2020-11-24 13:10:56 +08:00
std : : vector < int > statShape ;
std : : vector < int > reductionDims ;
int channels = mChannels ;
if ( dims = = 2 ) {
statShape = { 1 , channels } ;
mReductionDims = { 0 } ;
}
if ( dims = = 3 ) {
statShape = { 1 , channels , 1 } ;
mReductionDims = { 0 , 2 } ;
}
if ( dims = = 4 ) {
statShape = { 1 , channels , 1 , 1 } ;
mReductionDims = { 0 , 2 , 3 } ;
}
2020-02-26 09:57:17 +08:00
MNN_ASSERT ( bnPa - > biasData ( ) - > size ( ) = = mChannels ) ;
2020-11-24 13:10:56 +08:00
mBias = _TrainableParam ( bnPa - > biasData ( ) - > data ( ) , statShape , NCHW ) ;
2020-02-26 09:57:17 +08:00
MNN_ASSERT ( bnPa - > slopeData ( ) - > size ( ) = = mChannels ) ;
2020-11-24 13:10:56 +08:00
mScale = _TrainableParam ( bnPa - > slopeData ( ) - > data ( ) , statShape , NCHW ) ;
2020-02-26 09:57:17 +08:00
MNN_ASSERT ( bnPa - > meanData ( ) - > size ( ) = = mChannels ) ;
2020-11-24 13:10:56 +08:00
mRunningMean = _Const ( bnPa - > meanData ( ) - > data ( ) , statShape , NCHW ) ;
2020-02-26 09:57:17 +08:00
MNN_ASSERT ( bnPa - > meanData ( ) - > size ( ) = = mChannels ) ;
2020-11-24 13:10:56 +08:00
mRunningVariance = _Const ( bnPa - > varData ( ) - > data ( ) , statShape , NCHW ) ;
2020-02-26 09:57:17 +08:00
addParameter ( mScale ) ;
addParameter ( mBias ) ;
2020-11-05 16:41:56 +08:00
mRunningVariancePos = addParameter ( mRunningVariance ) ;
mRunningMeanPos = addParameter ( mRunningMean ) ;
2020-11-24 13:10:56 +08:00
2020-02-26 09:57:17 +08:00
setType ( " BatchNorm " ) ;
}
BatchNormModule ( const int channels , const int dims = 4 , const float m = 0.99 , const float e = 1e-5 ) {
2019-12-27 22:16:57 +08:00
mMomentum = m ;
2020-01-15 13:33:47 +08:00
mEps = e ;
2019-12-27 22:16:57 +08:00
mChannels = channels ;
std : : vector < int > statShape ;
std : : vector < int > reductionDims ;
if ( dims = = 2 ) {
2020-02-26 09:57:17 +08:00
statShape = { 1 , channels } ;
2019-12-27 22:16:57 +08:00
mReductionDims = { 0 } ;
}
2020-11-24 13:10:56 +08:00
if ( dims = = 3 ) {
statShape = { 1 , channels , 1 } ;
mReductionDims = { 0 , 2 } ;
}
2019-12-27 22:16:57 +08:00
if ( dims = = 4 ) {
2020-02-26 09:57:17 +08:00
statShape = { 1 , channels , 1 , 1 } ;
2019-12-27 22:16:57 +08:00
mReductionDims = { 0 , 2 , 3 } ;
}
2020-02-26 09:57:17 +08:00
mScale = _TrainableParam ( 1.0f , statShape , NCHW ) ;
mBias = _TrainableParam ( 0.0f , statShape , NCHW ) ;
2020-01-15 13:33:47 +08:00
mRunningMean = _Const ( 0.0f , statShape , NCHW ) ;
2019-12-27 22:16:57 +08:00
mRunningVariance = _Const ( 0.0f , statShape , NCHW ) ;
addParameter ( mScale ) ;
addParameter ( mBias ) ;
2020-11-05 16:41:56 +08:00
mRunningVariancePos = addParameter ( mRunningVariance ) ;
mRunningMeanPos = addParameter ( mRunningMean ) ;
2020-02-26 09:57:17 +08:00
setType ( " BatchNorm " ) ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
VARP runningMean ( ) {
return mRunningMean ;
}
2019-12-27 22:16:57 +08:00
2020-02-26 09:57:17 +08:00
VARP runningVariance ( ) {
return mRunningVariance ;
}
2019-12-27 22:16:57 +08:00
2020-02-26 09:57:17 +08:00
VARP scale ( ) {
return mScale ;
}
2019-12-27 22:16:57 +08:00
2020-02-26 09:57:17 +08:00
VARP bias ( ) {
return mBias ;
}
float eps ( ) {
return mEps ;
}
2019-12-27 22:16:57 +08:00
2020-02-26 09:57:17 +08:00
virtual std : : vector < Express : : VARP > onForward ( const std : : vector < Express : : VARP > & inputs ) override {
Express : : VARP x = inputs [ 0 ] ;
auto dimFormat = x - > getInfo ( ) - > order ;
VARP outputData = nullptr ;
2019-12-27 22:16:57 +08:00
if ( getIsTraining ( ) ) {
2020-02-26 09:57:17 +08:00
if ( dimFormat = = NC4HW4 | | dimFormat = = NHWC ) {
x = _Convert ( x , NCHW ) ;
}
MNN_ASSERT ( x - > getInfo ( ) - > dim [ 1 ] = = mChannels ) ;
2020-01-15 13:33:47 +08:00
auto sampleMean = _ReduceMean ( x , mReductionDims , true ) ; // mean for each channel in the batch
2020-02-26 09:57:17 +08:00
auto xSub = x - sampleMean ;
auto sampleVar = _ReduceMean ( _Square ( xSub ) , mReductionDims ,
2020-01-15 13:33:47 +08:00
true ) ; // variance for each channel in the batch
2020-02-26 09:57:17 +08:00
auto rSampleStd = _Reciprocal ( _Sqrt ( sampleVar + _Const ( mEps ) ) ) ;
auto normalizedData = xSub * rSampleStd ;
2020-01-15 13:33:47 +08:00
outputData = normalizedData * mScale + mBias ;
2019-12-27 22:16:57 +08:00
mRunningMean = _Const ( mMomentum ) * mRunningMean + _Const ( 1 - mMomentum ) * sampleMean ;
mRunningVariance = _Const ( mMomentum ) * mRunningVariance + _Const ( 1 - mMomentum ) * sampleVar ;
2020-02-26 09:57:17 +08:00
outputData - > setName ( name ( ) ) ;
2019-12-27 22:16:57 +08:00
outputData = _Convert ( outputData , dimFormat ) ;
2020-11-05 16:41:56 +08:00
setParameter ( mRunningMean , mRunningMeanPos ) ;
setParameter ( mRunningVariance , mRunningVariancePos ) ;
2020-02-26 09:57:17 +08:00
return { outputData } ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
auto rStd = _Const ( 1.0f ) / _Sqrt ( mRunningVariance + _Const ( mEps ) ) ;
auto alpha = rStd * mScale ;
auto beta = mBias - mRunningMean * rStd * mScale ;
//outputData = (_Convert(x, NCHW) * alpha) + beta;
2020-04-11 17:53:34 +08:00
alpha . fix ( VARP : : CONSTANT ) ;
beta . fix ( VARP : : CONSTANT ) ;
2020-02-26 09:57:17 +08:00
//FUNC_PRINT_ALL(alpha->readMap<float>()[0], f);
x = _Convert ( x , NC4HW4 ) ;
std : : vector < float > scale ( alpha - > getInfo ( ) - > size ) ;
std : : vector < float > bias ( beta - > getInfo ( ) - > size ) ;
: : memcpy ( scale . data ( ) , alpha - > readMap < float > ( ) , scale . size ( ) * sizeof ( float ) ) ;
: : memcpy ( bias . data ( ) , beta - > readMap < float > ( ) , bias . size ( ) * sizeof ( float ) ) ;
outputData = _Scale ( x , mChannels , std : : move ( scale ) , std : : move ( bias ) ) ;
outputData - > setName ( name ( ) ) ;
outputData = _Convert ( outputData , dimFormat ) ;
2019-12-27 22:16:57 +08:00
return { outputData } ;
}
private :
2020-11-05 16:41:56 +08:00
BatchNormModule ( ) = default ;
Module * clone ( CloneContext * ctx ) const override {
BatchNormModule * module ( new BatchNormModule ) ;
module - > mMomentum = mMomentum ;
module - > mEps = mEps ;
module - > mScale = ctx - > getOrClone ( mScale ) ;
module - > mBias = ctx - > getOrClone ( mBias ) ;
module - > mRunningMean = ctx - > getOrClone ( mRunningMean ) ;
module - > mRunningVariance = ctx - > getOrClone ( mRunningVariance ) ;
module - > mRunningMeanPos = mRunningMeanPos ;
module - > mRunningVariancePos = mRunningVariancePos ;
module - > mChannels = mChannels ;
module - > mReductionDims = mReductionDims ;
return this - > cloneBaseTo ( ctx , module ) ;
}
2020-02-26 09:57:17 +08:00
float mMomentum = 0.99 ;
2020-01-15 13:33:47 +08:00
float mEps = 1e-5 ;
VARP mScale = nullptr ;
VARP mBias = nullptr ;
VARP mRunningMean = nullptr ;
2019-12-27 22:16:57 +08:00
VARP mRunningVariance = nullptr ;
2020-11-05 16:41:56 +08:00
int mRunningMeanPos = - 1 ;
int mRunningVariancePos = - 1 ;
2019-12-27 22:16:57 +08:00
int mChannels ;
std : : vector < int > mReductionDims ;
} ;
void NN : : ConvOption : : reset ( int size ) {
stride = std : : vector < int > ( size , 1 ) ;
channel = std : : vector < int > ( size , 0 ) ;
kernelSize = std : : vector < int > ( size , 1 ) ;
dilate = std : : vector < int > ( size , 1 ) ;
padMode = VALID ;
pads = std : : vector < int > ( size , 0 ) ;
depthwise = false ;
2020-02-26 09:57:17 +08:00
fusedActivationFunction = None ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
class ConvModule : public Module {
public :
ConvModule ( const NN : : ConvParameters & parameters ) {
mParameter = parameters ;
if ( nullptr ! = mParameter . bias ) {
addParameter ( mParameter . bias ) ;
}
if ( nullptr ! = mParameter . weight ) {
addParameter ( mParameter . weight ) ;
}
setName ( parameters . name ) ;
setType ( " Conv " ) ;
}
NN : : ConvParameters & convParameters ( ) {
return mParameter ;
}
virtual std : : vector < VARP > onForward ( const std : : vector < VARP > & inputs ) override {
auto input = inputs [ 0 ] ;
auto & option = mParameter . option ;
if ( getIsTraining ( ) ) {
auto tempOutput = _Conv ( mParameter . weight , mParameter . bias , _Convert ( input , NC4HW4 ) , option . padMode , option . stride , option . dilate , mParameter . group , mParameter . option . pads ) ;
tempOutput - > setName ( name ( ) ) ;
tempOutput = _activate ( tempOutput , option . fusedActivationFunction ) ;
return { tempOutput } ;
}
bool relu = option . fusedActivationFunction = = NN : : Relu ;
bool relu6 = option . fusedActivationFunction = = NN : : Relu6 ;
std : : vector < float > weight ;
std : : vector < float > bias ;
{
auto weightInfo = mParameter . weight - > getInfo ( ) ;
weight . resize ( weightInfo - > size ) ;
: : memcpy ( weight . data ( ) , mParameter . weight - > readMap < float > ( ) , weight . size ( ) * sizeof ( float ) ) ;
}
{
bias . resize ( mParameter . option . channel [ 1 ] ) ;
if ( nullptr ! = mParameter . bias ) {
: : memcpy ( bias . data ( ) , mParameter . bias - > readMap < float > ( ) , bias . size ( ) * sizeof ( float ) ) ;
} else {
: : memset ( bias . data ( ) , 0 , bias . size ( ) * sizeof ( float ) ) ;
}
}
auto tempOutput = _Conv ( std : : move ( weight ) , std : : move ( bias ) , _Convert ( input , NC4HW4 ) , option . channel , option . kernelSize , option . padMode , option . stride , option . dilate , mParameter . group , mParameter . option . pads , relu , relu6 ) ;
tempOutput - > setName ( name ( ) ) ;
return { tempOutput } ;
}
2020-11-05 16:41:56 +08:00
2020-02-26 09:57:17 +08:00
private :
2020-11-05 16:41:56 +08:00
ConvModule ( ) = default ;
Module * clone ( CloneContext * ctx ) const override {
ConvModule * module ( new ConvModule ) ;
module - > mParameter = mParameter ;
module - > mParameter . weight = ctx - > getOrClone ( mParameter . weight ) ;
module - > mParameter . bias = ctx - > getOrClone ( mParameter . bias ) ;
return this - > cloneBaseTo ( ctx , module ) ;
}
2020-02-26 09:57:17 +08:00
NN : : ConvParameters mParameter ;
} ;
2019-12-27 22:16:57 +08:00
static std : : tuple < VARP , VARP , int > _initParameters ( const NN : : ConvOption & option , bool hasBias ,
std : : shared_ptr < Initializer > weightInit ,
std : : shared_ptr < Initializer > biasInit ) {
std : : tuple < VARP , VARP , int > defaultRes ;
if ( nullptr = = weightInit ) {
weightInit . reset ( Initializer : : xavier ( ) ) ;
}
if ( nullptr = = biasInit ) {
biasInit . reset ( Initializer : : constValue ( 0.0f ) ) ;
}
VARP weight ;
int group = 1 ;
if ( option . depthwise ) {
if ( option . channel [ 1 ] ! = option . channel [ 0 ] ) {
MNN_ERROR ( " Can't support not the same channel for convolution depthwise \n " ) ;
return defaultRes ;
}
2020-02-26 09:57:17 +08:00
weight = weightInit - > createConstVar ( { option . channel [ 0 ] , 1 , option . kernelSize [ 1 ] , option . kernelSize [ 0 ] } , NCHW ) ;
weight . fix ( VARP : : TRAINABLE ) ;
2019-12-27 22:16:57 +08:00
group = option . channel [ 0 ] ;
} else {
weight = weightInit - > createConstVar (
{ option . channel [ 1 ] , option . channel [ 0 ] , option . kernelSize [ 1 ] , option . kernelSize [ 0 ] } , NCHW ) ;
2020-02-26 09:57:17 +08:00
weight . fix ( VARP : : TRAINABLE ) ;
2019-12-27 22:16:57 +08:00
}
VARP bias ;
if ( hasBias ) {
bias = biasInit - > createConstVar ( { option . channel [ 1 ] } , NCHW ) ;
2020-02-26 09:57:17 +08:00
bias . fix ( VARP : : TRAINABLE ) ;
2019-12-27 22:16:57 +08:00
}
return std : : make_tuple ( weight , bias , group ) ;
}
2020-02-26 09:57:17 +08:00
Module * NN : : ConvTranspose ( const ConvOption & option , bool hasBias ,
2019-12-27 22:16:57 +08:00
std : : shared_ptr < Initializer > weightInit ,
std : : shared_ptr < Initializer > biasInit ) {
VARP input = _Input ( { 1 , option . channel [ 0 ] , 1 , 1 } , NC4HW4 ) ;
auto tuple = _initParameters ( option , hasBias , weightInit , biasInit ) ;
auto weight = std : : get < 0 > ( tuple ) ;
if ( nullptr = = weight ) {
return nullptr ;
}
if ( ! option . depthwise ) {
weight = _Transpose ( weight , { 1 , 0 , 2 , 3 } ) ;
weight . fix ( VARP : : TRAINABLE ) ;
}
auto bias = std : : get < 1 > ( tuple ) ;
auto group = std : : get < 2 > ( tuple ) ;
if ( nullptr ! = bias ) {
auto tempOutput = _Deconv ( weight , bias , input , option . padMode , option . stride , option . dilate , group ) ;
2020-02-26 09:57:17 +08:00
tempOutput = _activate ( tempOutput , option . fusedActivationFunction ) ;
2021-04-14 17:21:30 +08:00
return NN : : extract ( { input } , { tempOutput } , true ) ;
2019-12-27 22:16:57 +08:00
}
auto tempOutput = _Deconv ( weight , nullptr , input , option . padMode , option . stride , option . dilate , group ) ;
2020-02-26 09:57:17 +08:00
tempOutput = _activate ( tempOutput , option . fusedActivationFunction ) ;
2021-04-14 17:21:30 +08:00
return NN : : extract ( { input } , { tempOutput } , true ) ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
Module * NN : : Conv ( const ConvOption & option , bool hasBias , std : : shared_ptr < Initializer > weightInit ,
2019-12-27 22:16:57 +08:00
std : : shared_ptr < Initializer > biasInit ) {
auto tuple = _initParameters ( option , hasBias , weightInit , biasInit ) ;
2020-02-26 09:57:17 +08:00
ConvParameters parameters ;
parameters . weight = std : : get < 0 > ( tuple ) ;
if ( nullptr = = parameters . weight ) {
2019-12-27 22:16:57 +08:00
return nullptr ;
}
2020-02-26 09:57:17 +08:00
parameters . bias = std : : get < 1 > ( tuple ) ;
parameters . group = std : : get < 2 > ( tuple ) ;
parameters . option = option ;
return new ConvModule ( parameters ) ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
Module * NN : : Linear ( int l , int t , bool hasBias , std : : shared_ptr < Initializer > weightInit ,
2019-12-27 22:16:57 +08:00
std : : shared_ptr < Initializer > biasInit ) {
if ( nullptr = = weightInit ) {
weightInit . reset ( Initializer : : xavier ( ) ) ;
}
if ( nullptr = = biasInit ) {
biasInit . reset ( Initializer : : constValue ( 0.0f ) ) ;
}
auto weight = weightInit - > createConstVar ( { t , l } , NCHW ) ;
2020-02-26 09:57:17 +08:00
weight . fix ( VARP : : TRAINABLE ) ;
2019-12-27 22:16:57 +08:00
auto input = _Input ( { l } , NCHW ) ;
auto output = _MatMul ( input , weight , false , true ) ;
if ( ! hasBias ) {
2021-04-14 17:21:30 +08:00
return NN : : extract ( { input } , { output } , true ) ;
2019-12-27 22:16:57 +08:00
}
auto bias = biasInit - > createConstVar ( { 1 , t } , NCHW ) ;
2020-02-26 09:57:17 +08:00
bias . fix ( VARP : : TRAINABLE ) ;
2019-12-27 22:16:57 +08:00
output = _Add ( output , bias ) ;
2021-04-14 17:21:30 +08:00
auto module = NN : : extract ( { input } , { output } , true ) ;
2020-02-26 09:57:17 +08:00
module - > setType ( " Linear " ) ;
return module ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
Module * NN : : Dropout ( const float dropRatio ) {
return new DropoutModule ( dropRatio ) ;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
Module * NN : : BatchNorm ( const int channels , const int dims , const float m , const float e ) {
return new BatchNormModule ( channels , dims , m , e ) ;
2020-01-15 13:33:47 +08:00
}
2020-02-26 09:57:17 +08:00
NN : : ConvParameters NN : : Utils : : ExtractConvolution ( EXPRP source ) {
ConvParameters _default ;
2020-01-15 13:33:47 +08:00
if ( source - > get ( ) = = nullptr ) {
return _default ;
}
if ( source - > get ( ) - > type ( ) ! = OpType_Convolution & & source - > get ( ) - > type ( ) ! = OpType_ConvolutionDepthwise ) {
return _default ;
}
auto conv2D = source - > get ( ) - > main_as_Convolution2D ( ) ;
NN : : ConvOption option ;
option . kernelSize = { conv2D - > common ( ) - > kernelX ( ) , conv2D - > common ( ) - > kernelY ( ) } ;
option . stride = { conv2D - > common ( ) - > strideX ( ) , conv2D - > common ( ) - > strideY ( ) } ;
2020-02-26 09:57:17 +08:00
if ( nullptr ! = conv2D - > common ( ) - > pads ( ) ) {
option . pads . resize ( conv2D - > common ( ) - > pads ( ) - > size ( ) ) ;
for ( int i = 0 ; i < option . pads . size ( ) ; + + i ) {
option . pads [ i ] = conv2D - > common ( ) - > pads ( ) - > data ( ) [ i ] ;
}
} else {
option . pads = { conv2D - > common ( ) - > padX ( ) , conv2D - > common ( ) - > padY ( ) } ;
}
2020-01-15 13:33:47 +08:00
switch ( conv2D - > common ( ) - > padMode ( ) ) {
case MNN : : PadMode_SAME :
option . padMode = SAME ;
break ;
case MNN : : PadMode_VALID :
option . padMode = VALID ;
break ;
2020-02-26 09:57:17 +08:00
case MNN : : PadMode_CAFFE :
option . padMode = CAFFE ;
break ;
2020-01-15 13:33:47 +08:00
default :
break ;
}
option . dilate = { conv2D - > common ( ) - > dilateX ( ) , conv2D - > common ( ) - > dilateY ( ) } ;
option . depthwise = source - > get ( ) - > type ( ) = = OpType_ConvolutionDepthwise ;
2020-07-04 01:21:30 +08:00
auto inputCount = conv2D - > common ( ) - > inputCount ( ) ;
if ( 0 = = inputCount ) {
auto inputInfo = source - > inputs ( ) [ 0 ] - > getInfo ( ) ;
if ( nullptr ! = inputInfo ) {
if ( NHWC = = inputInfo - > order ) {
inputCount = source - > inputs ( ) [ 0 ] - > getInfo ( ) - > dim [ 3 ] ;
} else {
inputCount = source - > inputs ( ) [ 0 ] - > getInfo ( ) - > dim [ 1 ] ;
}
} else {
if ( nullptr = = conv2D - > weight ( ) ) {
MNN_ERROR ( " Can't extract convolution \n " ) ;
return _default ;
}
auto weightCount = conv2D - > weight ( ) - > size ( ) ;
if ( option . depthwise ) {
inputCount = conv2D - > common ( ) - > outputCount ( ) ;
} else {
inputCount = weightCount / conv2D - > common ( ) - > kernelX ( ) / conv2D - > common ( ) - > kernelY ( ) / conv2D - > common ( ) - > outputCount ( ) ;
}
}
}
option . channel = { inputCount , conv2D - > common ( ) - > outputCount ( ) } ;
2020-01-15 13:33:47 +08:00
int group = 1 ;
2020-02-26 09:57:17 +08:00
if ( option . depthwise ) {
2020-01-15 13:33:47 +08:00
group = conv2D - > common ( ) - > outputCount ( ) ;
}
2020-02-26 09:57:17 +08:00
VARP weight ;
auto inputs = source - > inputs ( ) ;
if ( inputs . size ( ) > 1 ) {
weight = inputs [ 1 ] ;
}
2020-01-15 13:33:47 +08:00
VARP bias ;
if ( inputs . size ( ) > 2 ) {
bias = inputs [ 2 ] ;
}
2020-02-26 09:57:17 +08:00
if ( inputs . size ( ) < 2 ) {
// Extract Weight And Bias from Conv2D
if ( conv2D - > weight ( ) = = nullptr | | conv2D - > bias ( ) = = nullptr ) {
return _default ;
}
bias = _TrainableParam ( conv2D - > bias ( ) - > data ( ) , { option . channel [ 1 ] } , NCHW ) ;
weight = _TrainableParam ( conv2D - > weight ( ) - > data ( ) , { option . channel [ 1 ] , option . channel [ 0 ] / group , option . kernelSize [ 1 ] , option . kernelSize [ 0 ] } , NCHW ) ;
}
_default . option = std : : move ( option ) ;
_default . weight = std : : move ( weight ) ;
_default . bias = std : : move ( bias ) ;
_default . group = group ;
if ( conv2D - > common ( ) - > relu ( ) ) {
_default . option . fusedActivationFunction = NN : : Relu ;
}
if ( conv2D - > common ( ) - > relu6 ( ) ) {
_default . option . fusedActivationFunction = NN : : Relu6 ;
}
_default . name = source - > name ( ) ;
return _default ;
2020-01-15 13:33:47 +08:00
}
2020-02-26 09:57:17 +08:00
Module * NN : : Conv ( const ConvParameters & parameter ) {
return new ConvModule ( parameter ) ;
}
2020-11-05 16:41:56 +08:00
Module * NN : : Utils : : ExtractNotRunableOp ( Express : : EXPRP expr , const std : : map < std : : string , SubGraph > & subgraphs ) {
2020-02-26 09:57:17 +08:00
if ( nullptr = = expr - > get ( ) ) {
2020-01-15 13:33:47 +08:00
return nullptr ;
}
2020-02-26 09:57:17 +08:00
if ( expr - > get ( ) - > type ( ) = = OpType_BatchNorm ) {
return new BatchNormModule ( expr ) ;
}
if ( expr - > get ( ) - > type ( ) = = OpType_Dropout ) {
return new DropoutModule ( 0.3f ) ;
2020-01-15 13:33:47 +08:00
}
2020-11-05 16:41:56 +08:00
if ( expr - > get ( ) - > type ( ) = = OpType_While ) {
2022-05-06 19:51:20 +08:00
return WhileModule : : create ( expr - > get ( ) , subgraphs , nullptr ) ;
2020-11-05 16:41:56 +08:00
}
if ( expr - > get ( ) - > type ( ) = = OpType_If ) {
2022-05-06 19:51:20 +08:00
return IfModule : : create ( expr - > get ( ) , subgraphs , nullptr ) ;
2020-11-05 16:41:56 +08:00
}
2022-02-18 11:30:27 +08:00
if ( expr - > get ( ) - > type ( ) = = OpType_NonMaxSuppressionV2 ) {
2022-05-06 19:51:20 +08:00
return NMSModule : : create ( expr - > get ( ) , nullptr ) ;
2022-02-18 11:30:27 +08:00
}
2020-02-26 09:57:17 +08:00
return nullptr ;
2020-01-15 13:33:47 +08:00
}
2020-02-26 09:57:17 +08:00
class ConvBNReluFusedModule : public Module {
public :
ConvBNReluFusedModule ( std : : vector < std : : shared_ptr < Module > > modules ,
NN : : FeatureScaleStatMethod featureScaleStatMethod ,
NN : : ScaleUpdateMethod scaleUpdateMethod , const int bits ) {
MNN_ASSERT ( modules . size ( ) > = 1 ) ;
MNN_ASSERT ( modules [ 0 ] - > type ( ) = = " Conv " ) ;
if ( modules . size ( ) = = 3 ) {
MNN_ASSERT ( modules [ 1 ] - > type ( ) = = " BatchNorm " ) ;
MNN_ASSERT ( modules [ 2 ] - > type ( ) = = " ReLU " | | modules [ 2 ] - > type ( ) = = " ReLU6 " ) ;
}
for ( int i = 0 ; i < modules . size ( ) ; i + + ) {
auto type = modules [ i ] - > type ( ) ;
if ( type = = " Conv " ) {
mConvParameter = std : : static_pointer_cast < ConvModule > ( modules [ i ] ) - > convParameters ( ) ;
mOption = mConvParameter . option ;
mGroup = mConvParameter . group ;
mWeight = mConvParameter . weight ;
mBias = mConvParameter . bias ;
if ( nullptr ! = mWeight ) {
addParameter ( mWeight ) ;
}
if ( nullptr ! = mBias ) {
addParameter ( mBias ) ;
}
setName ( mConvParameter . name ) ;
modules [ i ] = nullptr ;
} else if ( type = = " BatchNorm " ) {
mBatchNorm = modules [ i ] ;
registerModel ( { mBatchNorm } ) ;
} else if ( type = = " ReLU " ) {
mActivation = NN : : Relu ;
modules [ i ] = nullptr ;
} else if ( type = = " ReLU6 " ) {
mActivation = NN : : Relu6 ;
modules [ i ] = nullptr ;
} else {
MNN_ASSERT ( false ) ;
}
}
if ( mOption . fusedActivationFunction = = NN : : Relu | | mOption . fusedActivationFunction = = NN : : Relu6 ) {
mActivation = mOption . fusedActivationFunction ;
}
2021-04-08 15:34:23 +08:00
mFeatureScaleStatMethod = NN : : PerTensor ;
2020-02-26 09:57:17 +08:00
mScaleUpdateMethod = scaleUpdateMethod ;
2021-01-06 16:29:37 +08:00
mBits = bits ;
2021-04-08 15:34:23 +08:00
mLimit = ( float ) ( 1 < < ( bits - 1 ) ) - 1.0f ;
mLimitScale = _Scalar < float > ( 1.0f / mLimit ) ;
mWeightClampValue = _Scalar < float > ( mLimit ) ;
2021-04-28 18:02:10 +08:00
// mInputClampValue = _Scalar<float>(mLimit);
// mOutputClampValue = _Scalar<float>(mLimit);
// lower bits only apply to weights
mInputClampValue = _Scalar < float > ( ( float ) ( 1 < < ( 8 - 1 ) ) - 1.0f ) ;
mOutputClampValue = _Scalar < float > ( ( float ) ( 1 < < ( 8 - 1 ) ) - 1.0f ) ;
2020-11-05 16:41:56 +08:00
2021-04-08 15:34:23 +08:00
mInputMinPos = addParameter ( mInputMin ) ;
mInputMaxPos = addParameter ( mInputMax ) ;
mOutputMinPos = addParameter ( mOutputMin ) ;
mOutputMaxPos = addParameter ( mOutputMax ) ;
2020-02-26 09:57:17 +08:00
setType ( " ConvBNReluFused " ) ;
}
2021-04-08 15:34:23 +08:00
std : : pair < VARP , VARP > computeScaleAndZeroPoint ( VARP min , VARP max , VARP clampVar ) {
MNN_ASSERT ( ( ! ( min = = nullptr ) ) ) ;
MNN_ASSERT ( ( ! ( max = = nullptr ) ) ) ;
min = _Minimum ( _Scalar < float > ( 0.0f ) , min ) ;
max = _Maximum ( _Scalar < float > ( 0.0f ) , max ) ;
auto scale = ( max - min ) / ( _Scalar ( 2.0f ) * clampVar ) ;
auto zeroPoint = _Round ( ( _Scalar ( 0.0f ) - min ) / scale - clampVar ) ;
return std : : make_pair ( scale , zeroPoint ) ;
}
std : : vector < VARP > fakeQuantFeatureWithMinMax ( VARP x , VARP useMin , VARP useMax , VARP clampVar ) {
2020-02-26 09:57:17 +08:00
auto originFormat = x - > getInfo ( ) - > order ;
auto tempX = x ;
if ( originFormat = = NC4HW4 ) {
tempX = _Convert ( tempX , NCHW ) ;
}
auto originX = tempX ;
2021-04-08 15:34:23 +08:00
VARP min , max ;
// always PerTensor
min = _ReduceMin ( tempX ) ;
max = _ReduceMax ( tempX ) ;
VARP scale , zeroPoint ;
VARP nudgeMin , nudgeMax ;
if ( ! ( useMin = = nullptr ) ) {
MNN_ASSERT ( ! ( useMax = = nullptr ) ) ;
auto scaleAndZeroPoint = computeScaleAndZeroPoint ( useMin , useMax , clampVar ) ;
scale = scaleAndZeroPoint . first ;
zeroPoint = scaleAndZeroPoint . second ;
2020-02-26 09:57:17 +08:00
} else {
2021-04-08 15:34:23 +08:00
auto scaleAndZeroPoint = computeScaleAndZeroPoint ( min , max , clampVar ) ;
scale = scaleAndZeroPoint . first ;
zeroPoint = scaleAndZeroPoint . second ;
2020-02-26 09:57:17 +08:00
}
2021-04-08 15:34:23 +08:00
float limit = clampVar - > readMap < float > ( ) [ 0 ] ;
nudgeMin = ( _Scalar < float > ( - limit ) - zeroPoint ) * scale ;
nudgeMax = ( _Scalar < float > ( limit ) - zeroPoint ) * scale ;
nudgeMin = _Minimum ( _Scalar < float > ( 0.0f ) , nudgeMin ) ;
nudgeMax = _Maximum ( _Scalar < float > ( 0.0f ) , nudgeMax ) ;
auto quantX = clamp ( _Round ( tempX / scale + zeroPoint ) , clampVar ) ;
tempX = scale * ( quantX - zeroPoint ) ;
2020-11-05 16:41:56 +08:00
// Break the grad by use cast
tempX = _Cast < float > ( tempX ) ;
// Move grad from tempX to originX
2020-02-26 09:57:17 +08:00
tempX = _Convert ( tempX + _ZeroGrad ( originX ) , originFormat ) ;
2021-04-08 15:34:23 +08:00
return { tempX , nudgeMin , nudgeMax } ;
2020-02-26 09:57:17 +08:00
}
2021-04-08 15:34:23 +08:00
VARP clamp ( VARP x , VARP clampVar ) {
return _Maximum ( _Minimum ( x , clampVar ) , _Negative ( clampVar ) ) ;
2020-02-26 09:57:17 +08:00
}
2021-04-08 15:34:23 +08:00
VARP updateParameter ( VARP originValue , VARP newValue ) const {
2020-02-26 09:57:17 +08:00
if ( nullptr = = originValue ) {
return newValue ;
}
switch ( mScaleUpdateMethod ) {
case NN : : MovingAverage :
return originValue * _Scalar < float > ( mMomentum ) + newValue * _Scalar < float > ( 1.0f - mMomentum ) ;
case NN : : Maximum :
return _Maximum ( originValue , newValue ) ;
default :
break ;
}
MNN_ASSERT ( false ) ;
return nullptr ;
}
virtual std : : vector < Express : : VARP > onForward ( const std : : vector < Express : : VARP > & inputs ) override {
VARP res ;
if ( getIsTraining ( ) ) {
auto x = _Convert ( inputs [ 0 ] , NCHW ) ;
// simulate weight quant
2021-04-08 15:34:23 +08:00
auto weightScale = _Maximum ( _ReduceMax ( _Abs ( mWeight ) , { 1 , 2 , 3 } , true ) , _Scalar < float > ( 1E-6 ) ) * _Reciprocal ( mWeightClampValue ) ;
auto weightTemp = clamp ( _Round ( mWeight * _Reciprocal ( weightScale ) ) , mWeightClampValue ) * weightScale ;
2020-02-26 09:57:17 +08:00
weightTemp = weightTemp + _ZeroGrad ( mWeight ) ;
// simulate input quant to get original input scale
2021-04-08 15:34:23 +08:00
auto inputPair = fakeQuantFeatureWithMinMax ( x , nullptr , nullptr , mInputClampValue ) ;
mInputMin = updateParameter ( mInputMin , inputPair [ 1 ] ) ;
mInputMax = updateParameter ( mInputMax , inputPair [ 2 ] ) ;
setParameter ( mInputMin , mInputMinPos ) ;
setParameter ( mInputMax , mInputMaxPos ) ;
2020-02-26 09:57:17 +08:00
// simulate output quant to get original output scale
2021-04-08 15:34:23 +08:00
res = _Conv ( weightTemp , mBias , _Convert ( inputPair [ 0 ] , NC4HW4 ) , mOption . padMode , mOption . stride ,
2020-02-26 09:57:17 +08:00
mOption . dilate , mGroup , mOption . pads ) ;
res - > setName ( name ( ) ) ;
if ( mBatchNorm ) {
res = mBatchNorm - > forward ( res ) ;
}
res = _activate ( res , mActivation ) ;
2021-04-08 15:34:23 +08:00
auto outputPair = fakeQuantFeatureWithMinMax ( res , nullptr , nullptr , mOutputClampValue ) ;
mOutputMin = updateParameter ( mOutputMin , outputPair [ 1 ] ) ;
mOutputMax = updateParameter ( mOutputMax , outputPair [ 2 ] ) ;
setParameter ( mOutputMin , mOutputMinPos ) ;
setParameter ( mOutputMax , mOutputMaxPos ) ;
res = outputPair [ 0 ] ;
2020-02-26 09:57:17 +08:00
} else {
2021-04-08 15:34:23 +08:00
if ( nullptr = = mInputMin ) {
2020-02-26 09:57:17 +08:00
// Initial for test
// simulate weight quant
2021-04-08 15:34:23 +08:00
auto weightScale = _Maximum ( _ReduceMax ( _Abs ( mWeight ) , { 1 , 2 , 3 } , true ) , _Scalar < float > ( 1E-6 ) ) * _Reciprocal ( mWeightClampValue ) ;
auto weightTemp = clamp ( _Round ( mWeight * _Reciprocal ( weightScale ) ) , mWeightClampValue ) * weightScale ;
2020-02-26 09:57:17 +08:00
auto x = _Convert ( inputs [ 0 ] , NCHW ) ;
2021-04-08 15:34:23 +08:00
auto inputPair = fakeQuantFeatureWithMinMax ( x , nullptr , nullptr , mInputClampValue ) ;
mInputMin = updateParameter ( mInputMin , inputPair [ 1 ] ) ;
mInputMax = updateParameter ( mInputMax , inputPair [ 2 ] ) ;
setParameter ( mInputMin , mInputMinPos ) ;
setParameter ( mInputMax , mInputMaxPos ) ;
auto simuRes = _Conv ( weightTemp , mBias , _Convert ( inputPair [ 0 ] , NC4HW4 ) , mOption . padMode , mOption . stride ,
2020-02-26 09:57:17 +08:00
mOption . dilate , mGroup , mOption . pads ) ;
if ( mBatchNorm ) {
simuRes = mBatchNorm - > forward ( simuRes ) ;
}
simuRes = _activate ( simuRes , mActivation ) ;
Variable : : prepareCompute ( { simuRes } ) ;
2021-04-08 15:34:23 +08:00
auto outputPair = fakeQuantFeatureWithMinMax ( simuRes , nullptr , nullptr , mOutputClampValue ) ;
mOutputMin = updateParameter ( mOutputMin , outputPair [ 1 ] ) ;
mOutputMax = updateParameter ( mOutputMax , outputPair [ 2 ] ) ;
setParameter ( mOutputMin , mOutputMinPos ) ;
setParameter ( mOutputMax , mOutputMaxPos ) ;
2020-02-26 09:57:17 +08:00
}
// fold bn to conv weights and bias
VARP fusedWeights = mWeight ;
VARP fusedBias = mBias ;
fusedBias = _Reshape ( fusedBias , { fusedBias - > getInfo ( ) - > size , 1 , 1 , 1 } ) ;
if ( mBatchNorm ) {
auto bn = std : : static_pointer_cast < BatchNormModule > ( mBatchNorm ) ;
auto bnMean = bn - > runningMean ( ) ;
auto bnVar = bn - > runningVariance ( ) ;
auto bnScale = bn - > scale ( ) ;
auto bnBias = bn - > bias ( ) ;
auto bnEps = bn - > eps ( ) ;
MNN_ASSERT ( bnMean - > getInfo ( ) - > dim . size ( ) = = 4 ) ;
auto rStd = _Const ( 1.0f ) / _Sqrt ( bnVar + _Const ( bnEps ) ) ;
auto alpha = rStd * bnScale ;
auto beta = bnBias - bnMean * rStd * bnScale ;
alpha = _Reshape ( alpha , { alpha - > getInfo ( ) - > size , 1 , 1 , 1 } ) ;
beta = _Reshape ( beta , { beta - > getInfo ( ) - > size , 1 , 1 , 1 } ) ;
fusedWeights = alpha * fusedWeights ;
fusedBias = alpha * fusedBias + beta ;
}
auto x = _Convert ( inputs [ 0 ] , NC4HW4 ) ;
2021-04-08 15:34:23 +08:00
int8_t inputZeroPoint , outputZeroPoint ;
2020-02-26 09:57:17 +08:00
{
2021-04-08 15:34:23 +08:00
VARP channelScale , zeroPoint ;
auto scaleAndZeroPoint = computeScaleAndZeroPoint ( mInputMin , mInputMax , mInputClampValue ) ;
mInputScale = scaleAndZeroPoint . first ;
mInputZeroPoint = scaleAndZeroPoint . second ;
// always PerTensor
channelScale = _Reciprocal ( mInputScale ) ;
zeroPoint = _Cast < int8_t > ( mInputZeroPoint ) ;
inputZeroPoint = zeroPoint - > readMap < int8_t > ( ) [ 0 ] ;
x = _FloatToInt8 ( x , channelScale , - int8_t ( mInputClampValue - > readMap < float > ( ) [ 0 ] ) , int8_t ( mInputClampValue - > readMap < float > ( ) [ 0 ] ) , inputZeroPoint ) ;
}
{
VARP channelScale , zeroPoint ;
auto scaleAndZeroPoint = computeScaleAndZeroPoint ( mOutputMin , mOutputMax , mOutputClampValue ) ;
mOutputScale = scaleAndZeroPoint . first ;
mOutputZeroPoint = scaleAndZeroPoint . second ;
// always PerTensor
channelScale = mOutputScale ;
zeroPoint = _Cast < int8_t > ( mOutputZeroPoint ) ;
outputZeroPoint = zeroPoint - > readMap < int8_t > ( ) [ 0 ] ;
2020-02-26 09:57:17 +08:00
}
std : : vector < int8_t > weight ;
2021-04-08 14:24:07 +08:00
std : : vector < float > bias ;
std : : vector < float > weightScaleVector ;
2020-02-26 09:57:17 +08:00
{
VARP weightScale , quanWeight , convScale ;
2021-04-08 14:24:07 +08:00
// auto newWeight = fusedWeights * mInputScale;
weightScale = _Maximum ( _ReduceMax ( _Abs ( fusedWeights ) , { 1 , 2 , 3 } , true ) , _Scalar < float > ( 1E-6 ) ) * mLimitScale ;
2021-04-28 18:02:10 +08:00
quanWeight = _Cast < int8_t > ( clamp ( _Round ( fusedWeights * _Reciprocal ( weightScale ) ) , mWeightClampValue ) ) ;
2021-04-08 14:24:07 +08:00
convScale = _Reciprocal ( mOutputScale ) * weightScale * mInputScale ;
2021-04-08 15:34:23 +08:00
Variable : : prepareCompute ( { quanWeight , convScale } ) ;
2021-04-08 14:24:07 +08:00
// // reference for how to get quantized bias
// auto remains = _ReduceSum(_Cast<int32_t>(mInputZeroPoint) * _Cast<int32_t>(quanWeight), {1, 2, 3}, true);
// MNN_ASSERT((mOutputZeroPoint->getInfo()->dim.size() == 0) && (mOutputZeroPoint->getInfo()->size == 1)); // only support per-tensor, per-channel is removed.
// auto outputZeroPointFused = _Cast<int32_t>(_Cast<float>(mOutputZeroPoint) * _Reciprocal(convScale));
// auto quanBias = _Cast<int32_t>(fusedBias * _Reciprocal(weightScale * mInputScale)) - remains + outputZeroPointFused;
2021-04-08 15:34:23 +08:00
2020-02-26 09:57:17 +08:00
{
auto info = quanWeight - > getInfo ( ) ;
weight . resize ( info - > size ) ;
auto ptr = quanWeight - > readMap < int8_t > ( ) ;
: : memcpy ( weight . data ( ) , ptr , weight . size ( ) * sizeof ( int8_t ) ) ;
}
{
2021-04-08 14:24:07 +08:00
auto biasinfo = fusedBias - > getInfo ( ) ;
2020-02-26 09:57:17 +08:00
bias . resize ( biasinfo - > size ) ;
2021-04-08 14:24:07 +08:00
auto ptr = fusedBias - > readMap < float > ( ) ;
: : memcpy ( bias . data ( ) , ptr , bias . size ( ) * sizeof ( float ) ) ;
auto info = weightScale - > getInfo ( ) ;
weightScaleVector . resize ( info - > size ) ;
MNN_ASSERT ( weightScaleVector . size ( ) = = bias . size ( ) ) ;
auto ptrScale = weightScale - > readMap < float > ( ) ;
: : memcpy ( weightScaleVector . data ( ) , ptrScale , weightScaleVector . size ( ) * sizeof ( float ) ) ;
2020-02-26 09:57:17 +08:00
}
}
bool relu = mActivation = = NN : : None ? false : true ;
2021-04-08 14:24:07 +08:00
res = _Conv ( std : : move ( weight ) , std : : move ( bias ) , std : : move ( weightScaleVector ) , _Convert ( x , NC4HW4 ) , mOption . channel ,
2021-04-08 15:34:23 +08:00
mOption . kernelSize , mOption . padMode , mOption . stride , mOption . dilate , mGroup , mOption . pads , relu ,
2021-04-08 14:24:07 +08:00
mInputScale - > readMap < float > ( ) [ 0 ] , mOutputScale - > readMap < float > ( ) [ 0 ] ,
2021-04-08 15:34:23 +08:00
inputZeroPoint , outputZeroPoint ,
2021-04-08 14:24:07 +08:00
- int8_t ( mOutputClampValue - > readMap < float > ( ) [ 0 ] ) , int8_t ( mOutputClampValue - > readMap < float > ( ) [ 0 ] ) , mWeightClampValue - > readMap < float > ( ) [ 0 ] , mAccumulateToInt16 ) ;
2020-02-26 09:57:17 +08:00
res - > setName ( name ( ) ) ;
2021-04-08 15:34:23 +08:00
// always PerTensor
res = _Int8ToFloat ( res , mOutputScale , outputZeroPoint ) ;
2020-02-26 09:57:17 +08:00
}
return { res } ;
}
private :
2020-11-05 16:41:56 +08:00
ConvBNReluFusedModule ( ) = default ;
Module * clone ( CloneContext * ctx ) const override {
ConvBNReluFusedModule * module ( new ConvBNReluFusedModule ) ;
module - > mConvParameter = mConvParameter ;
module - > mConvParameter . weight = ctx - > getOrClone ( mConvParameter . weight ) ;
module - > mConvParameter . bias = ctx - > getOrClone ( mConvParameter . bias ) ;
module - > mOption = mOption ;
module - > mGroup = mGroup ;
module - > mWeight = ctx - > getOrClone ( mWeight ) ;
module - > mBias = ctx - > getOrClone ( mBias ) ;
module - > mActivation = mActivation ;
2021-01-06 16:29:37 +08:00
module - > mBits = mBits ;
2021-04-08 15:34:23 +08:00
module - > mLimit = mLimit ;
2020-11-05 16:41:56 +08:00
module - > mLimitScale = ctx - > getOrClone ( mLimitScale ) ;
2021-04-08 15:34:23 +08:00
module - > mWeightClampValue = ctx - > getOrClone ( mWeightClampValue ) ;
2020-11-05 16:41:56 +08:00
module - > mInputScale = ctx - > getOrClone ( mInputScale ) ;
module - > mOutputScale = ctx - > getOrClone ( mOutputScale ) ;
2021-04-08 15:34:23 +08:00
module - > mInputMin = ctx - > getOrClone ( mInputMin ) ;
module - > mInputMax = ctx - > getOrClone ( mInputMax ) ;
module - > mOutputMin = ctx - > getOrClone ( mOutputMin ) ;
module - > mOutputMax = ctx - > getOrClone ( mOutputMax ) ;
module - > mInputZeroPoint = ctx - > getOrClone ( mInputZeroPoint ) ;
module - > mOutputZeroPoint = ctx - > getOrClone ( mOutputZeroPoint ) ;
module - > mInputMinPos = mInputMinPos ;
module - > mInputMaxPos = mInputMaxPos ;
module - > mOutputMinPos = mOutputMinPos ;
module - > mOutputMaxPos = mOutputMaxPos ;
module - > mInputClampValue = ctx - > getOrClone ( mInputClampValue ) ;
module - > mOutputClampValue = ctx - > getOrClone ( mOutputClampValue ) ;
2020-11-05 16:41:56 +08:00
module - > mMomentum = mMomentum ;
module - > mFeatureScaleStatMethod = mFeatureScaleStatMethod ;
module - > mScaleUpdateMethod = mScaleUpdateMethod ;
if ( mBatchNorm ) {
module - > mBatchNorm . reset ( mBatchNorm - > clone ( ctx ) ) ;
module - > registerModel ( { module - > mBatchNorm } ) ;
}
return this - > cloneBaseTo ( ctx , module ) ;
}
2020-02-26 09:57:17 +08:00
NN : : ConvParameters mConvParameter ;
NN : : ConvOption mOption ;
int mGroup ;
VARP mWeight ;
VARP mBias ;
NN : : ActivationFunctionType mActivation = NN : : ActivationFunctionType : : None ;
std : : shared_ptr < Module > mBatchNorm = nullptr ;
2021-01-06 16:29:37 +08:00
int mBits ;
2021-04-08 15:34:23 +08:00
float mLimit ;
2020-02-26 09:57:17 +08:00
VARP mLimitScale ;
2021-04-08 15:34:23 +08:00
Express : : VARP mWeightClampValue ;
2020-02-26 09:57:17 +08:00
VARP mInputScale = nullptr ;
VARP mOutputScale = nullptr ;
2021-04-08 15:34:23 +08:00
VARP mInputMin = nullptr ;
VARP mInputMax = nullptr ;
VARP mOutputMin = nullptr ;
VARP mOutputMax = nullptr ;
VARP mInputZeroPoint = nullptr ;
VARP mOutputZeroPoint = nullptr ;
int mInputMinPos = - 1 ;
int mInputMaxPos = - 1 ;
int mOutputMinPos = - 1 ;
int mOutputMaxPos = - 1 ;
VARP mInputClampValue ;
VARP mOutputClampValue ;
2020-02-26 09:57:17 +08:00
float mMomentum = 0.99f ;
NN : : FeatureScaleStatMethod mFeatureScaleStatMethod ;
NN : : ScaleUpdateMethod mScaleUpdateMethod ;
2021-04-08 15:34:23 +08:00
bool mAccumulateToInt16 = false ;
2020-02-26 09:57:17 +08:00
} ;
Module * NN : : ConvBNReluFused ( std : : vector < std : : shared_ptr < Module > > modules ,
NN : : FeatureScaleStatMethod featureScaleStatMethod ,
NN : : ScaleUpdateMethod scaleUpdateMethod , const int bits ) {
return new ConvBNReluFusedModule ( modules , featureScaleStatMethod , scaleUpdateMethod , bits ) ;
}
Module * NN : : ConvInt8 ( const ConvOption & option , int bits , bool hasBias ,
std : : shared_ptr < Initializer > weightInit , std : : shared_ptr < Initializer > biasInit , NN : : FeatureScaleStatMethod featureMethod , NN : : ScaleUpdateMethod method ) {
std : : shared_ptr < Module > conv ( NN : : Conv ( option ) ) ;
return new ConvBNReluFusedModule ( { conv } , featureMethod , method , bits ) ;
}
Module * NN : : ConvInt8 ( const ConvParameters & para , int bits , NN : : FeatureScaleStatMethod featureMethod , NN : : ScaleUpdateMethod method ) {
std : : shared_ptr < Module > conv ( NN : : Conv ( para ) ) ;
return new ConvBNReluFusedModule ( { conv } , featureMethod , method , bits ) ;
2020-01-15 13:33:47 +08:00
}
2021-04-14 17:21:30 +08:00
bool NN : : turnQuantize ( Module * module , const int bits , NN : : FeatureScaleStatMethod featureScaleStatMethod , NN : : ScaleUpdateMethod scaleUpdateMethod ) {
if ( nullptr = = module | | module - > type ( ) ! = PIPELINE_MODULE ) {
MNN_ERROR ( " Invalide module for quantized \n " ) ;
return false ;
}
auto pipModule = static_cast < PipelineModule * > ( module ) ;
std : : vector < int > needEraseIndices ;
for ( int i = 0 ; i < pipModule - > mSubModules . size ( ) ; i + + ) {
auto & m = pipModule - > mSubModules [ i ] ;
auto & theModule = std : : get < 0 > ( m ) ;
auto moduleType = theModule - > type ( ) ;
//auto& inputIndices = std::get<1>(m);
auto & outputIndices = std : : get < 2 > ( m ) ;
if ( moduleType = = " Conv " & & i < pipModule - > mSubModules . size ( ) - 1 ) {
auto & p1 = pipModule - > mSubModules [ i + 1 ] ;
auto p1Module = std : : get < 0 > ( p1 ) ;
auto & p1ModuleType = p1Module - > type ( ) ;
auto & p1InputIndices = std : : get < 1 > ( p1 ) ;
auto & p1OutputIndices = std : : get < 2 > ( p1 ) ;
auto convOutputCount = pipModule - > countOutputReference ( outputIndices ) ;
bool convSingleOutputReference = ( ( outputIndices . size ( ) = = 1 ) & & ( convOutputCount [ 0 ] = = 1 ) ) ;
// only conv
if ( ( ! convSingleOutputReference ) | | ( p1ModuleType = = " Conv " ) | |
( p1ModuleType ! = " BatchNorm " & & p1ModuleType ! = " ReLU " & & p1ModuleType ! = " ReLU6 " ) ) {
theModule . reset ( NN : : ConvBNReluFused ( { theModule } , featureScaleStatMethod , scaleUpdateMethod , bits ) ) ;
pipModule - > registerModel ( { theModule } ) ;
continue ;
}
// conv + bn + ?
if ( p1ModuleType = = " BatchNorm " ) {
bool convBnConnected = ( ( convSingleOutputReference ) & & ( p1InputIndices . size ( ) = = 1 ) & & ( p1InputIndices [ 0 ] = = outputIndices [ 0 ] ) ) ;
if ( ! convBnConnected ) {
theModule . reset ( NN : : ConvBNReluFused ( { theModule } , featureScaleStatMethod , scaleUpdateMethod , bits ) ) ;
pipModule - > registerModel ( { theModule } ) ;
continue ;
}
// last conv + bn
if ( i = = pipModule - > mSubModules . size ( ) - 2 ) {
theModule . reset ( NN : : ConvBNReluFused ( { theModule , p1Module } , featureScaleStatMethod , scaleUpdateMethod , bits ) ) ;
pipModule - > registerModel ( { theModule } ) ;
outputIndices = p1OutputIndices ;
needEraseIndices . emplace_back ( i + 1 ) ;
continue ;
}
// maybe there is a relu or relu6 after conv + bn
auto & p2 = pipModule - > mSubModules [ i + 2 ] ;
auto & p2Module = std : : get < 0 > ( p2 ) ;
auto p2ModuleType = p2Module - > type ( ) ;
auto & p2InputIndices = std : : get < 1 > ( p2 ) ;
auto & p2OutputIndices = std : : get < 2 > ( p2 ) ;
auto bnOutputCount = pipModule - > countOutputReference ( p1OutputIndices ) ;
bool bnSingleOutputReference = ( ( p1OutputIndices . size ( ) = = 1 ) & & ( bnOutputCount [ 0 ] = = 1 ) ) ;
// only conv + bn
if ( ( ! bnSingleOutputReference ) | | ( p2ModuleType ! = " ReLU " & & p2ModuleType ! = " ReLU6 " ) ) {
theModule . reset ( NN : : ConvBNReluFused ( { theModule , p1Module } , featureScaleStatMethod , scaleUpdateMethod , bits ) ) ;
pipModule - > registerModel ( { theModule } ) ;
outputIndices = p1OutputIndices ;
needEraseIndices . emplace_back ( i + 1 ) ;
continue ;
} else { // conv + bn + relu or conv + bn + relu6
bool convBnReluConnected = ( ( bnSingleOutputReference ) & & ( p2InputIndices . size ( ) = = 1 ) & & ( p2InputIndices [ 0 ] = = p1OutputIndices [ 0 ] ) ) ;
2021-06-11 17:17:13 +08:00
bool isPrelu = false ;
if ( p2ModuleType = = " ReLU " ) {
auto p2Op = ( ( ExprModule * ) p2Module . get ( ) ) - > getExpr ( ) - > get ( ) ;
float slope = p2Op - > main_as_Relu ( ) - > slope ( ) ;
isPrelu = std : : abs ( slope ) > 1e-6 ;
}
if ( ! convBnReluConnected | | isPrelu ) {
2021-04-14 17:21:30 +08:00
theModule . reset ( NN : : ConvBNReluFused ( { theModule , p1Module } , featureScaleStatMethod , scaleUpdateMethod , bits ) ) ;
pipModule - > registerModel ( { theModule } ) ;
outputIndices = p1OutputIndices ;
needEraseIndices . emplace_back ( i + 1 ) ;
continue ;
}
theModule . reset ( NN : : ConvBNReluFused ( { theModule , p1Module , p2Module } , featureScaleStatMethod , scaleUpdateMethod , bits ) ) ;
pipModule - > registerModel ( { theModule } ) ;
outputIndices = p2OutputIndices ;
needEraseIndices . emplace_back ( i + 1 ) ;
needEraseIndices . emplace_back ( i + 2 ) ;
continue ;
}
}
// conv + relu or conv + relu6
if ( p1ModuleType = = " ReLU " | | p1ModuleType = = " ReLU6 " ) {
bool convReluConnected = ( ( convSingleOutputReference ) & & ( p1InputIndices . size ( ) = = 1 ) & & ( p1InputIndices [ 0 ] = = outputIndices [ 0 ] ) ) ;
2021-06-11 17:17:13 +08:00
bool isPrelu = false ;
if ( p1ModuleType = = " ReLU " ) {
auto p1Op = ( ( ExprModule * ) p1Module . get ( ) ) - > getExpr ( ) - > get ( ) ;
float slope = p1Op - > main_as_Relu ( ) - > slope ( ) ;
isPrelu = std : : abs ( slope ) > 1e-6 ;
}
if ( ! convReluConnected | | isPrelu ) {
2021-04-14 17:21:30 +08:00
theModule . reset ( NN : : ConvBNReluFused ( { theModule } , featureScaleStatMethod , scaleUpdateMethod , bits ) ) ;
pipModule - > registerModel ( { theModule } ) ;
continue ;
}
theModule . reset ( NN : : ConvBNReluFused ( { theModule , p1Module } , featureScaleStatMethod , scaleUpdateMethod , bits ) ) ;
pipModule - > registerModel ( { theModule } ) ;
outputIndices = p1OutputIndices ;
needEraseIndices . emplace_back ( i + 1 ) ;
continue ;
}
}
if ( i = = pipModule - > mSubModules . size ( ) - 1 & & moduleType = = " Conv " ) {
theModule . reset ( NN : : ConvBNReluFused ( { theModule } , featureScaleStatMethod , scaleUpdateMethod , bits ) ) ;
pipModule - > registerModel ( { theModule } ) ;
}
}
// erase useless submodules
const int eraseSize = needEraseIndices . size ( ) ;
int alreadyErasedCount = 0 ;
for ( int i = 0 ; i < eraseSize ; i + + ) {
auto position = needEraseIndices [ i ] - alreadyErasedCount ;
auto type = std : : get < 0 > ( pipModule - > mSubModules [ position ] ) - > type ( ) ;
MNN_ASSERT ( type = = " BatchNorm " | | type = = " ReLU " | | type = = " ReLU6 " ) ;
pipModule - > mSubModules . erase ( pipModule - > mSubModules . begin ( ) + position ) ;
alreadyErasedCount + + ;
}
return true ;
}
Module * NN : : extract ( std : : vector < Express : : VARP > inputs , std : : vector < Express : : VARP > outputs , bool fortrain , const std : : map < std : : string , SubGraph > & subGraph ) {
std : : function < std : : pair < std : : vector < int > , std : : shared_ptr < Module > > ( EXPRP ) > transformFunction ;
if ( fortrain ) {
transformFunction =
[ & subGraph ] ( EXPRP source ) {
if ( source - > get ( ) = = nullptr ) {
return std : : make_pair ( std : : vector < int > { } , std : : shared_ptr < Module > ( nullptr ) ) ;
}
std : : shared_ptr < Module > m ( NN : : Utils : : ExtractNotRunableOp ( source , subGraph ) ) ;
if ( nullptr ! = m ) {
m - > setName ( source - > name ( ) ) ;
return std : : make_pair ( std : : vector < int > { } , m ) ;
}
auto convExtracted = NN : : Utils : : ExtractConvolution ( source ) ;
if ( convExtracted . weight = = nullptr ) {
return std : : make_pair ( std : : vector < int > { } , std : : shared_ptr < Module > ( nullptr ) ) ;
}
std : : shared_ptr < Module > module ( NN : : Conv ( convExtracted ) ) ;
module - > setName ( source - > name ( ) ) ;
return std : : make_pair ( std : : vector < int > { 0 } , module ) ;
} ;
} else {
transformFunction = [ & subGraph ] ( EXPRP source ) {
if ( source - > get ( ) = = nullptr ) {
return std : : make_pair ( std : : vector < int > { } , std : : shared_ptr < Module > ( nullptr ) ) ;
}
std : : shared_ptr < Module > m ( NN : : Utils : : ExtractNotRunableOp ( source , subGraph ) ) ;
if ( nullptr ! = m ) {
m - > setName ( source - > name ( ) ) ;
return std : : make_pair ( std : : vector < int > { } , m ) ;
}
return std : : make_pair ( std : : vector < int > { } , std : : shared_ptr < Module > ( nullptr ) ) ;
} ;
}
return new PipelineModule ( inputs , outputs , transformFunction ) ;
}
2020-11-05 16:41:56 +08:00
} // namespace Express
2021-04-08 15:34:23 +08:00
} // namespace MNN