[MNN:Sync] A few bugfixes

1. 支持 Onnx If 空子图的情况(这种情况是条件判断一定为真或假)
    2. 修正 Where 算子在 zeroshape 下维度计算出错的问题
    3. 修正 Reduce 计算 zeroshape 的非 prod 情况
    4. 修正 arch64-linux 上编译错误
    5. 修正 头文件 NNAPI 的注释错误
    6, 部分训练相关问题修正
This commit is contained in:
xiaying 2022-12-04 15:17:36 +08:00
parent 767086816d
commit ad5d243c9f
17 changed files with 460 additions and 122 deletions

View File

@ -31,7 +31,7 @@ typedef enum {
MNN_FORWARD_OPENGL = 6,
MNN_FORWARD_VULKAN = 7,
/*Android 8.1's NNAPI, Not Support yet. CoreML Now*/
/*Android 8.1's NNAPI or CoreML for ios*/
MNN_FORWARD_NN = 5,
/*User can use API from Backend.hpp to add or search Backend*/

View File

@ -45,10 +45,17 @@ struct BinaryRealDiv {
}
};
/**
Ref from onnxruntime/onnxruntime/core/providers/cpu/math/element_wise_ops.cc :: Modulus
*/
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryModInt {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return x - (x / y) * y;
auto res = x % y;
if ((res < 0 && y > 0) || (res > 0 && y < 0)) {
res += y;
}
return (_ErrorCode)res;
}
};

View File

@ -54,6 +54,10 @@
#define CPUINFO_ARM_LINUX_FEATURE_FPHP UINT32_C(0x00000200)
#define CPUINFO_ARM_LINUX_FEATURE_ASIMDHP UINT32_C(0x00000400)
#define CPUINFO_ARM_LINUX_FEATURE_ASIMDDP UINT32_C(0x00100000)
#define CPUINFO_ARM_LINUX_FEATURE_I8MM UINT32_C(0x00002000)
#define CPUINFO_ARM_LINUX_FEATURE_SVE UINT32_C(0x00400000)
#define CPUINFO_ARM_LINUX_FEATURE_SVE2 UINT32_C(0x00000002)
#endif /* __linux__ && __aarch64__ */
#ifdef __ANDROID__
@ -1564,4 +1568,4 @@ void cpuinfo_arm_init(struct cpuinfo_arm_isa* cpuinfo_isa) {
#endif
MNN_PRINT("The device support dot:%d, support fp16:%d, support i8mm: %d\n", cpuinfo_isa->dot, cpuinfo_isa->fp16arith, cpuinfo_isa->i8mm);
}
}

View File

@ -19,14 +19,23 @@ public:
auto reduct = op->main_as_ReductionParam();
auto reductOp = reduct->operation();
// prod([]) = 1
if (inputs[0]->elementSize() == 0 && reductOp == ReductionType_PROD) {
if (inputs[0]->elementSize() == 0) {
if(!context.allocTensor(outputs[0])) {
return false;
}
float res;
switch (reductOp) {
case ReductionType_PROD:
res = 1.0f;
break;
default:
res = 0.0f;
break;
}
if (outputs[0]->getType() == halide_type_of<float>()) {
outputs[0]->host<float>()[0] = 1.f;
outputs[0]->host<float>()[0] = (float)res;
} else {
outputs[0]->host<int>()[0] = 1;
outputs[0]->host<int>()[0] = (int)res;
}
return true;
}

View File

@ -41,7 +41,6 @@ class WhereSizeComputer : public SizeComputer {
}
// For zeroshape input
if (nullptr == inputs[0]->host<void>()) {
ob.dimensions = 1;
ob.dim[0].extent = 0;
return true;
}

View File

@ -226,6 +226,9 @@ void setInputOutputForOps(std::vector<std::shared_ptr<Tensor>>& allTensors, cons
if (des->usage == Tensor::InsideDescribe::CONSTANT) {
continue;
}
if (des->usage == Tensor::InsideDescribe::TRAINABLE) {
continue;
}
des->usage = Tensor::InsideDescribe::INPUT;
}
for (auto index : output) {

View File

@ -185,9 +185,26 @@ public:
{4, 2}, {2}, {4, 2});
}
};
class ModTest : public BinaryTestCommon {
class ModTestInt : public BinaryTestCommon {
public:
virtual ~ModTest() = default;
virtual ~ModTestInt() = default;
virtual bool run(int precision) {
std::vector<int> x = {
-4, 7, 5, 4, -7, 8
};
std::vector<int> y = {
2, -3, 8, -2, 3, 5
};
std::vector<int> z = {
0, -2, 5, 0, 2, 3
};
return test<int, int>(_Mod, "ModTestFloat", 0,
x,y,z, {6}, {6}, {6});
}
};
class ModTestFloat : public BinaryTestCommon {
public:
virtual ~ModTestFloat() = default;
virtual bool run(int precision) {
std::vector<float> x = {
1.1f, 2.3f, 3.5f, 4.7f, 5.9f, 6.2f, 7.4f, 8.6f
@ -201,7 +218,7 @@ public:
z[i + j * 2] = FP32Converter[precision](fmodf(FP32Converter[precision](x[i+j*2]), FP32Converter[precision](y[i])));
}
}
return test<float, float>(_Mod, "ModTest", 0,
return test<float, float>(_Mod, "ModTestFloat", 0,
x,y,z,
{4, 2}, {2}, {4, 2});
}
@ -368,7 +385,8 @@ MNNTestSuiteRegister(SquaredDifferenceTest, "op/binary/squareddifference");
MNNTestSuiteRegister(EqualTest, "op/binary/equal");
MNNTestSuiteRegister(LessEqualTest, "op/binary/lessequal");
MNNTestSuiteRegister(FloorModTest, "op/binary/floormod");
MNNTestSuiteRegister(ModTest, "op/binary/mod");
MNNTestSuiteRegister(ModTestFloat, "op/binary/mod_float");
MNNTestSuiteRegister(ModTestInt, "op/binary/mod_int");
MNNTestSuiteRegister(Atan2Test, "op/binary/atan2");
MNNTestSuiteRegister(LogicalOrTest, "op/binary/logicalor");
MNNTestSuiteRegister(NotEqualTest, "op/binary/notqual");

View File

@ -21,15 +21,14 @@ MNN::OpParameter IfOnnx::type() {
void IfOnnx::run(MNN::OpT* dstOp, const onnx::NodeProto* onnxNode,
OnnxScope* scope) {
auto param = new MNN::IfParamT;
dstOp->name += "/If";
param->then_graph = dstOp->name + "/then";
param->else_graph = dstOp->name + "/else";
const ::onnx::GraphProto *thenG = nullptr, *elseG = nullptr;
for (const auto& attr : onnxNode->attribute()) {
if (attr.name() == "then_branch") {
thenG = &attr.g();
param->then_graph = thenG->name();
} else if (attr.name() == "else_branch") {
elseG = &attr.g();
param->else_graph = elseG->name();
}
}
if (thenG == nullptr || elseG == nullptr) {
@ -50,8 +49,21 @@ void IfOnnx::run(MNN::OpT* dstOp, const onnx::NodeProto* onnxNode,
std::transform(graph->output().begin(), graph->output().end(), outputs.begin(), [](const ::onnx::ValueInfoProto& p) { return p.name(); });
return std::make_pair(inputs, outputs);
};
auto thenInOuts = dealWithSubGraph(thenG, param->then_graph);
auto elseInOuts = dealWithSubGraph(elseG, param->else_graph);
std::pair<std::vector<std::string>, std::vector<std::string>> thenInOuts, elseInOuts;
MNN_ASSERT(thenG->node_size() > 0 || elseG->node_size() > 0);
if (thenG->node_size() > 0) {
thenInOuts = dealWithSubGraph(thenG, param->then_graph);
}
if (elseG->node_size() > 0) {
elseInOuts = dealWithSubGraph(elseG, param->else_graph);
}
if (thenG->node_size() == 0) {
thenInOuts = elseInOuts;
param->then_graph = param->else_graph;
} else if (elseG->node_size() == 0) {
elseInOuts = thenInOuts;
param->else_graph = param->then_graph;
}
auto thenInputs = thenInOuts.first, thenOutputs = thenInOuts.second;
auto elseInputs = elseInOuts.first, elseOutputs = elseInOuts.second;

View File

@ -7,6 +7,7 @@ add_executable(SequenceModuleTest.out ${CMAKE_CURRENT_LIST_DIR}/SequenceModuleTe
list(APPEND MNN_CPP_TOOLS SequenceModuleTest.out)
add_executable(mergeInplaceForCPU ${CMAKE_CURRENT_LIST_DIR}/mergeInplaceForCPU.cpp)
list(APPEND MNN_CPP_TOOLS mergeInplaceForCPU)
add_executable(MNNV2Basic.out ${CMAKE_CURRENT_LIST_DIR}/MNNV2Basic.cpp)
list(APPEND MNN_CPP_TOOLS MNNV2Basic.out)

View File

@ -37,7 +37,7 @@ target_compile_definitions(MNNTrain PRIVATE STB_IMAGE_STATIC STB_IMAGE_IMPLEMENT
# executables
set(MNN_TRAIN_TOOLS "")
add_executable(transformer.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/transformerExecution.cpp)
add_executable(transformer.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/transformerExecution.cpp ${TRANSFORMER})
add_executable(train.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/train.cpp ${SCHEMA} ${BASIC_INCLUDE})
add_executable(rawDataTransform.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/rawDataTransform.cpp ${SCHEMA} ${BASIC_INCLUDE})
add_executable(dataTransformer.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/dataTransformer.cpp ${SCHEMA} ${BASIC_INCLUDE})

View File

@ -22,6 +22,7 @@
#define MNN_OPEN_TIME_TRACE
#include <MNN/AutoTime.hpp>
#include "rapidjson/document.h"
#include <algorithm>
using namespace MNN;
using namespace MNN::Express;
@ -82,6 +83,14 @@ int main(int argc, const char* argv[]) {
MNN_PRINT("optimizer type: %s\n", optimizerType.c_str());
}
}
auto bnMomentum = new MNN::AttributeT;
bnMomentum->f = 0.99;
if (configObject.HasMember("BatchNorm")) {
auto bnConfig = configObject["BatchNorm"].GetObject();
if (bnConfig.HasMember("momentum")) {
bnMomentum->f = bnConfig["momentum"].GetFloat();
}
}
const char* inputModeFileName = argv[1];
FUNC_PRINT_ALL(inputModeFileName, s);
std::map<std::string, VARP> inputVars;
@ -94,7 +103,10 @@ int main(int argc, const char* argv[]) {
Transformer::TrainConfig trainConfig;
trainConfig.noUpdateOps = std::move(noUpdateOps);
trainConfig.onlyUpdateOps = std::move(onlyUpdateOps);
Transformer::turnModelToTrainable(trainConfig)->onExecute(Variable::mapToSequence(outputVars));
trainConfig.extraParams["BatchNorm"]["momentum"] = bnMomentum;
auto turnTrainable = Train::TurnTrainable(trainConfig);
turnTrainable.onExecute(Variable::mapToSequence(outputVars));
auto trainInfo = turnTrainable.trainInfo;
if (configObject.HasMember("Shape")) {
auto shapeArray = configObject["Shape"].GetObject();
for (auto shapeIter = shapeArray.begin(); shapeIter != shapeArray.end(); shapeIter++) {
@ -189,10 +201,17 @@ int main(int argc, const char* argv[]) {
stepPlus1->setName("optimize_step+1");
varUpdateMap[step] = stepPlus1;
std::map<std::string, std::string> extraInputs;
extraInputs["LearningRate"] = "float";
extraInputs["WeightDecay"] = "float";
if (trainInfo["is_training_float"]->linkNumber() > 0) {
extraInputs["is_training"] = "int, 0 or 1";
}
if (optimizerType == "SGD") {
auto momentum = _Input();
momentum->setName("Momentum");
MNN_PRINT(">>>\nextra input tensors for SGD: LearningRate, WeightDecay, Momentum\n<<<\n");
extraInputs["Momentum"] = "float";
for (auto iter : gradMap) {
auto p = iter.first;
@ -200,6 +219,18 @@ int main(int argc, const char* argv[]) {
auto grad = iter.second;
grad->setName(p->name()+"_grad");
if (p->name().find("_BN_RunningMean_Weight") != string::npos) {
varUpdateMap[p] = trainInfo[p->name()];
continue; // not update running stats
}
if (p->name().find("_BN_RunningVariance_Weight") != string::npos) {
varUpdateMap[p] = trainInfo[p->name()];
continue; // not update running stats
}
if (p->name().find("_BN_Eps_Weight") != string::npos) {
continue; // not update eps
}
auto pInfo = p->getInfo();
auto pDims = pInfo->dim;
@ -216,10 +247,14 @@ int main(int argc, const char* argv[]) {
VARP history = _Const(0.0f, pDims, pInfo->order);
history->setName(p->name() + "_momentum");
history.fix(VARP::TRAINABLE);
auto newHistory = _Multiply(learningRate, gradWithDecay) + momentum * history;
auto newHistory = gradWithDecay + momentum * history;
newHistory->setName("update_" + history->name());
auto updateValue = _Subtract(p, history);
auto finalGrad = learningRate * history;
finalGrad->setName(p->name() + "_final_grad");
auto updateValue = _Subtract(p, finalGrad);
updateValue->setName("update_" + p->name());
varUpdateMap[p] = updateValue;
varUpdateMap[history] = newHistory;
@ -231,7 +266,10 @@ int main(int argc, const char* argv[]) {
beta2->setName("Beta2");
auto eps = _Input();
eps->setName("Eps");
MNN_PRINT(">>>\nextra input tensors for ADAM: LearningRate, WeightDecay, Beta1, Beta2, Eps\n<<<\n");
extraInputs["Beta1"] = "float";
extraInputs["Beta2"] = "float";
extraInputs["Eps"] = "float";
auto correction = _Sqrt(_Const(1.0f, {}, NCHW) - _Pow(beta2, step)) / (_Const(1.0f, {}, NCHW) - _Pow(beta1, step));
correction->setName("correction");
@ -242,6 +280,18 @@ int main(int argc, const char* argv[]) {
auto grad = iter.second;
grad->setName(p->name()+"_grad");
if (p->name().find("_BN_RunningMean_Weight") != string::npos) {
varUpdateMap[p] = trainInfo[p->name()];
continue; // not update running stats
}
if (p->name().find("_BN_RunningVariance_Weight") != string::npos) {
varUpdateMap[p] = trainInfo[p->name()];
continue; // not update running stats
}
if (p->name().find("_BN_Eps_Weight") != string::npos) {
continue; // not update eps
}
auto pInfo = p->getInfo();
auto pDims = pInfo->dim;
@ -280,6 +330,12 @@ int main(int argc, const char* argv[]) {
MNN_ERROR("error: don't support optimizer type: %s\n", optimizerType.c_str());
}
MNN_PRINT(">>>\nextra input tensors for %s:\n\n", optimizerType.c_str());
for (auto& input : extraInputs) {
MNN_PRINT("name: %s, \ttype: %s\n", input.first.c_str(), input.second.c_str());
}
MNN_PRINT("<<<\n");
std::unique_ptr<MNN::NetT> netStruct(new MNN::NetT);
std::vector<VARP> resultOutputs;
for (auto output : outputVars) {
@ -298,7 +354,21 @@ int main(int argc, const char* argv[]) {
for (int j = 0; j < netStruct->oplists.size(); ++j) {
auto& opSub = netStruct->oplists[j];
if (opSub->name == iter.first->name()) {
auto indexOri = op->outputIndexes;
op->outputIndexes = opSub->outputIndexes;
if ((opSub->name.find("_BN_RunningMean_Weight") != string::npos) || (opSub->name.find("_BN_RunningVariance_Weight") != string::npos)) {
for (int k = 0; k < netStruct->oplists.size(); ++k) {
auto& opSubSub = netStruct->oplists[k];
if (opSubSub->inputIndexes.size() > 0) {
for (int kk = 0; kk < opSubSub->inputIndexes.size(); kk++) {
if (opSubSub->inputIndexes[kk] == indexOri[0]) {
opSubSub->inputIndexes[kk] = opSub->outputIndexes[0];
}
}
}
}
}
}
}
}

View File

@ -70,7 +70,7 @@ Express::VARP SGD::regularizeParameters(Express::VARP param, Express::VARP grad)
Express::VARP SGD::onComputeUpdateValue(Express::VARP param, Express::VARP grad) {
auto lr = _Const(mLearningRate, {}, NCHW);
mHistory[param] = lr * grad + _Const(mMomentum, {}, NCHW) * mHistory[param];
mHistory[param] = lr * (grad + _Const(mMomentum, {}, NCHW) * mHistory[param]);
mHistory[param].fix(Express::VARP::CONSTANT);
//FUNC_PRINT_ALL(_ReduceMax(grad)->readMap<float>()[0], f);
return mHistory[param];

View File

@ -31,12 +31,88 @@ void OpConverter::insert(MNN::OpType type, OpConverter* converter) {
converterMap.insert(std::make_pair(type, converter));
}
EXPRP OpConverter::convert(EXPRP source) {
EXPRP OpConverter::convert(EXPRP source, std::map<std::string, MNN::Express::VARP>& helpInfo) {
auto opOrigin = source->get();
if (nullptr == opOrigin) {
return source;
}
std::unique_ptr<MNN::OpT> op(opOrigin->UnPack());
if (op->type == OpType_BatchNorm) {
printf("transform batchnorm: %s\n", source->name().c_str());
auto params = op->main.AsBatchNorm();
auto channels = params->channels;
auto input = source->inputs()[0];
if (input->getInfo()->dim.size() != 4) {
printf("only support BatchNorm with 4-D input\n");
return nullptr;
}
auto preExpr = input->expr().first;
bool cond = (preExpr->get() != nullptr) && (preExpr->get()->type() == OpType_Convolution) && (preExpr->inputs().size() == 3);
auto oriInputOrder = input->getInfo()->order;
if (oriInputOrder == NC4HW4) {
input = _Convert(input, NCHW);
if (cond) input->setName(source->name() + "_MNN_BN_after_conv_first_op");
else input->setName(source->name() + "_MNN_single_BN_first_op");
}
auto inputOrder = input->getInfo()->order;
std::vector<int> reduceDims = {0, 2, 3};
std::vector<int> statShape = {1, channels, 1, 1};
if (inputOrder == NHWC) {
reduceDims = {0, 1, 2};
statShape = {1, 1, 1, channels};
}
auto rMean = _Const((void*)params->meanData.data(), statShape, inputOrder);
rMean->setName(source->name() + "_BN_RunningMean_Weight");
auto rVar = _Const((void*)params->varData.data(), statShape, inputOrder);
rVar->setName(source->name() + "_BN_RunningVariance_Weight");
auto w = _Const((void*)params->slopeData.data(), statShape, inputOrder);
w->setName(source->name() + "_BN_Gamma_Weight");
auto b = _Const((void*)params->biasData.data(), statShape, inputOrder);
b->setName(source->name() + "_BN_Beta_Bias");
auto eps = _Scalar<float>(params->epsilon);
eps->setName(source->name() + "_BN_Eps_Weight");
auto meanX = _ReduceMean(input, reduceDims, true);
meanX->setName(source->name() + "_BN_xmean");
auto varX = _ReduceMean(_Square(input - meanX), reduceDims, true);
varX->setName(source->name() + "_BN_xvariance");
auto isTraining = helpInfo["is_training_float"];
auto one = helpInfo["one_float"];
auto momentum = helpInfo["bn_momentum"] * isTraining + (one - isTraining) * one;
auto mMean = momentum * rMean + (one - momentum) * meanX;
mMean->setName(source->name() + "_BN_momentum_mean");
helpInfo[rMean->name()] = mMean;
auto mVar = momentum * rVar + (one - momentum) * varX;
mVar->setName(source->name() + "_BN_momentum_variance");
helpInfo[rVar->name()] = mVar;
auto meanFinal = isTraining * meanX + (one - isTraining) * mMean;
meanFinal->setName(source->name() + "_BN_mean_final");
auto varFinal = isTraining * varX + (one - isTraining) * mVar;
varFinal->setName(source->name() + "_BN_variance_final");
auto stdFinal = _Sqrt(varFinal + eps);
auto subMean = input - meanFinal;
if (oriInputOrder != NC4HW4) {
if (cond) subMean->setName(source->name() + "_MNN_BN_after_conv_first_op");
else subMean->setName(source->name() + "_MNN_single_BN_first_op");
}
auto normed = subMean / stdFinal;
auto res = normed * w + b;
if (oriInputOrder == NC4HW4) {
res = _Convert(res, oriInputOrder);
}
res->setName(source->name());
return res->expr().first;
}
if (op->type != OpType_Convolution && op->type != OpType_ConvolutionDepthwise) {
return source;
}

View File

@ -16,7 +16,7 @@ class MNN_PUBLIC OpConverter {
public:
OpConverter() = default;
static MNN::Express::EXPRP convert(MNN::Express::EXPRP source);
static MNN::Express::EXPRP convert(MNN::Express::EXPRP source, std::map<std::string, MNN::Express::VARP>& helpInfo);
virtual ~OpConverter() = default;
static OpConverter* get(MNN::OpType type);

View File

@ -9,80 +9,76 @@
#include "Transformer.hpp"
#include "OpConverter.hpp"
#include "MNN_generated.h"
#include <MNN/expr/ExprCreator.hpp>
using namespace MNN::Express;
namespace MNN {
namespace Train {
class TurnTrainable : public Express::Optimizer {
public:
TurnTrainable(Transformer::TrainConfig config) {
mConfig = std::move(config);
}
virtual Cost onMeasure(const std::vector<VARP>& outputs,
std::shared_ptr<Parameters> parameters = nullptr) override {
return Cost();
}
virtual bool onExecute(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> p) override {
auto exprs = Variable::getExecuteOrder(outputs);
{
// Turn convolution be trainable convolution
for (auto expr : exprs) {
auto newExpr = OpConverter::convert(expr);
if (newExpr.get() != expr.get()) {
Expr::replace(expr, newExpr);
}
bool TurnTrainable::onExecute(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> p) {
auto exprs = Variable::getExecuteOrder(outputs);
{
auto isTraining = _Input({}, NCHW, halide_type_of<int>());
isTraining->setName("is_training");
trainInfo["is_training"] = isTraining;
isTraining = _Cast<float>(isTraining);
isTraining->setName("is_training_float");
trainInfo["is_training_float"] = isTraining;
trainInfo["one_float"] = _Scalar<float>(1.0f);
trainInfo["bn_momentum"] = _Scalar<float>(mConfig.extraParams["BatchNorm"]["momentum"]->f);
// Turn convolution be trainable convolution
for (auto expr : exprs) {
auto newExpr = OpConverter::convert(expr, trainInfo);
if (newExpr.get() != expr.get()) {
Expr::replace(expr, newExpr);
}
}
exprs = Variable::getExecuteOrder(outputs);
auto& noUpdateOps = mConfig.noUpdateOps;
auto& onlyUpdateOps = mConfig.onlyUpdateOps;
// Collect Const Variable and turn to Trainable
for (auto v : exprs) {
if (v->get() == nullptr && VARP::INPUT != v->inputType()) {
auto name = v->name();
auto info = v->outputInfo(0);
if (halide_type_float != info->type.code) {
continue;
}
}
exprs = Variable::getExecuteOrder(outputs);
auto& noUpdateOps = mConfig.noUpdateOps;
auto& onlyUpdateOps = mConfig.onlyUpdateOps;
// Collect Const Variable and turn to Trainable
for (auto v : exprs) {
if (v->get() == nullptr && VARP::INPUT != v->inputType()) {
auto name = v->name();
auto info = v->outputInfo(0);
if (halide_type_float != info->type.code) {
continue;
}
bool update;
if (!onlyUpdateOps.empty()) {
update = false;
for (auto limit : onlyUpdateOps) {
if (name.find(limit) != std::string::npos) {
update = true;
break;
}
}
} else {
update = true;
for (auto limit : noUpdateOps) {
if (name.find(limit) != std::string::npos) {
update = false;
break;
}
bool update;
if (!onlyUpdateOps.empty()) {
update = false;
for (auto limit : onlyUpdateOps) {
if (name.find(limit) != std::string::npos) {
update = true;
break;
}
}
auto va = Variable::create(v, 0);
if (update) {
MNN_PRINT("Add Variable: %s\n", name.c_str());
va.fix(VARP::TRAINABLE);
if (name.find("Weight") == std::string::npos && name.find("Bias") == std::string::npos) {
MNN_PRINT(">>>\ncheck mnn model if const '%s' is a learnable parameter in your original training model, ", name.c_str());
MNN_PRINT("if not, add it to transformConfig.json NoUpdateOps\n<<<\n");
} else {
update = true;
for (auto limit : noUpdateOps) {
if (name.find(limit) != std::string::npos) {
update = false;
break;
}
} else {
va.fix(VARP::CONSTANT);
}
}
auto va = Variable::create(v, 0);
if (update && name != "") {
MNN_PRINT("Add Variable: %s\n", name.c_str());
va.fix(VARP::TRAINABLE);
if (name.find("Weight") == std::string::npos && name.find("Bias") == std::string::npos) {
MNN_PRINT(">>>\ncheck mnn model if const '%s' is a learnable parameter in your original training model, ", name.c_str());
MNN_PRINT("if not, add it to transformConfig.json NoUpdateOps\n<<<\n");
}
} else {
va.fix(VARP::CONSTANT);
}
}
return true;
}
private:
Transformer::TrainConfig mConfig;
};
return true;
}
std::shared_ptr<Express::Optimizer> Transformer::turnModelToTrainable(TrainConfig config) {
std::shared_ptr<Express::Optimizer> res;
@ -90,46 +86,156 @@ std::shared_ptr<Express::Optimizer> Transformer::turnModelToTrainable(TrainConfi
return res;
}
class InferOptimizer : public Express::Optimizer {
public:
InferOptimizer(){}
virtual Cost onMeasure(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> parameters = nullptr) override {
Cost c;
return c;
};
bool InferOptimizer::onExecute(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> parameters) {
auto exprs = Variable::getExecuteOrder(outputs);
virtual bool onExecute(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> parameters = nullptr) override {
auto exprs = Variable::getExecuteOrder(outputs);
for (auto& iter : exprs) {
auto op = iter->get();
if (nullptr == op) {
continue;
}
if (op->type() != OpType_ConvInt8 && op->type() != OpType_DepthwiseConvInt8) {
continue;
}
auto inputExpr = iter->inputs()[0]->expr().first;
if (inputExpr->get() == nullptr) {
continue;
}
if (inputExpr->get()->type() != OpType_FloatToInt8) {
continue;
}
auto subInputExpr = inputExpr->inputs()[0]->expr().first;
if (subInputExpr->get() == nullptr) {
continue;
}
if (subInputExpr->get()->type() != OpType_Int8ToFloat) {
continue;
}
//MNN_PRINT("Find direct\n");
std::vector<VARP> newInputs = subInputExpr->inputs();
auto newExpr = Expr::create(iter->extra(), std::move(newInputs));
Expr::replace(iter, newExpr);
// convert trainable to const
for (auto& expr : exprs) {
if (expr->inputs().size() == 0 && expr->inputType() == VARP::InputType::TRAINABLE) {
auto newConst = Variable::create(expr);
newConst.fix(VARP::InputType::CONSTANT);
newConst->setName(expr->name());
auto newExpr = newConst->expr().first;
newExpr->setName(expr->name());
Expr::replace(expr, newExpr);
}
return true;
}
};
// merge bn after conv into conv
// convert single bn to scale
std::set<std::string> bnNames;
std::string pattern1 = "_MNN_BN_after_conv_first_op";
std::string pattern2 = "_MNN_single_BN_first_op";
for (auto& expr : exprs) {
if (expr->name().find(pattern1) != std::string::npos) {
std::string bnName = expr->name();
for (int i = 0; i < pattern1.size(); i++) {
bnName.pop_back();
}
bnNames.insert(bnName);
}
if (expr->name().find(pattern2) != std::string::npos) {
std::string bnName = expr->name();
for (int i = 0; i < pattern2.size(); i++) {
bnName.pop_back();
}
bnNames.insert(bnName);
}
}
std::map<std::string, std::map<std::string, EXPRP>> bnInfo;
for (auto& name : bnNames) {
for (auto& expr : exprs) {
auto inputs = expr->inputs();
if (expr->name() == name) {
bnInfo[name]["Self"] = expr;
}
if (inputs.size() == 0 && expr->name() == name + "_BN_RunningMean_Weight") {
bnInfo[name]["RunningMean"] = expr;
}
if (inputs.size() == 0 && expr->name() == name + "_BN_RunningVariance_Weight") {
bnInfo[name]["RunningVariance"] = expr;
}
if (inputs.size() == 0 && expr->name() == name + "_BN_Gamma_Weight") {
bnInfo[name]["Gamma"] = expr;
}
if (inputs.size() == 0 && expr->name() == name + "_BN_Beta_Bias") {
bnInfo[name]["Bias"] = expr;
}
if (inputs.size() == 0 && expr->name() == name + "_BN_Eps_Weight") {
bnInfo[name]["Eps"] = expr;
}
if (expr->name() == name + pattern1) {
bnInfo[name]["FirstOpAfterConv"] = expr;
}
if (expr->name() == name + pattern2) {
bnInfo[name]["FirstOpSingleBN"] = expr;
}
}
}
for (auto& bn : bnInfo) {
auto bnName = bn.first;
auto info = bn.second;
bool bnAfterConv = false;
if (info.find("FirstOpAfterConv") != info.end()) {
bnAfterConv = true;
}
auto rm = _Convert(Variable::create(info["RunningMean"]), NCHW);
auto rv = _Convert(Variable::create(info["RunningVariance"]), NCHW);
auto gamma = _Convert(Variable::create(info["Gamma"]), NCHW);
auto bias = _Convert(Variable::create(info["Bias"]), NCHW);
auto eps = Variable::create(info["Eps"]);
auto s = _Sqrt(rv + eps);
auto alpha = gamma / s;
auto beta = bias - rm / s * gamma;
if (bnAfterConv) {
auto firstOp = info["FirstOpAfterConv"];
auto convExpr = firstOp->inputs()[0]->expr().first;
if (convExpr->get() == nullptr || convExpr->get()->type() != OpType_Convolution) {
continue;
}
auto convInput = convExpr->inputs()[0];
auto w = convExpr->inputs()[1];
auto b = convExpr->inputs()[2];
auto nw = w * _Reshape(alpha, {b->getInfo()->dim[0], 1, 1, 1});
nw.fix(w->expr().first->inputType());
nw->setName(w->name());
auto nb = _Reshape(alpha, {b->getInfo()->dim}) * b + _Reshape(beta, b->getInfo()->dim);
nb.fix(b->expr().first->inputType());
nb->setName(b->name());
std::vector<VARP> newInputs = {convInput, nw, nb};
auto newConv = Expr::create(convExpr->extra(), std::move(newInputs));
Expr::replace(info["Self"], newConv);
} else {
auto firstOp = info["FirstOpSingleBN"];
auto inputs = firstOp->inputs();
std::vector<float> scales, p;
for (int i = 0; i < beta->getInfo()->size; i++) {
scales.push_back(alpha->readMap<float>()[i]);
p.push_back(beta->readMap<float>()[i]);
}
auto res = _Scale(inputs[0], beta->getInfo()->size, std::move(scales), std::move(p));
res->setName(info["Self"]->name());
Expr::replace(info["Self"], res->expr().first);
}
}
exprs = Variable::getExecuteOrder(outputs);
for (auto& iter : exprs) {
auto op = iter->get();
if (nullptr == op) {
continue;
}
if (op->type() != OpType_ConvInt8 && op->type() != OpType_DepthwiseConvInt8) {
continue;
}
auto inputExpr = iter->inputs()[0]->expr().first;
if (inputExpr->get() == nullptr) {
continue;
}
if (inputExpr->get()->type() != OpType_FloatToInt8) {
continue;
}
auto subInputExpr = inputExpr->inputs()[0]->expr().first;
if (subInputExpr->get() == nullptr) {
continue;
}
if (subInputExpr->get()->type() != OpType_Int8ToFloat) {
continue;
}
//MNN_PRINT("Find direct\n");
std::vector<VARP> newInputs = subInputExpr->inputs();
auto newExpr = Expr::create(iter->extra(), std::move(newInputs));
Expr::replace(iter, newExpr);
}
return true;
}
std::shared_ptr<Express::Optimizer> Transformer::turnModelToInfer() {
return std::shared_ptr<Optimizer>(new InferOptimizer);

View File

@ -9,6 +9,7 @@
#ifndef Transformer_hpp
#define Transformer_hpp
#include <MNN/expr/Optimizer.hpp>
#include <MNN_generated.h>
namespace MNN {
namespace Train {
@ -17,11 +18,40 @@ public:
struct TrainConfig {
std::vector<std::string> noUpdateOps;
std::vector<std::string> onlyUpdateOps;
std::map<std::string, std::map<std::string, MNN::AttributeT*>> extraParams;
};
static std::shared_ptr<Express::Optimizer> turnModelToTrainable(TrainConfig config);
static std::shared_ptr<Express::Optimizer> turnModelToInfer();
};
class TurnTrainable : public Express::Optimizer {
public:
TurnTrainable(Transformer::TrainConfig config) {
mConfig = std::move(config);
}
virtual Cost onMeasure(const std::vector<Express::VARP>& outputs,
std::shared_ptr<Parameters> parameters = nullptr) override {
return Cost();
}
virtual bool onExecute(const std::vector<Express::VARP>& outputs, std::shared_ptr<Parameters> p = nullptr) override;
public:
std::map<std::string, Express::VARP> trainInfo;
private:
Transformer::TrainConfig mConfig;
};
class InferOptimizer : public Express::Optimizer {
public:
InferOptimizer(){}
virtual Cost onMeasure(const std::vector<Express::VARP>& outputs, std::shared_ptr<Parameters> parameters = nullptr) override {
Cost c;
return c;
}
virtual bool onExecute(const std::vector<Express::VARP>& outputs, std::shared_ptr<Parameters> p = nullptr) override;
};
} // namespace Train
} // namespace MNN
#endif

View File

@ -9,6 +9,9 @@
"StopBackPropOps":[],
"type": "SGD"
},
"BatchNorm": {
"momentum":0.99
},
"Debug": {
"L2Norm": []
},