MNN/tools/train/source/nn/NN.cpp

1091 lines
44 KiB
C++
Raw Normal View History

2019-12-27 22:16:57 +08:00
//
// NN.cpp
// MNN
//
// Created by MNN on 2019/11/25.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "NN.hpp"
2019-12-27 22:16:57 +08:00
#include "Distributions.hpp"
#include "module/PipelineModule.hpp"
#include "module/WhileModule.hpp"
#include "module/IfModule.hpp"
2019-12-27 22:16:57 +08:00
#include "Initializer.hpp"
#include "MNN_generated.h"
2019-12-27 22:16:57 +08:00
#include "RandomGenerator.hpp"
#include "core/Macro.h"
2020-02-26 09:57:17 +08:00
#include <string>
2019-12-27 22:16:57 +08:00
using namespace MNN::Express;
namespace MNN {
2020-11-05 16:41:56 +08:00
namespace Express {
2020-02-26 09:57:17 +08:00
static VARP _activate(VARP x, NN::ActivationFunctionType type) {
switch (type) {
case NN::None:
return x;
case NN::Relu:
return _Relu(x);
case NN::Relu6:
return _Relu6(x);
default:
break;
}
return nullptr;
}
2019-12-27 22:16:57 +08:00
class DropoutModule : public Module {
public:
DropoutModule(const float dropRatio) {
mDropRatio = dropRatio;
2020-02-26 09:57:17 +08:00
setType("Dropout");
2019-12-27 22:16:57 +08:00
}
virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override {
Express::VARP x = inputs[0];
if (getIsTraining()) {
float scale = 1. / (1. - mDropRatio);
auto mask = _Input(x->getInfo()->dim, x->getInfo()->order, x->getInfo()->type);
auto maskPtr = mask->writeMap<float>();
auto eltSize = x->getInfo()->size;
Distributions::uniform(eltSize, 0, 1, maskPtr, RandomGenerator::generator());
for (int i = 0; i < eltSize; i++) {
maskPtr[i] = maskPtr[i] < mDropRatio ? 0.0f : scale;
}
x = x * mask;
}
return {x};
}
private:
2020-11-05 16:41:56 +08:00
DropoutModule() = default;
Module* clone(CloneContext* ctx) const override {
DropoutModule* module(new DropoutModule);
module->mDropRatio = mDropRatio;
return this->cloneBaseTo(ctx, module);
}
2019-12-27 22:16:57 +08:00
float mDropRatio;
};
class BatchNormModule : public Module {
public:
2020-02-26 09:57:17 +08:00
BatchNormModule(EXPRP expr, const float m = 0.99) {
MNN_ASSERT(expr->get() != nullptr);
MNN_ASSERT(expr->get()->type() == OpType_BatchNorm);
auto bnPa = expr->get()->main_as_BatchNorm();
auto& inputs = expr->inputs();
int dims = 4;
if (!inputs.empty()) {
auto info = inputs[0]->getInfo();
if (nullptr != info) {
dims = info->dim.size();
}
}
2020-02-26 09:57:17 +08:00
mEps = bnPa->epsilon();
mMomentum = m;
mChannels = bnPa->channels();
std::vector<int> statShape;
std::vector<int> reductionDims;
int channels = mChannels;
if (dims == 2) {
statShape = {1, channels};
mReductionDims = {0};
}
if (dims == 3) {
statShape = {1, channels, 1};
mReductionDims = {0, 2};
}
if (dims == 4) {
statShape = {1, channels, 1, 1};
mReductionDims = {0, 2, 3};
}
2020-02-26 09:57:17 +08:00
MNN_ASSERT(bnPa->biasData()->size() == mChannels);
mBias = _TrainableParam(bnPa->biasData()->data(), statShape, NCHW);
2020-02-26 09:57:17 +08:00
MNN_ASSERT(bnPa->slopeData()->size() == mChannels);
mScale = _TrainableParam(bnPa->slopeData()->data(), statShape, NCHW);
2020-02-26 09:57:17 +08:00
MNN_ASSERT(bnPa->meanData()->size() == mChannels);
mRunningMean = _Const(bnPa->meanData()->data(), statShape, NCHW);
2020-02-26 09:57:17 +08:00
MNN_ASSERT(bnPa->meanData()->size() == mChannels);
mRunningVariance = _Const(bnPa->varData()->data(), statShape, NCHW);
2020-02-26 09:57:17 +08:00
addParameter(mScale);
addParameter(mBias);
2020-11-05 16:41:56 +08:00
mRunningVariancePos = addParameter(mRunningVariance);
mRunningMeanPos = addParameter(mRunningMean);
2020-02-26 09:57:17 +08:00
setType("BatchNorm");
}
BatchNormModule(const int channels, const int dims = 4, const float m = 0.99, const float e = 1e-5) {
2019-12-27 22:16:57 +08:00
mMomentum = m;
mEps = e;
2019-12-27 22:16:57 +08:00
mChannels = channels;
std::vector<int> statShape;
std::vector<int> reductionDims;
if (dims == 2) {
2020-02-26 09:57:17 +08:00
statShape = {1, channels};
2019-12-27 22:16:57 +08:00
mReductionDims = {0};
}
if (dims == 3) {
statShape = {1, channels, 1};
mReductionDims = {0, 2};
}
2019-12-27 22:16:57 +08:00
if (dims == 4) {
2020-02-26 09:57:17 +08:00
statShape = {1, channels, 1, 1};
2019-12-27 22:16:57 +08:00
mReductionDims = {0, 2, 3};
}
2020-02-26 09:57:17 +08:00
mScale = _TrainableParam(1.0f, statShape, NCHW);
mBias = _TrainableParam(0.0f, statShape, NCHW);
mRunningMean = _Const(0.0f, statShape, NCHW);
2019-12-27 22:16:57 +08:00
mRunningVariance = _Const(0.0f, statShape, NCHW);
addParameter(mScale);
addParameter(mBias);
2020-11-05 16:41:56 +08:00
mRunningVariancePos = addParameter(mRunningVariance);
mRunningMeanPos = addParameter(mRunningMean);
2020-02-26 09:57:17 +08:00
setType("BatchNorm");
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
VARP runningMean() {
return mRunningMean;
}
2019-12-27 22:16:57 +08:00
2020-02-26 09:57:17 +08:00
VARP runningVariance() {
return mRunningVariance;
}
2019-12-27 22:16:57 +08:00
2020-02-26 09:57:17 +08:00
VARP scale() {
return mScale;
}
2019-12-27 22:16:57 +08:00
2020-02-26 09:57:17 +08:00
VARP bias() {
return mBias;
}
float eps() {
return mEps;
}
2019-12-27 22:16:57 +08:00
2020-02-26 09:57:17 +08:00
virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override {
Express::VARP x = inputs[0];
auto dimFormat = x->getInfo()->order;
VARP outputData = nullptr;
2019-12-27 22:16:57 +08:00
if (getIsTraining()) {
2020-02-26 09:57:17 +08:00
if (dimFormat == NC4HW4 || dimFormat == NHWC) {
x = _Convert(x, NCHW);
}
MNN_ASSERT(x->getInfo()->dim[1] == mChannels);
auto sampleMean = _ReduceMean(x, mReductionDims, true); // mean for each channel in the batch
2020-02-26 09:57:17 +08:00
auto xSub = x - sampleMean;
auto sampleVar = _ReduceMean(_Square(xSub), mReductionDims,
true); // variance for each channel in the batch
2020-02-26 09:57:17 +08:00
auto rSampleStd = _Reciprocal(_Sqrt(sampleVar + _Const(mEps)));
auto normalizedData = xSub * rSampleStd;
outputData = normalizedData * mScale + mBias;
2019-12-27 22:16:57 +08:00
mRunningMean = _Const(mMomentum) * mRunningMean + _Const(1 - mMomentum) * sampleMean;
mRunningVariance = _Const(mMomentum) * mRunningVariance + _Const(1 - mMomentum) * sampleVar;
2020-02-26 09:57:17 +08:00
outputData->setName(name());
2019-12-27 22:16:57 +08:00
outputData = _Convert(outputData, dimFormat);
2020-11-05 16:41:56 +08:00
setParameter(mRunningMean, mRunningMeanPos);
setParameter(mRunningVariance, mRunningVariancePos);
2020-02-26 09:57:17 +08:00
return {outputData};
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
auto rStd = _Const(1.0f) / _Sqrt(mRunningVariance + _Const(mEps));
auto alpha = rStd * mScale;
auto beta = mBias - mRunningMean * rStd * mScale;
//outputData = (_Convert(x, NCHW) * alpha) + beta;
alpha.fix(VARP::CONSTANT);
beta.fix(VARP::CONSTANT);
2020-02-26 09:57:17 +08:00
//FUNC_PRINT_ALL(alpha->readMap<float>()[0], f);
x = _Convert(x, NC4HW4);
std::vector<float> scale(alpha->getInfo()->size);
std::vector<float> bias(beta->getInfo()->size);
::memcpy(scale.data(), alpha->readMap<float>(), scale.size() * sizeof(float));
::memcpy(bias.data(), beta->readMap<float>(), bias.size() * sizeof(float));
outputData = _Scale(x, mChannels, std::move(scale), std::move(bias));
outputData->setName(name());
outputData = _Convert(outputData, dimFormat);
2019-12-27 22:16:57 +08:00
return {outputData};
}
private:
2020-11-05 16:41:56 +08:00
BatchNormModule() = default;
Module* clone(CloneContext* ctx) const override {
BatchNormModule* module(new BatchNormModule);
module->mMomentum = mMomentum;
module->mEps = mEps;
module->mScale = ctx->getOrClone(mScale);
module->mBias = ctx->getOrClone(mBias);
module->mRunningMean = ctx->getOrClone(mRunningMean);
module->mRunningVariance = ctx->getOrClone(mRunningVariance);
module->mRunningMeanPos = mRunningMeanPos;
module->mRunningVariancePos = mRunningVariancePos;
module->mChannels = mChannels;
module->mReductionDims = mReductionDims;
return this->cloneBaseTo(ctx, module);
}
2020-02-26 09:57:17 +08:00
float mMomentum = 0.99;
float mEps = 1e-5;
VARP mScale = nullptr;
VARP mBias = nullptr;
VARP mRunningMean = nullptr;
2019-12-27 22:16:57 +08:00
VARP mRunningVariance = nullptr;
2020-11-05 16:41:56 +08:00
int mRunningMeanPos = -1;
int mRunningVariancePos = -1;
2019-12-27 22:16:57 +08:00
int mChannels;
std::vector<int> mReductionDims;
};
void NN::ConvOption::reset(int size) {
stride = std::vector<int>(size, 1);
channel = std::vector<int>(size, 0);
kernelSize = std::vector<int>(size, 1);
dilate = std::vector<int>(size, 1);
padMode = VALID;
pads = std::vector<int>(size, 0);
depthwise = false;
2020-02-26 09:57:17 +08:00
fusedActivationFunction = None;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
class ConvModule : public Module {
public:
ConvModule(const NN::ConvParameters& parameters) {
mParameter = parameters;
if (nullptr != mParameter.bias) {
addParameter(mParameter.bias);
}
if (nullptr != mParameter.weight) {
addParameter(mParameter.weight);
}
setName(parameters.name);
setType("Conv");
}
NN::ConvParameters& convParameters() {
return mParameter;
}
virtual std::vector<VARP> onForward(const std::vector<VARP>& inputs) override {
auto input = inputs[0];
auto& option = mParameter.option;
if (getIsTraining()) {
auto tempOutput = _Conv(mParameter.weight, mParameter.bias, _Convert(input, NC4HW4), option.padMode, option.stride, option.dilate, mParameter.group, mParameter.option.pads);
tempOutput->setName(name());
tempOutput = _activate(tempOutput, option.fusedActivationFunction);
return {tempOutput};
}
bool relu = option.fusedActivationFunction == NN::Relu;
bool relu6 = option.fusedActivationFunction == NN::Relu6;
std::vector<float> weight;
std::vector<float> bias;
{
auto weightInfo = mParameter.weight->getInfo();
weight.resize(weightInfo->size);
::memcpy(weight.data(), mParameter.weight->readMap<float>(), weight.size() * sizeof(float));
}
{
bias.resize(mParameter.option.channel[1]);
if (nullptr != mParameter.bias) {
::memcpy(bias.data(), mParameter.bias->readMap<float>(), bias.size() * sizeof(float));
} else {
::memset(bias.data(), 0, bias.size() * sizeof(float));
}
}
auto tempOutput = _Conv(std::move(weight), std::move(bias), _Convert(input, NC4HW4), option.channel, option.kernelSize, option.padMode, option.stride, option.dilate, mParameter.group, mParameter.option.pads, relu, relu6);
tempOutput->setName(name());
return {tempOutput};
}
2020-11-05 16:41:56 +08:00
2020-02-26 09:57:17 +08:00
private:
2020-11-05 16:41:56 +08:00
ConvModule() = default;
Module* clone(CloneContext* ctx) const override {
ConvModule* module(new ConvModule);
module->mParameter = mParameter;
module->mParameter.weight = ctx->getOrClone(mParameter.weight);
module->mParameter.bias = ctx->getOrClone(mParameter.bias);
return this->cloneBaseTo(ctx, module);
}
2020-02-26 09:57:17 +08:00
NN::ConvParameters mParameter;
};
2019-12-27 22:16:57 +08:00
static std::tuple<VARP, VARP, int> _initParameters(const NN::ConvOption& option, bool hasBias,
std::shared_ptr<Initializer> weightInit,
std::shared_ptr<Initializer> biasInit) {
std::tuple<VARP, VARP, int> defaultRes;
if (nullptr == weightInit) {
weightInit.reset(Initializer::xavier());
}
if (nullptr == biasInit) {
biasInit.reset(Initializer::constValue(0.0f));
}
VARP weight;
int group = 1;
if (option.depthwise) {
if (option.channel[1] != option.channel[0]) {
MNN_ERROR("Can't support not the same channel for convolution depthwise\n");
return defaultRes;
}
2020-02-26 09:57:17 +08:00
weight = weightInit->createConstVar({option.channel[0], 1, option.kernelSize[1], option.kernelSize[0]}, NCHW);
weight.fix(VARP::TRAINABLE);
2019-12-27 22:16:57 +08:00
group = option.channel[0];
} else {
weight = weightInit->createConstVar(
{option.channel[1], option.channel[0], option.kernelSize[1], option.kernelSize[0]}, NCHW);
2020-02-26 09:57:17 +08:00
weight.fix(VARP::TRAINABLE);
2019-12-27 22:16:57 +08:00
}
VARP bias;
if (hasBias) {
bias = biasInit->createConstVar({option.channel[1]}, NCHW);
2020-02-26 09:57:17 +08:00
bias.fix(VARP::TRAINABLE);
2019-12-27 22:16:57 +08:00
}
return std::make_tuple(weight, bias, group);
}
2020-02-26 09:57:17 +08:00
Module* NN::ConvTranspose(const ConvOption& option, bool hasBias,
2019-12-27 22:16:57 +08:00
std::shared_ptr<Initializer> weightInit,
std::shared_ptr<Initializer> biasInit) {
VARP input = _Input({1, option.channel[0], 1, 1}, NC4HW4);
auto tuple = _initParameters(option, hasBias, weightInit, biasInit);
auto weight = std::get<0>(tuple);
if (nullptr == weight) {
return nullptr;
}
if (!option.depthwise) {
weight = _Transpose(weight, {1, 0, 2, 3});
weight.fix(VARP::TRAINABLE);
}
auto bias = std::get<1>(tuple);
auto group = std::get<2>(tuple);
if (nullptr != bias) {
auto tempOutput = _Deconv(weight, bias, input, option.padMode, option.stride, option.dilate, group);
2020-02-26 09:57:17 +08:00
tempOutput = _activate(tempOutput, option.fusedActivationFunction);
return NN::extract({input}, {tempOutput}, true);
2019-12-27 22:16:57 +08:00
}
auto tempOutput = _Deconv(weight, nullptr, input, option.padMode, option.stride, option.dilate, group);
2020-02-26 09:57:17 +08:00
tempOutput = _activate(tempOutput, option.fusedActivationFunction);
return NN::extract({input}, {tempOutput}, true);
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
Module* NN::Conv(const ConvOption& option, bool hasBias, std::shared_ptr<Initializer> weightInit,
2019-12-27 22:16:57 +08:00
std::shared_ptr<Initializer> biasInit) {
auto tuple = _initParameters(option, hasBias, weightInit, biasInit);
2020-02-26 09:57:17 +08:00
ConvParameters parameters;
parameters.weight = std::get<0>(tuple);
if (nullptr == parameters.weight) {
2019-12-27 22:16:57 +08:00
return nullptr;
}
2020-02-26 09:57:17 +08:00
parameters.bias = std::get<1>(tuple);
parameters.group = std::get<2>(tuple);
parameters.option = option;
return new ConvModule(parameters);
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
Module* NN::Linear(int l, int t, bool hasBias, std::shared_ptr<Initializer> weightInit,
2019-12-27 22:16:57 +08:00
std::shared_ptr<Initializer> biasInit) {
if (nullptr == weightInit) {
weightInit.reset(Initializer::xavier());
}
if (nullptr == biasInit) {
biasInit.reset(Initializer::constValue(0.0f));
}
auto weight = weightInit->createConstVar({t, l}, NCHW);
2020-02-26 09:57:17 +08:00
weight.fix(VARP::TRAINABLE);
2019-12-27 22:16:57 +08:00
auto input = _Input({l}, NCHW);
auto output = _MatMul(input, weight, false, true);
if (!hasBias) {
return NN::extract({input}, {output}, true);
2019-12-27 22:16:57 +08:00
}
auto bias = biasInit->createConstVar({1, t}, NCHW);
2020-02-26 09:57:17 +08:00
bias.fix(VARP::TRAINABLE);
2019-12-27 22:16:57 +08:00
output = _Add(output, bias);
auto module = NN::extract({input}, {output}, true);
2020-02-26 09:57:17 +08:00
module->setType("Linear");
return module;
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
Module* NN::Dropout(const float dropRatio) {
return new DropoutModule(dropRatio);
2019-12-27 22:16:57 +08:00
}
2020-02-26 09:57:17 +08:00
Module* NN::BatchNorm(const int channels, const int dims, const float m, const float e) {
return new BatchNormModule(channels, dims, m, e);
}
2020-02-26 09:57:17 +08:00
NN::ConvParameters NN::Utils::ExtractConvolution(EXPRP source) {
ConvParameters _default;
if (source->get() == nullptr) {
return _default;
}
if (source->get()->type() != OpType_Convolution && source->get()->type() != OpType_ConvolutionDepthwise) {
return _default;
}
auto conv2D = source->get()->main_as_Convolution2D();
NN::ConvOption option;
option.kernelSize = {conv2D->common()->kernelX(), conv2D->common()->kernelY()};
option.stride = {conv2D->common()->strideX(), conv2D->common()->strideY()};
2020-02-26 09:57:17 +08:00
if (nullptr != conv2D->common()->pads()) {
option.pads.resize(conv2D->common()->pads()->size());
for (int i=0; i<option.pads.size(); ++i) {
option.pads[i] = conv2D->common()->pads()->data()[i];
}
} else {
option.pads = {conv2D->common()->padX(), conv2D->common()->padY()};
}
switch (conv2D->common()->padMode()) {
case MNN::PadMode_SAME:
option.padMode = SAME;
break;
case MNN::PadMode_VALID:
option.padMode = VALID;
break;
2020-02-26 09:57:17 +08:00
case MNN::PadMode_CAFFE:
option.padMode = CAFFE;
break;
default:
break;
}
option.dilate = {conv2D->common()->dilateX(), conv2D->common()->dilateY()};
option.depthwise = source->get()->type() == OpType_ConvolutionDepthwise;
2020-07-04 01:21:30 +08:00
auto inputCount = conv2D->common()->inputCount();
if (0 == inputCount) {
auto inputInfo = source->inputs()[0]->getInfo();
if (nullptr != inputInfo) {
if (NHWC == inputInfo->order) {
inputCount = source->inputs()[0]->getInfo()->dim[3];
} else {
inputCount = source->inputs()[0]->getInfo()->dim[1];
}
} else {
if (nullptr == conv2D->weight()) {
MNN_ERROR("Can't extract convolution\n");
return _default;
}
auto weightCount = conv2D->weight()->size();
if (option.depthwise) {
inputCount = conv2D->common()->outputCount();
} else {
inputCount = weightCount / conv2D->common()->kernelX() / conv2D->common()->kernelY() / conv2D->common()->outputCount();
}
}
}
option.channel = {inputCount, conv2D->common()->outputCount()};
int group = 1;
2020-02-26 09:57:17 +08:00
if (option.depthwise) {
group = conv2D->common()->outputCount();
}
2020-02-26 09:57:17 +08:00
VARP weight;
auto inputs = source->inputs();
if (inputs.size() > 1) {
weight = inputs[1];
}
VARP bias;
if (inputs.size() > 2) {
bias = inputs[2];
}
2020-02-26 09:57:17 +08:00
if (inputs.size() < 2) {
// Extract Weight And Bias from Conv2D
if (conv2D->weight() == nullptr || conv2D->bias() == nullptr) {
return _default;
}
bias = _TrainableParam(conv2D->bias()->data(), {option.channel[1]}, NCHW);
weight = _TrainableParam(conv2D->weight()->data(), {option.channel[1], option.channel[0] / group, option.kernelSize[1], option.kernelSize[0]}, NCHW);
}
_default.option = std::move(option);
_default.weight = std::move(weight);
_default.bias = std::move(bias);
_default.group = group;
if (conv2D->common()->relu()) {
_default.option.fusedActivationFunction = NN::Relu;
}
if (conv2D->common()->relu6()) {
_default.option.fusedActivationFunction = NN::Relu6;
}
_default.name = source->name();
return _default;
}
2020-02-26 09:57:17 +08:00
Module* NN::Conv(const ConvParameters& parameter) {
return new ConvModule(parameter);
}
2020-11-05 16:41:56 +08:00
Module* NN::Utils::ExtractNotRunableOp(Express::EXPRP expr, const std::map<std::string, SubGraph>& subgraphs) {
2020-02-26 09:57:17 +08:00
if (nullptr == expr->get()) {
return nullptr;
}
2020-02-26 09:57:17 +08:00
if (expr->get()->type() == OpType_BatchNorm) {
return new BatchNormModule(expr);
}
if (expr->get()->type() == OpType_Dropout) {
return new DropoutModule(0.3f);
}
2020-11-05 16:41:56 +08:00
if (expr->get()->type() == OpType_While) {
return WhileModule::create(expr->get(), subgraphs);
}
if (expr->get()->type() == OpType_If) {
return IfModule::create(expr->get(), subgraphs);
}
2020-02-26 09:57:17 +08:00
return nullptr;
}
2020-02-26 09:57:17 +08:00
class ConvBNReluFusedModule : public Module {
public:
ConvBNReluFusedModule(std::vector<std::shared_ptr<Module> > modules,
NN::FeatureScaleStatMethod featureScaleStatMethod,
NN::ScaleUpdateMethod scaleUpdateMethod, const int bits) {
MNN_ASSERT(modules.size() >= 1);
MNN_ASSERT(modules[0]->type() == "Conv");
if (modules.size() == 3) {
MNN_ASSERT(modules[1]->type() == "BatchNorm");
MNN_ASSERT(modules[2]->type() == "ReLU" || modules[2]->type() == "ReLU6");
}
for (int i = 0; i < modules.size(); i++) {
auto type = modules[i]->type();
if (type == "Conv") {
mConvParameter = std::static_pointer_cast<ConvModule>(modules[i])->convParameters();
mOption = mConvParameter.option;
mGroup = mConvParameter.group;
mWeight = mConvParameter.weight;
mBias = mConvParameter.bias;
if (nullptr != mWeight) {
addParameter(mWeight);
}
if (nullptr != mBias) {
addParameter(mBias);
}
setName(mConvParameter.name);
modules[i] = nullptr;
} else if (type == "BatchNorm") {
mBatchNorm = modules[i];
registerModel({mBatchNorm});
} else if (type == "ReLU") {
mActivation = NN::Relu;
modules[i] = nullptr;
} else if (type == "ReLU6") {
mActivation = NN::Relu6;
modules[i] = nullptr;
} else {
MNN_ASSERT(false);
}
}
if (mOption.fusedActivationFunction == NN::Relu || mOption.fusedActivationFunction == NN::Relu6) {
mActivation = mOption.fusedActivationFunction;
}
2021-04-08 15:34:23 +08:00
mFeatureScaleStatMethod = NN::PerTensor;
2020-02-26 09:57:17 +08:00
mScaleUpdateMethod = scaleUpdateMethod;
2021-01-06 16:29:37 +08:00
mBits = bits;
2021-04-08 15:34:23 +08:00
mLimit = (float)(1 << (bits - 1)) - 1.0f;
mLimitScale = _Scalar<float>(1.0f / mLimit);
mWeightClampValue = _Scalar<float>(mLimit);
mInputClampValue = _Scalar<float>(mLimit);
mOutputClampValue = _Scalar<float>(mLimit);
2020-11-05 16:41:56 +08:00
2021-04-08 15:34:23 +08:00
mInputMinPos = addParameter(mInputMin);
mInputMaxPos = addParameter(mInputMax);
mOutputMinPos = addParameter(mOutputMin);
mOutputMaxPos = addParameter(mOutputMax);
2020-02-26 09:57:17 +08:00
setType("ConvBNReluFused");
}
2021-04-08 15:34:23 +08:00
std::pair<VARP, VARP> computeScaleAndZeroPoint(VARP min, VARP max, VARP clampVar) {
MNN_ASSERT((!(min == nullptr)));
MNN_ASSERT((!(max == nullptr)));
min = _Minimum(_Scalar<float>(0.0f), min);
max = _Maximum(_Scalar<float>(0.0f), max);
auto scale = (max - min) / (_Scalar(2.0f) * clampVar);
auto zeroPoint = _Round((_Scalar(0.0f) - min) / scale - clampVar);
return std::make_pair(scale, zeroPoint);
}
std::vector<VARP> fakeQuantFeatureWithMinMax(VARP x, VARP useMin, VARP useMax, VARP clampVar) {
2020-02-26 09:57:17 +08:00
auto originFormat = x->getInfo()->order;
auto tempX = x;
if (originFormat == NC4HW4) {
tempX = _Convert(tempX, NCHW);
}
auto originX = tempX;
2021-04-08 15:34:23 +08:00
VARP min, max;
// always PerTensor
min = _ReduceMin(tempX);
max = _ReduceMax(tempX);
VARP scale, zeroPoint;
VARP nudgeMin, nudgeMax;
if (!(useMin == nullptr)) {
MNN_ASSERT(!(useMax == nullptr));
auto scaleAndZeroPoint = computeScaleAndZeroPoint(useMin, useMax, clampVar);
scale = scaleAndZeroPoint.first;
zeroPoint = scaleAndZeroPoint.second;
2020-02-26 09:57:17 +08:00
} else {
2021-04-08 15:34:23 +08:00
auto scaleAndZeroPoint = computeScaleAndZeroPoint(min, max, clampVar);
scale = scaleAndZeroPoint.first;
zeroPoint = scaleAndZeroPoint.second;
2020-02-26 09:57:17 +08:00
}
2021-04-08 15:34:23 +08:00
float limit = clampVar->readMap<float>()[0];
nudgeMin = (_Scalar<float>(-limit) - zeroPoint) * scale;
nudgeMax = (_Scalar<float>(limit) - zeroPoint) * scale;
nudgeMin = _Minimum(_Scalar<float>(0.0f), nudgeMin);
nudgeMax = _Maximum(_Scalar<float>(0.0f), nudgeMax);
auto quantX = clamp(_Round(tempX / scale + zeroPoint), clampVar);
tempX = scale * (quantX - zeroPoint);
2020-11-05 16:41:56 +08:00
// Break the grad by use cast
tempX = _Cast<float>(tempX);
// Move grad from tempX to originX
2020-02-26 09:57:17 +08:00
tempX = _Convert(tempX + _ZeroGrad(originX), originFormat);
2021-04-08 15:34:23 +08:00
return {tempX, nudgeMin, nudgeMax};
2020-02-26 09:57:17 +08:00
}
2021-04-08 15:34:23 +08:00
VARP clamp(VARP x, VARP clampVar) {
return _Maximum(_Minimum(x, clampVar), _Negative(clampVar));
2020-02-26 09:57:17 +08:00
}
2021-04-08 15:34:23 +08:00
VARP updateParameter(VARP originValue, VARP newValue) const {
2020-02-26 09:57:17 +08:00
if (nullptr == originValue) {
return newValue;
}
switch (mScaleUpdateMethod) {
case NN::MovingAverage:
return originValue * _Scalar<float>(mMomentum) + newValue * _Scalar<float>(1.0f-mMomentum);
case NN::Maximum:
return _Maximum(originValue, newValue);
default:
break;
}
MNN_ASSERT(false);
return nullptr;
}
virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override {
VARP res;
if (getIsTraining()) {
auto x = _Convert(inputs[0], NCHW);
// simulate weight quant
2021-04-08 15:34:23 +08:00
auto weightScale = _Maximum(_ReduceMax(_Abs(mWeight), {1, 2, 3}, true), _Scalar<float>(1E-6)) * _Reciprocal(mWeightClampValue);
auto weightTemp = clamp(_Round(mWeight * _Reciprocal(weightScale)), mWeightClampValue) * weightScale;
2020-02-26 09:57:17 +08:00
weightTemp = weightTemp + _ZeroGrad(mWeight);
// simulate input quant to get original input scale
2021-04-08 15:34:23 +08:00
auto inputPair = fakeQuantFeatureWithMinMax(x, nullptr, nullptr, mInputClampValue);
mInputMin = updateParameter(mInputMin, inputPair[1]);
mInputMax = updateParameter(mInputMax, inputPair[2]);
setParameter(mInputMin, mInputMinPos);
setParameter(mInputMax, mInputMaxPos);
2020-02-26 09:57:17 +08:00
// simulate output quant to get original output scale
2021-04-08 15:34:23 +08:00
res = _Conv(weightTemp, mBias, _Convert(inputPair[0], NC4HW4), mOption.padMode, mOption.stride,
2020-02-26 09:57:17 +08:00
mOption.dilate, mGroup, mOption.pads);
res->setName(name());
if (mBatchNorm) {
res = mBatchNorm->forward(res);
}
res = _activate(res, mActivation);
2021-04-08 15:34:23 +08:00
auto outputPair = fakeQuantFeatureWithMinMax(res, nullptr, nullptr, mOutputClampValue);
mOutputMin = updateParameter(mOutputMin, outputPair[1]);
mOutputMax = updateParameter(mOutputMax, outputPair[2]);
setParameter(mOutputMin, mOutputMinPos);
setParameter(mOutputMax, mOutputMaxPos);
res = outputPair[0];
2020-02-26 09:57:17 +08:00
} else {
2021-04-08 15:34:23 +08:00
if (nullptr == mInputMin) {
2020-02-26 09:57:17 +08:00
// Initial for test
// simulate weight quant
2021-04-08 15:34:23 +08:00
auto weightScale = _Maximum(_ReduceMax(_Abs(mWeight), {1, 2, 3}, true), _Scalar<float>(1E-6)) * _Reciprocal(mWeightClampValue);
auto weightTemp = clamp(_Round(mWeight * _Reciprocal(weightScale)), mWeightClampValue) * weightScale;
2020-02-26 09:57:17 +08:00
auto x = _Convert(inputs[0], NCHW);
2021-04-08 15:34:23 +08:00
auto inputPair = fakeQuantFeatureWithMinMax(x, nullptr, nullptr, mInputClampValue);
mInputMin = updateParameter(mInputMin, inputPair[1]);
mInputMax = updateParameter(mInputMax, inputPair[2]);
setParameter(mInputMin, mInputMinPos);
setParameter(mInputMax, mInputMaxPos);
auto simuRes = _Conv(weightTemp, mBias, _Convert(inputPair[0], NC4HW4), mOption.padMode, mOption.stride,
2020-02-26 09:57:17 +08:00
mOption.dilate, mGroup, mOption.pads);
if (mBatchNorm) {
simuRes = mBatchNorm->forward(simuRes);
}
simuRes = _activate(simuRes, mActivation);
Variable::prepareCompute({simuRes});
2021-04-08 15:34:23 +08:00
auto outputPair = fakeQuantFeatureWithMinMax(simuRes, nullptr, nullptr, mOutputClampValue);
mOutputMin = updateParameter(mOutputMin, outputPair[1]);
mOutputMax = updateParameter(mOutputMax, outputPair[2]);
setParameter(mOutputMin, mOutputMinPos);
setParameter(mOutputMax, mOutputMaxPos);
2020-02-26 09:57:17 +08:00
}
// fold bn to conv weights and bias
VARP fusedWeights = mWeight;
VARP fusedBias = mBias;
fusedBias = _Reshape(fusedBias, {fusedBias->getInfo()->size, 1, 1, 1});
if (mBatchNorm) {
auto bn = std::static_pointer_cast<BatchNormModule>(mBatchNorm);
auto bnMean = bn->runningMean();
auto bnVar = bn->runningVariance();
auto bnScale = bn->scale();
auto bnBias = bn->bias();
auto bnEps = bn->eps();
MNN_ASSERT(bnMean->getInfo()->dim.size() == 4);
auto rStd = _Const(1.0f) / _Sqrt(bnVar + _Const(bnEps));
auto alpha = rStd * bnScale;
auto beta = bnBias - bnMean * rStd * bnScale;
alpha = _Reshape(alpha, {alpha->getInfo()->size, 1, 1, 1});
beta = _Reshape(beta, {beta->getInfo()->size, 1, 1, 1});
fusedWeights = alpha * fusedWeights;
fusedBias = alpha * fusedBias + beta;
}
auto x = _Convert(inputs[0], NC4HW4);
2021-04-08 15:34:23 +08:00
int8_t inputZeroPoint, outputZeroPoint;
2020-02-26 09:57:17 +08:00
{
2021-04-08 15:34:23 +08:00
VARP channelScale, zeroPoint;
auto scaleAndZeroPoint = computeScaleAndZeroPoint(mInputMin, mInputMax, mInputClampValue);
mInputScale = scaleAndZeroPoint.first;
mInputZeroPoint = scaleAndZeroPoint.second;
// always PerTensor
channelScale = _Reciprocal(mInputScale);
zeroPoint = _Cast<int8_t>(mInputZeroPoint);
inputZeroPoint = zeroPoint->readMap<int8_t>()[0];
x = _FloatToInt8(x, channelScale, -int8_t(mInputClampValue->readMap<float>()[0]), int8_t(mInputClampValue->readMap<float>()[0]), inputZeroPoint);
}
{
VARP channelScale, zeroPoint;
auto scaleAndZeroPoint = computeScaleAndZeroPoint(mOutputMin, mOutputMax, mOutputClampValue);
mOutputScale = scaleAndZeroPoint.first;
mOutputZeroPoint = scaleAndZeroPoint.second;
// always PerTensor
channelScale = mOutputScale;
zeroPoint = _Cast<int8_t>(mOutputZeroPoint);
outputZeroPoint = zeroPoint->readMap<int8_t>()[0];
2020-02-26 09:57:17 +08:00
}
std::vector<int8_t> weight;
std::vector<float> bias;
std::vector<float> weightScaleVector;
2020-02-26 09:57:17 +08:00
{
VARP weightScale, quanWeight, convScale;
// auto newWeight = fusedWeights * mInputScale;
weightScale = _Maximum(_ReduceMax(_Abs(fusedWeights), {1, 2, 3}, true), _Scalar<float>(1E-6)) * mLimitScale;
quanWeight = _Cast<int8_t>(_Round(fusedWeights * _Reciprocal(weightScale)));
convScale = _Reciprocal(mOutputScale) * weightScale * mInputScale;
2021-04-08 15:34:23 +08:00
Variable::prepareCompute({quanWeight, convScale});
// // reference for how to get quantized bias
// auto remains = _ReduceSum(_Cast<int32_t>(mInputZeroPoint) * _Cast<int32_t>(quanWeight), {1, 2, 3}, true);
// MNN_ASSERT((mOutputZeroPoint->getInfo()->dim.size() == 0) && (mOutputZeroPoint->getInfo()->size == 1)); // only support per-tensor, per-channel is removed.
// auto outputZeroPointFused = _Cast<int32_t>(_Cast<float>(mOutputZeroPoint) * _Reciprocal(convScale));
// auto quanBias = _Cast<int32_t>(fusedBias * _Reciprocal(weightScale * mInputScale)) - remains + outputZeroPointFused;
2021-04-08 15:34:23 +08:00
2020-02-26 09:57:17 +08:00
{
auto info = quanWeight->getInfo();
weight.resize(info->size);
auto ptr = quanWeight->readMap<int8_t>();
::memcpy(weight.data(), ptr, weight.size() * sizeof(int8_t));
}
{
auto biasinfo = fusedBias->getInfo();
2020-02-26 09:57:17 +08:00
bias.resize(biasinfo->size);
auto ptr = fusedBias->readMap<float>();
::memcpy(bias.data(), ptr, bias.size() * sizeof(float));
auto info = weightScale->getInfo();
weightScaleVector.resize(info->size);
MNN_ASSERT(weightScaleVector.size() == bias.size());
auto ptrScale = weightScale->readMap<float>();
::memcpy(weightScaleVector.data(), ptrScale, weightScaleVector.size() * sizeof(float));
2020-02-26 09:57:17 +08:00
}
}
bool relu = mActivation == NN::None ? false : true;
res = _Conv(std::move(weight), std::move(bias), std::move(weightScaleVector), _Convert(x, NC4HW4), mOption.channel,
2021-04-08 15:34:23 +08:00
mOption.kernelSize, mOption.padMode, mOption.stride, mOption.dilate, mGroup, mOption.pads, relu,
mInputScale->readMap<float>()[0], mOutputScale->readMap<float>()[0],
2021-04-08 15:34:23 +08:00
inputZeroPoint, outputZeroPoint,
-int8_t(mOutputClampValue->readMap<float>()[0]), int8_t(mOutputClampValue->readMap<float>()[0]), mWeightClampValue->readMap<float>()[0], mAccumulateToInt16);
2020-02-26 09:57:17 +08:00
res->setName(name());
2021-04-08 15:34:23 +08:00
// always PerTensor
res = _Int8ToFloat(res, mOutputScale, outputZeroPoint);
2020-02-26 09:57:17 +08:00
}
return {res};
}
private:
2020-11-05 16:41:56 +08:00
ConvBNReluFusedModule() = default;
Module* clone(CloneContext* ctx) const override {
ConvBNReluFusedModule* module(new ConvBNReluFusedModule);
module->mConvParameter = mConvParameter;
module->mConvParameter.weight = ctx->getOrClone(mConvParameter.weight);
module->mConvParameter.bias = ctx->getOrClone(mConvParameter.bias);
module->mOption = mOption;
module->mGroup = mGroup;
module->mWeight = ctx->getOrClone(mWeight);
module->mBias = ctx->getOrClone(mBias);
module->mActivation = mActivation;
2021-01-06 16:29:37 +08:00
module->mBits = mBits;
2021-04-08 15:34:23 +08:00
module->mLimit = mLimit;
2020-11-05 16:41:56 +08:00
module->mLimitScale = ctx->getOrClone(mLimitScale);
2021-04-08 15:34:23 +08:00
module->mWeightClampValue = ctx->getOrClone(mWeightClampValue);
2020-11-05 16:41:56 +08:00
module->mInputScale = ctx->getOrClone(mInputScale);
module->mOutputScale = ctx->getOrClone(mOutputScale);
2021-04-08 15:34:23 +08:00
module->mInputMin = ctx->getOrClone(mInputMin);
module->mInputMax = ctx->getOrClone(mInputMax);
module->mOutputMin = ctx->getOrClone(mOutputMin);
module->mOutputMax = ctx->getOrClone(mOutputMax);
module->mInputZeroPoint = ctx->getOrClone(mInputZeroPoint);
module->mOutputZeroPoint = ctx->getOrClone(mOutputZeroPoint);
module->mInputMinPos = mInputMinPos;
module->mInputMaxPos = mInputMaxPos;
module->mOutputMinPos = mOutputMinPos;
module->mOutputMaxPos = mOutputMaxPos;
module->mInputClampValue = ctx->getOrClone(mInputClampValue);
module->mOutputClampValue = ctx->getOrClone(mOutputClampValue);
2020-11-05 16:41:56 +08:00
module->mMomentum = mMomentum;
module->mFeatureScaleStatMethod = mFeatureScaleStatMethod;
module->mScaleUpdateMethod = mScaleUpdateMethod;
if (mBatchNorm) {
module->mBatchNorm.reset(mBatchNorm->clone(ctx));
module->registerModel({module->mBatchNorm});
}
return this->cloneBaseTo(ctx, module);
}
2020-02-26 09:57:17 +08:00
NN::ConvParameters mConvParameter;
NN::ConvOption mOption;
int mGroup;
VARP mWeight;
VARP mBias;
NN::ActivationFunctionType mActivation = NN::ActivationFunctionType::None;
std::shared_ptr<Module> mBatchNorm = nullptr;
2021-01-06 16:29:37 +08:00
int mBits;
2021-04-08 15:34:23 +08:00
float mLimit;
2020-02-26 09:57:17 +08:00
VARP mLimitScale;
2021-04-08 15:34:23 +08:00
Express::VARP mWeightClampValue;
2020-02-26 09:57:17 +08:00
VARP mInputScale = nullptr;
VARP mOutputScale = nullptr;
2021-04-08 15:34:23 +08:00
VARP mInputMin = nullptr;
VARP mInputMax = nullptr;
VARP mOutputMin = nullptr;
VARP mOutputMax = nullptr;
VARP mInputZeroPoint = nullptr;
VARP mOutputZeroPoint = nullptr;
int mInputMinPos = -1;
int mInputMaxPos = -1;
int mOutputMinPos = -1;
int mOutputMaxPos = -1;
VARP mInputClampValue;
VARP mOutputClampValue;
2020-02-26 09:57:17 +08:00
float mMomentum = 0.99f;
NN::FeatureScaleStatMethod mFeatureScaleStatMethod;
NN::ScaleUpdateMethod mScaleUpdateMethod;
2021-04-08 15:34:23 +08:00
bool mAccumulateToInt16 = false;
2020-02-26 09:57:17 +08:00
};
Module* NN::ConvBNReluFused(std::vector<std::shared_ptr<Module> > modules,
NN::FeatureScaleStatMethod featureScaleStatMethod,
NN::ScaleUpdateMethod scaleUpdateMethod, const int bits) {
return new ConvBNReluFusedModule(modules, featureScaleStatMethod, scaleUpdateMethod, bits);
}
Module* NN::ConvInt8(const ConvOption& option, int bits, bool hasBias,
std::shared_ptr<Initializer> weightInit, std::shared_ptr<Initializer> biasInit, NN::FeatureScaleStatMethod featureMethod, NN::ScaleUpdateMethod method) {
std::shared_ptr<Module> conv(NN::Conv(option));
return new ConvBNReluFusedModule({conv}, featureMethod, method, bits);
}
Module* NN::ConvInt8(const ConvParameters& para, int bits, NN::FeatureScaleStatMethod featureMethod, NN::ScaleUpdateMethod method) {
std::shared_ptr<Module> conv(NN::Conv(para));
return new ConvBNReluFusedModule({conv}, featureMethod, method, bits);
}
bool NN::turnQuantize(Module* module, const int bits, NN::FeatureScaleStatMethod featureScaleStatMethod, NN::ScaleUpdateMethod scaleUpdateMethod) {
if (nullptr == module || module->type() != PIPELINE_MODULE) {
MNN_ERROR("Invalide module for quantized\n");
return false;
}
auto pipModule = static_cast<PipelineModule*>(module);
std::vector<int> needEraseIndices;
for (int i = 0; i < pipModule->mSubModules.size(); i++) {
auto& m = pipModule->mSubModules[i];
auto& theModule = std::get<0>(m);
auto moduleType = theModule->type();
//auto& inputIndices = std::get<1>(m);
auto& outputIndices = std::get<2>(m);
if (moduleType == "Conv" && i < pipModule->mSubModules.size() - 1) {
auto& p1 = pipModule->mSubModules[i+1];
auto p1Module = std::get<0>(p1);
auto& p1ModuleType = p1Module->type();
auto& p1InputIndices = std::get<1>(p1);
auto& p1OutputIndices = std::get<2>(p1);
auto convOutputCount = pipModule->countOutputReference(outputIndices);
bool convSingleOutputReference = ((outputIndices.size() == 1) && (convOutputCount[0] == 1));
// only conv
if ((!convSingleOutputReference) || (p1ModuleType == "Conv") ||
(p1ModuleType != "BatchNorm" && p1ModuleType != "ReLU" && p1ModuleType != "ReLU6")) {
theModule.reset(NN::ConvBNReluFused({theModule}, featureScaleStatMethod, scaleUpdateMethod, bits));
pipModule->registerModel({theModule});
continue;
}
// conv + bn + ?
if (p1ModuleType == "BatchNorm") {
bool convBnConnected = ((convSingleOutputReference) && (p1InputIndices.size() == 1) && (p1InputIndices[0] == outputIndices[0]));
if (!convBnConnected) {
theModule.reset(NN::ConvBNReluFused({theModule}, featureScaleStatMethod, scaleUpdateMethod, bits));
pipModule->registerModel({theModule});
continue;
}
// last conv + bn
if (i == pipModule->mSubModules.size() - 2) {
theModule.reset(NN::ConvBNReluFused({theModule, p1Module}, featureScaleStatMethod, scaleUpdateMethod, bits));
pipModule->registerModel({theModule});
outputIndices = p1OutputIndices;
needEraseIndices.emplace_back(i + 1);
continue;
}
// maybe there is a relu or relu6 after conv + bn
auto& p2 = pipModule->mSubModules[i+2];
auto& p2Module = std::get<0>(p2);
auto p2ModuleType = p2Module->type();
auto& p2InputIndices = std::get<1>(p2);
auto& p2OutputIndices = std::get<2>(p2);
auto bnOutputCount = pipModule->countOutputReference(p1OutputIndices);
bool bnSingleOutputReference = ((p1OutputIndices.size() == 1) && (bnOutputCount[0] == 1));
// only conv + bn
if ((!bnSingleOutputReference) || (p2ModuleType != "ReLU" && p2ModuleType != "ReLU6")) {
theModule.reset(NN::ConvBNReluFused({theModule, p1Module}, featureScaleStatMethod, scaleUpdateMethod, bits));
pipModule->registerModel({theModule});
outputIndices = p1OutputIndices;
needEraseIndices.emplace_back(i + 1);
continue;
} else { // conv + bn + relu or conv + bn + relu6
bool convBnReluConnected = ((bnSingleOutputReference) && (p2InputIndices.size() == 1) && (p2InputIndices[0] == p1OutputIndices[0]));
if (!convBnReluConnected) {
theModule.reset(NN::ConvBNReluFused({theModule, p1Module}, featureScaleStatMethod, scaleUpdateMethod, bits));
pipModule->registerModel({theModule});
outputIndices = p1OutputIndices;
needEraseIndices.emplace_back(i + 1);
continue;
}
theModule.reset(NN::ConvBNReluFused({theModule, p1Module, p2Module}, featureScaleStatMethod, scaleUpdateMethod, bits));
pipModule->registerModel({theModule});
outputIndices = p2OutputIndices;
needEraseIndices.emplace_back(i + 1);
needEraseIndices.emplace_back(i + 2);
continue;
}
}
// conv + relu or conv + relu6
if (p1ModuleType == "ReLU" || p1ModuleType == "ReLU6") {
bool convReluConnected = ((convSingleOutputReference) && (p1InputIndices.size() == 1) && (p1InputIndices[0] == outputIndices[0]));
if (!convReluConnected) {
theModule.reset(NN::ConvBNReluFused({theModule}, featureScaleStatMethod, scaleUpdateMethod, bits));
pipModule->registerModel({theModule});
continue;
}
theModule.reset(NN::ConvBNReluFused({theModule, p1Module}, featureScaleStatMethod, scaleUpdateMethod, bits));
pipModule->registerModel({theModule});
outputIndices = p1OutputIndices;
needEraseIndices.emplace_back(i + 1);
continue;
}
}
if (i == pipModule->mSubModules.size() - 1 && moduleType == "Conv") {
theModule.reset(NN::ConvBNReluFused({theModule}, featureScaleStatMethod, scaleUpdateMethod, bits));
pipModule->registerModel({theModule});
}
}
// erase useless submodules
const int eraseSize = needEraseIndices.size();
int alreadyErasedCount = 0;
for (int i = 0; i < eraseSize; i++) {
auto position = needEraseIndices[i] - alreadyErasedCount;
auto type = std::get<0>(pipModule->mSubModules[position])->type();
MNN_ASSERT(type == "BatchNorm" || type == "ReLU" || type == "ReLU6");
pipModule->mSubModules.erase(pipModule->mSubModules.begin() + position);
alreadyErasedCount++;
}
return true;
}
Module* NN::extract(std::vector<Express::VARP> inputs, std::vector<Express::VARP> outputs, bool fortrain, const std::map<std::string, SubGraph>& subGraph) {
std::function<std::pair<std::vector<int>, std::shared_ptr<Module>>(EXPRP)> transformFunction;
if (fortrain) {
transformFunction =
[&subGraph](EXPRP source) {
if (source->get() == nullptr) {
return std::make_pair(std::vector<int>{}, std::shared_ptr<Module>(nullptr));
}
std::shared_ptr<Module> m(NN::Utils::ExtractNotRunableOp(source, subGraph));
if (nullptr != m) {
m->setName(source->name());
return std::make_pair(std::vector<int>{}, m);
}
auto convExtracted = NN::Utils::ExtractConvolution(source);
if (convExtracted.weight == nullptr) {
return std::make_pair(std::vector<int>{}, std::shared_ptr<Module>(nullptr));
}
std::shared_ptr<Module> module(NN::Conv(convExtracted));
module->setName(source->name());
return std::make_pair(std::vector<int>{0}, module);
};
} else {
transformFunction = [&subGraph](EXPRP source) {
if (source->get() == nullptr) {
return std::make_pair(std::vector<int>{}, std::shared_ptr<Module>(nullptr));
}
std::shared_ptr<Module> m(NN::Utils::ExtractNotRunableOp(source, subGraph));
if (nullptr != m) {
m->setName(source->name());
return std::make_pair(std::vector<int>{}, m);
}
return std::make_pair(std::vector<int>{}, std::shared_ptr<Module>(nullptr));
};
}
return new PipelineModule(inputs, outputs, transformFunction);
}
2020-11-05 16:41:56 +08:00
} // namespace Express
2021-04-08 15:34:23 +08:00
} // namespace MNN