diff --git a/include/MNN/MNNForwardType.h b/include/MNN/MNNForwardType.h index 5ecb4b10f..eb30a7285 100644 --- a/include/MNN/MNNForwardType.h +++ b/include/MNN/MNNForwardType.h @@ -31,7 +31,7 @@ typedef enum { MNN_FORWARD_OPENGL = 6, MNN_FORWARD_VULKAN = 7, - /*Android 8.1's NNAPI, Not Support yet. CoreML Now*/ + /*Android 8.1's NNAPI or CoreML for ios*/ MNN_FORWARD_NN = 5, /*User can use API from Backend.hpp to add or search Backend*/ diff --git a/source/backend/cpu/BinaryUtils.hpp b/source/backend/cpu/BinaryUtils.hpp index a4e1d2c79..d4b98dcd8 100644 --- a/source/backend/cpu/BinaryUtils.hpp +++ b/source/backend/cpu/BinaryUtils.hpp @@ -45,10 +45,17 @@ struct BinaryRealDiv { } }; +/** + Ref from onnxruntime/onnxruntime/core/providers/cpu/math/element_wise_ops.cc :: Modulus + */ template struct BinaryModInt { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { - return x - (x / y) * y; + auto res = x % y; + if ((res < 0 && y > 0) || (res > 0 && y < 0)) { + res += y; + } + return (_ErrorCode)res; } }; diff --git a/source/backend/cpu/CPURuntime.cpp b/source/backend/cpu/CPURuntime.cpp index d00b63ce2..eb848c010 100644 --- a/source/backend/cpu/CPURuntime.cpp +++ b/source/backend/cpu/CPURuntime.cpp @@ -54,6 +54,10 @@ #define CPUINFO_ARM_LINUX_FEATURE_FPHP UINT32_C(0x00000200) #define CPUINFO_ARM_LINUX_FEATURE_ASIMDHP UINT32_C(0x00000400) #define CPUINFO_ARM_LINUX_FEATURE_ASIMDDP UINT32_C(0x00100000) +#define CPUINFO_ARM_LINUX_FEATURE_I8MM UINT32_C(0x00002000) +#define CPUINFO_ARM_LINUX_FEATURE_SVE UINT32_C(0x00400000) +#define CPUINFO_ARM_LINUX_FEATURE_SVE2 UINT32_C(0x00000002) + #endif /* __linux__ && __aarch64__ */ #ifdef __ANDROID__ @@ -1564,4 +1568,4 @@ void cpuinfo_arm_init(struct cpuinfo_arm_isa* cpuinfo_isa) { #endif MNN_PRINT("The device support dot:%d, support fp16:%d, support i8mm: %d\n", cpuinfo_isa->dot, cpuinfo_isa->fp16arith, cpuinfo_isa->i8mm); -} \ No newline at end of file +} diff --git a/source/geometry/GeometryReduce.cpp b/source/geometry/GeometryReduce.cpp index b0d9c0928..c2a3bb411 100644 --- a/source/geometry/GeometryReduce.cpp +++ b/source/geometry/GeometryReduce.cpp @@ -19,14 +19,23 @@ public: auto reduct = op->main_as_ReductionParam(); auto reductOp = reduct->operation(); // prod([]) = 1 - if (inputs[0]->elementSize() == 0 && reductOp == ReductionType_PROD) { + if (inputs[0]->elementSize() == 0) { if(!context.allocTensor(outputs[0])) { return false; } + float res; + switch (reductOp) { + case ReductionType_PROD: + res = 1.0f; + break; + default: + res = 0.0f; + break; + } if (outputs[0]->getType() == halide_type_of()) { - outputs[0]->host()[0] = 1.f; + outputs[0]->host()[0] = (float)res; } else { - outputs[0]->host()[0] = 1; + outputs[0]->host()[0] = (int)res; } return true; } diff --git a/source/shape/ShapeWhere.cpp b/source/shape/ShapeWhere.cpp index 69bebab48..251c80c11 100644 --- a/source/shape/ShapeWhere.cpp +++ b/source/shape/ShapeWhere.cpp @@ -41,7 +41,6 @@ class WhereSizeComputer : public SizeComputer { } // For zeroshape input if (nullptr == inputs[0]->host()) { - ob.dimensions = 1; ob.dim[0].extent = 0; return true; } diff --git a/source/utils/InitNet.cpp b/source/utils/InitNet.cpp index 0aae477f2..94adfb41e 100644 --- a/source/utils/InitNet.cpp +++ b/source/utils/InitNet.cpp @@ -226,6 +226,9 @@ void setInputOutputForOps(std::vector>& allTensors, cons if (des->usage == Tensor::InsideDescribe::CONSTANT) { continue; } + if (des->usage == Tensor::InsideDescribe::TRAINABLE) { + continue; + } des->usage = Tensor::InsideDescribe::INPUT; } for (auto index : output) { diff --git a/test/op/BinaryOPTest.cpp b/test/op/BinaryOPTest.cpp index 71af2ffe3..0cbce1bb5 100644 --- a/test/op/BinaryOPTest.cpp +++ b/test/op/BinaryOPTest.cpp @@ -185,9 +185,26 @@ public: {4, 2}, {2}, {4, 2}); } }; -class ModTest : public BinaryTestCommon { +class ModTestInt : public BinaryTestCommon { public: - virtual ~ModTest() = default; + virtual ~ModTestInt() = default; + virtual bool run(int precision) { + std::vector x = { + -4, 7, 5, 4, -7, 8 + }; + std::vector y = { + 2, -3, 8, -2, 3, 5 + }; + std::vector z = { + 0, -2, 5, 0, 2, 3 + }; + return test(_Mod, "ModTestFloat", 0, + x,y,z, {6}, {6}, {6}); + } +}; +class ModTestFloat : public BinaryTestCommon { +public: + virtual ~ModTestFloat() = default; virtual bool run(int precision) { std::vector x = { 1.1f, 2.3f, 3.5f, 4.7f, 5.9f, 6.2f, 7.4f, 8.6f @@ -201,7 +218,7 @@ public: z[i + j * 2] = FP32Converter[precision](fmodf(FP32Converter[precision](x[i+j*2]), FP32Converter[precision](y[i]))); } } - return test(_Mod, "ModTest", 0, + return test(_Mod, "ModTestFloat", 0, x,y,z, {4, 2}, {2}, {4, 2}); } @@ -368,7 +385,8 @@ MNNTestSuiteRegister(SquaredDifferenceTest, "op/binary/squareddifference"); MNNTestSuiteRegister(EqualTest, "op/binary/equal"); MNNTestSuiteRegister(LessEqualTest, "op/binary/lessequal"); MNNTestSuiteRegister(FloorModTest, "op/binary/floormod"); -MNNTestSuiteRegister(ModTest, "op/binary/mod"); +MNNTestSuiteRegister(ModTestFloat, "op/binary/mod_float"); +MNNTestSuiteRegister(ModTestInt, "op/binary/mod_int"); MNNTestSuiteRegister(Atan2Test, "op/binary/atan2"); MNNTestSuiteRegister(LogicalOrTest, "op/binary/logicalor"); MNNTestSuiteRegister(NotEqualTest, "op/binary/notqual"); diff --git a/tools/converter/source/onnx/IfOnnx.cpp b/tools/converter/source/onnx/IfOnnx.cpp index 1f5c7fc46..6d91425e0 100644 --- a/tools/converter/source/onnx/IfOnnx.cpp +++ b/tools/converter/source/onnx/IfOnnx.cpp @@ -21,15 +21,14 @@ MNN::OpParameter IfOnnx::type() { void IfOnnx::run(MNN::OpT* dstOp, const onnx::NodeProto* onnxNode, OnnxScope* scope) { auto param = new MNN::IfParamT; - dstOp->name += "/If"; - param->then_graph = dstOp->name + "/then"; - param->else_graph = dstOp->name + "/else"; const ::onnx::GraphProto *thenG = nullptr, *elseG = nullptr; for (const auto& attr : onnxNode->attribute()) { if (attr.name() == "then_branch") { thenG = &attr.g(); + param->then_graph = thenG->name(); } else if (attr.name() == "else_branch") { elseG = &attr.g(); + param->else_graph = elseG->name(); } } if (thenG == nullptr || elseG == nullptr) { @@ -50,8 +49,21 @@ void IfOnnx::run(MNN::OpT* dstOp, const onnx::NodeProto* onnxNode, std::transform(graph->output().begin(), graph->output().end(), outputs.begin(), [](const ::onnx::ValueInfoProto& p) { return p.name(); }); return std::make_pair(inputs, outputs); }; - auto thenInOuts = dealWithSubGraph(thenG, param->then_graph); - auto elseInOuts = dealWithSubGraph(elseG, param->else_graph); + std::pair, std::vector> thenInOuts, elseInOuts; + MNN_ASSERT(thenG->node_size() > 0 || elseG->node_size() > 0); + if (thenG->node_size() > 0) { + thenInOuts = dealWithSubGraph(thenG, param->then_graph); + } + if (elseG->node_size() > 0) { + elseInOuts = dealWithSubGraph(elseG, param->else_graph); + } + if (thenG->node_size() == 0) { + thenInOuts = elseInOuts; + param->then_graph = param->else_graph; + } else if (elseG->node_size() == 0) { + elseInOuts = thenInOuts; + param->else_graph = param->then_graph; + } auto thenInputs = thenInOuts.first, thenOutputs = thenInOuts.second; auto elseInputs = elseInOuts.first, elseOutputs = elseInOuts.second; diff --git a/tools/cpp/CMakeLists.txt b/tools/cpp/CMakeLists.txt index f793a9f08..8d1af615e 100644 --- a/tools/cpp/CMakeLists.txt +++ b/tools/cpp/CMakeLists.txt @@ -7,6 +7,7 @@ add_executable(SequenceModuleTest.out ${CMAKE_CURRENT_LIST_DIR}/SequenceModuleTe list(APPEND MNN_CPP_TOOLS SequenceModuleTest.out) add_executable(mergeInplaceForCPU ${CMAKE_CURRENT_LIST_DIR}/mergeInplaceForCPU.cpp) +list(APPEND MNN_CPP_TOOLS mergeInplaceForCPU) add_executable(MNNV2Basic.out ${CMAKE_CURRENT_LIST_DIR}/MNNV2Basic.cpp) list(APPEND MNN_CPP_TOOLS MNNV2Basic.out) diff --git a/tools/train/CMakeLists.txt b/tools/train/CMakeLists.txt index a580fc24d..992afd9ed 100644 --- a/tools/train/CMakeLists.txt +++ b/tools/train/CMakeLists.txt @@ -37,7 +37,7 @@ target_compile_definitions(MNNTrain PRIVATE STB_IMAGE_STATIC STB_IMAGE_IMPLEMENT # executables set(MNN_TRAIN_TOOLS "") -add_executable(transformer.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/transformerExecution.cpp) +add_executable(transformer.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/transformerExecution.cpp ${TRANSFORMER}) add_executable(train.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/train.cpp ${SCHEMA} ${BASIC_INCLUDE}) add_executable(rawDataTransform.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/rawDataTransform.cpp ${SCHEMA} ${BASIC_INCLUDE}) add_executable(dataTransformer.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/dataTransformer.cpp ${SCHEMA} ${BASIC_INCLUDE}) diff --git a/tools/train/source/exec/transformerExecution.cpp b/tools/train/source/exec/transformerExecution.cpp index 034227372..09ec75d28 100644 --- a/tools/train/source/exec/transformerExecution.cpp +++ b/tools/train/source/exec/transformerExecution.cpp @@ -22,6 +22,7 @@ #define MNN_OPEN_TIME_TRACE #include #include "rapidjson/document.h" +#include using namespace MNN; using namespace MNN::Express; @@ -82,6 +83,14 @@ int main(int argc, const char* argv[]) { MNN_PRINT("optimizer type: %s\n", optimizerType.c_str()); } } + auto bnMomentum = new MNN::AttributeT; + bnMomentum->f = 0.99; + if (configObject.HasMember("BatchNorm")) { + auto bnConfig = configObject["BatchNorm"].GetObject(); + if (bnConfig.HasMember("momentum")) { + bnMomentum->f = bnConfig["momentum"].GetFloat(); + } + } const char* inputModeFileName = argv[1]; FUNC_PRINT_ALL(inputModeFileName, s); std::map inputVars; @@ -94,7 +103,10 @@ int main(int argc, const char* argv[]) { Transformer::TrainConfig trainConfig; trainConfig.noUpdateOps = std::move(noUpdateOps); trainConfig.onlyUpdateOps = std::move(onlyUpdateOps); - Transformer::turnModelToTrainable(trainConfig)->onExecute(Variable::mapToSequence(outputVars)); + trainConfig.extraParams["BatchNorm"]["momentum"] = bnMomentum; + auto turnTrainable = Train::TurnTrainable(trainConfig); + turnTrainable.onExecute(Variable::mapToSequence(outputVars)); + auto trainInfo = turnTrainable.trainInfo; if (configObject.HasMember("Shape")) { auto shapeArray = configObject["Shape"].GetObject(); for (auto shapeIter = shapeArray.begin(); shapeIter != shapeArray.end(); shapeIter++) { @@ -189,10 +201,17 @@ int main(int argc, const char* argv[]) { stepPlus1->setName("optimize_step+1"); varUpdateMap[step] = stepPlus1; + std::map extraInputs; + extraInputs["LearningRate"] = "float"; + extraInputs["WeightDecay"] = "float"; + if (trainInfo["is_training_float"]->linkNumber() > 0) { + extraInputs["is_training"] = "int, 0 or 1"; + } + if (optimizerType == "SGD") { auto momentum = _Input(); momentum->setName("Momentum"); - MNN_PRINT(">>>\nextra input tensors for SGD: LearningRate, WeightDecay, Momentum\n<<<\n"); + extraInputs["Momentum"] = "float"; for (auto iter : gradMap) { auto p = iter.first; @@ -200,6 +219,18 @@ int main(int argc, const char* argv[]) { auto grad = iter.second; grad->setName(p->name()+"_grad"); + if (p->name().find("_BN_RunningMean_Weight") != string::npos) { + varUpdateMap[p] = trainInfo[p->name()]; + continue; // not update running stats + } + if (p->name().find("_BN_RunningVariance_Weight") != string::npos) { + varUpdateMap[p] = trainInfo[p->name()]; + continue; // not update running stats + } + if (p->name().find("_BN_Eps_Weight") != string::npos) { + continue; // not update eps + } + auto pInfo = p->getInfo(); auto pDims = pInfo->dim; @@ -216,10 +247,14 @@ int main(int argc, const char* argv[]) { VARP history = _Const(0.0f, pDims, pInfo->order); history->setName(p->name() + "_momentum"); history.fix(VARP::TRAINABLE); - auto newHistory = _Multiply(learningRate, gradWithDecay) + momentum * history; + + auto newHistory = gradWithDecay + momentum * history; newHistory->setName("update_" + history->name()); - auto updateValue = _Subtract(p, history); + auto finalGrad = learningRate * history; + finalGrad->setName(p->name() + "_final_grad"); + + auto updateValue = _Subtract(p, finalGrad); updateValue->setName("update_" + p->name()); varUpdateMap[p] = updateValue; varUpdateMap[history] = newHistory; @@ -231,7 +266,10 @@ int main(int argc, const char* argv[]) { beta2->setName("Beta2"); auto eps = _Input(); eps->setName("Eps"); - MNN_PRINT(">>>\nextra input tensors for ADAM: LearningRate, WeightDecay, Beta1, Beta2, Eps\n<<<\n"); + + extraInputs["Beta1"] = "float"; + extraInputs["Beta2"] = "float"; + extraInputs["Eps"] = "float"; auto correction = _Sqrt(_Const(1.0f, {}, NCHW) - _Pow(beta2, step)) / (_Const(1.0f, {}, NCHW) - _Pow(beta1, step)); correction->setName("correction"); @@ -242,6 +280,18 @@ int main(int argc, const char* argv[]) { auto grad = iter.second; grad->setName(p->name()+"_grad"); + if (p->name().find("_BN_RunningMean_Weight") != string::npos) { + varUpdateMap[p] = trainInfo[p->name()]; + continue; // not update running stats + } + if (p->name().find("_BN_RunningVariance_Weight") != string::npos) { + varUpdateMap[p] = trainInfo[p->name()]; + continue; // not update running stats + } + if (p->name().find("_BN_Eps_Weight") != string::npos) { + continue; // not update eps + } + auto pInfo = p->getInfo(); auto pDims = pInfo->dim; @@ -280,6 +330,12 @@ int main(int argc, const char* argv[]) { MNN_ERROR("error: don't support optimizer type: %s\n", optimizerType.c_str()); } + MNN_PRINT(">>>\nextra input tensors for %s:\n\n", optimizerType.c_str()); + for (auto& input : extraInputs) { + MNN_PRINT("name: %s, \ttype: %s\n", input.first.c_str(), input.second.c_str()); + } + MNN_PRINT("<<<\n"); + std::unique_ptr netStruct(new MNN::NetT); std::vector resultOutputs; for (auto output : outputVars) { @@ -298,7 +354,21 @@ int main(int argc, const char* argv[]) { for (int j = 0; j < netStruct->oplists.size(); ++j) { auto& opSub = netStruct->oplists[j]; if (opSub->name == iter.first->name()) { + auto indexOri = op->outputIndexes; op->outputIndexes = opSub->outputIndexes; + + if ((opSub->name.find("_BN_RunningMean_Weight") != string::npos) || (opSub->name.find("_BN_RunningVariance_Weight") != string::npos)) { + for (int k = 0; k < netStruct->oplists.size(); ++k) { + auto& opSubSub = netStruct->oplists[k]; + if (opSubSub->inputIndexes.size() > 0) { + for (int kk = 0; kk < opSubSub->inputIndexes.size(); kk++) { + if (opSubSub->inputIndexes[kk] == indexOri[0]) { + opSubSub->inputIndexes[kk] = opSub->outputIndexes[0]; + } + } + } + } + } } } } diff --git a/tools/train/source/optimizer/SGD.cpp b/tools/train/source/optimizer/SGD.cpp index 9d07acaad..1792a2e6a 100644 --- a/tools/train/source/optimizer/SGD.cpp +++ b/tools/train/source/optimizer/SGD.cpp @@ -70,7 +70,7 @@ Express::VARP SGD::regularizeParameters(Express::VARP param, Express::VARP grad) Express::VARP SGD::onComputeUpdateValue(Express::VARP param, Express::VARP grad) { auto lr = _Const(mLearningRate, {}, NCHW); - mHistory[param] = lr * grad + _Const(mMomentum, {}, NCHW) * mHistory[param]; + mHistory[param] = lr * (grad + _Const(mMomentum, {}, NCHW) * mHistory[param]); mHistory[param].fix(Express::VARP::CONSTANT); //FUNC_PRINT_ALL(_ReduceMax(grad)->readMap()[0], f); return mHistory[param]; diff --git a/tools/train/source/transformer/OpConverter.cpp b/tools/train/source/transformer/OpConverter.cpp index 563347a82..b37c83e53 100644 --- a/tools/train/source/transformer/OpConverter.cpp +++ b/tools/train/source/transformer/OpConverter.cpp @@ -31,12 +31,88 @@ void OpConverter::insert(MNN::OpType type, OpConverter* converter) { converterMap.insert(std::make_pair(type, converter)); } -EXPRP OpConverter::convert(EXPRP source) { +EXPRP OpConverter::convert(EXPRP source, std::map& helpInfo) { auto opOrigin = source->get(); if (nullptr == opOrigin) { return source; } std::unique_ptr op(opOrigin->UnPack()); + if (op->type == OpType_BatchNorm) { + printf("transform batchnorm: %s\n", source->name().c_str()); + + auto params = op->main.AsBatchNorm(); + auto channels = params->channels; + + auto input = source->inputs()[0]; + if (input->getInfo()->dim.size() != 4) { + printf("only support BatchNorm with 4-D input\n"); + return nullptr; + } + auto preExpr = input->expr().first; + bool cond = (preExpr->get() != nullptr) && (preExpr->get()->type() == OpType_Convolution) && (preExpr->inputs().size() == 3); + auto oriInputOrder = input->getInfo()->order; + if (oriInputOrder == NC4HW4) { + input = _Convert(input, NCHW); + if (cond) input->setName(source->name() + "_MNN_BN_after_conv_first_op"); + else input->setName(source->name() + "_MNN_single_BN_first_op"); + } + auto inputOrder = input->getInfo()->order; + + std::vector reduceDims = {0, 2, 3}; + std::vector statShape = {1, channels, 1, 1}; + if (inputOrder == NHWC) { + reduceDims = {0, 1, 2}; + statShape = {1, 1, 1, channels}; + } + + auto rMean = _Const((void*)params->meanData.data(), statShape, inputOrder); + rMean->setName(source->name() + "_BN_RunningMean_Weight"); + auto rVar = _Const((void*)params->varData.data(), statShape, inputOrder); + rVar->setName(source->name() + "_BN_RunningVariance_Weight"); + auto w = _Const((void*)params->slopeData.data(), statShape, inputOrder); + w->setName(source->name() + "_BN_Gamma_Weight"); + auto b = _Const((void*)params->biasData.data(), statShape, inputOrder); + b->setName(source->name() + "_BN_Beta_Bias"); + auto eps = _Scalar(params->epsilon); + eps->setName(source->name() + "_BN_Eps_Weight"); + + auto meanX = _ReduceMean(input, reduceDims, true); + meanX->setName(source->name() + "_BN_xmean"); + auto varX = _ReduceMean(_Square(input - meanX), reduceDims, true); + varX->setName(source->name() + "_BN_xvariance"); + + auto isTraining = helpInfo["is_training_float"]; + auto one = helpInfo["one_float"]; + auto momentum = helpInfo["bn_momentum"] * isTraining + (one - isTraining) * one; + + auto mMean = momentum * rMean + (one - momentum) * meanX; + mMean->setName(source->name() + "_BN_momentum_mean"); + helpInfo[rMean->name()] = mMean; + auto mVar = momentum * rVar + (one - momentum) * varX; + mVar->setName(source->name() + "_BN_momentum_variance"); + helpInfo[rVar->name()] = mVar; + + auto meanFinal = isTraining * meanX + (one - isTraining) * mMean; + meanFinal->setName(source->name() + "_BN_mean_final"); + auto varFinal = isTraining * varX + (one - isTraining) * mVar; + varFinal->setName(source->name() + "_BN_variance_final"); + auto stdFinal = _Sqrt(varFinal + eps); + + auto subMean = input - meanFinal; + if (oriInputOrder != NC4HW4) { + if (cond) subMean->setName(source->name() + "_MNN_BN_after_conv_first_op"); + else subMean->setName(source->name() + "_MNN_single_BN_first_op"); + } + auto normed = subMean / stdFinal; + auto res = normed * w + b; + + if (oriInputOrder == NC4HW4) { + res = _Convert(res, oriInputOrder); + } + res->setName(source->name()); + return res->expr().first; + } + if (op->type != OpType_Convolution && op->type != OpType_ConvolutionDepthwise) { return source; } diff --git a/tools/train/source/transformer/OpConverter.hpp b/tools/train/source/transformer/OpConverter.hpp index 819803aff..fafa401db 100644 --- a/tools/train/source/transformer/OpConverter.hpp +++ b/tools/train/source/transformer/OpConverter.hpp @@ -16,7 +16,7 @@ class MNN_PUBLIC OpConverter { public: OpConverter() = default; - static MNN::Express::EXPRP convert(MNN::Express::EXPRP source); + static MNN::Express::EXPRP convert(MNN::Express::EXPRP source, std::map& helpInfo); virtual ~OpConverter() = default; static OpConverter* get(MNN::OpType type); diff --git a/tools/train/source/transformer/Transformer.cpp b/tools/train/source/transformer/Transformer.cpp index 1a28a56b6..ab097f395 100644 --- a/tools/train/source/transformer/Transformer.cpp +++ b/tools/train/source/transformer/Transformer.cpp @@ -9,80 +9,76 @@ #include "Transformer.hpp" #include "OpConverter.hpp" #include "MNN_generated.h" +#include using namespace MNN::Express; namespace MNN { namespace Train { -class TurnTrainable : public Express::Optimizer { -public: - TurnTrainable(Transformer::TrainConfig config) { - mConfig = std::move(config); - } - virtual Cost onMeasure(const std::vector& outputs, - std::shared_ptr parameters = nullptr) override { - return Cost(); - } - virtual bool onExecute(const std::vector& outputs, std::shared_ptr p) override { - auto exprs = Variable::getExecuteOrder(outputs); - { - // Turn convolution be trainable convolution - for (auto expr : exprs) { - auto newExpr = OpConverter::convert(expr); - if (newExpr.get() != expr.get()) { - Expr::replace(expr, newExpr); - } +bool TurnTrainable::onExecute(const std::vector& outputs, std::shared_ptr p) { + auto exprs = Variable::getExecuteOrder(outputs); + { + auto isTraining = _Input({}, NCHW, halide_type_of()); + isTraining->setName("is_training"); + trainInfo["is_training"] = isTraining; + isTraining = _Cast(isTraining); + isTraining->setName("is_training_float"); + trainInfo["is_training_float"] = isTraining; + trainInfo["one_float"] = _Scalar(1.0f); + trainInfo["bn_momentum"] = _Scalar(mConfig.extraParams["BatchNorm"]["momentum"]->f); + // Turn convolution be trainable convolution + for (auto expr : exprs) { + auto newExpr = OpConverter::convert(expr, trainInfo); + if (newExpr.get() != expr.get()) { + Expr::replace(expr, newExpr); } } - exprs = Variable::getExecuteOrder(outputs); - auto& noUpdateOps = mConfig.noUpdateOps; - auto& onlyUpdateOps = mConfig.onlyUpdateOps; - // Collect Const Variable and turn to Trainable - for (auto v : exprs) { - if (v->get() == nullptr && VARP::INPUT != v->inputType()) { - auto name = v->name(); - auto info = v->outputInfo(0); - if (halide_type_float != info->type.code) { - continue; - } + } + exprs = Variable::getExecuteOrder(outputs); + auto& noUpdateOps = mConfig.noUpdateOps; + auto& onlyUpdateOps = mConfig.onlyUpdateOps; + // Collect Const Variable and turn to Trainable + for (auto v : exprs) { + if (v->get() == nullptr && VARP::INPUT != v->inputType()) { + auto name = v->name(); + auto info = v->outputInfo(0); + if (halide_type_float != info->type.code) { + continue; + } - bool update; - if (!onlyUpdateOps.empty()) { - update = false; - for (auto limit : onlyUpdateOps) { - if (name.find(limit) != std::string::npos) { - update = true; - break; - } - } - } else { - update = true; - for (auto limit : noUpdateOps) { - if (name.find(limit) != std::string::npos) { - update = false; - break; - } + bool update; + if (!onlyUpdateOps.empty()) { + update = false; + for (auto limit : onlyUpdateOps) { + if (name.find(limit) != std::string::npos) { + update = true; + break; } } - - auto va = Variable::create(v, 0); - if (update) { - MNN_PRINT("Add Variable: %s\n", name.c_str()); - va.fix(VARP::TRAINABLE); - if (name.find("Weight") == std::string::npos && name.find("Bias") == std::string::npos) { - MNN_PRINT(">>>\ncheck mnn model if const '%s' is a learnable parameter in your original training model, ", name.c_str()); - MNN_PRINT("if not, add it to transformConfig.json NoUpdateOps\n<<<\n"); + } else { + update = true; + for (auto limit : noUpdateOps) { + if (name.find(limit) != std::string::npos) { + update = false; + break; } - } else { - va.fix(VARP::CONSTANT); } } + + auto va = Variable::create(v, 0); + if (update && name != "") { + MNN_PRINT("Add Variable: %s\n", name.c_str()); + va.fix(VARP::TRAINABLE); + if (name.find("Weight") == std::string::npos && name.find("Bias") == std::string::npos) { + MNN_PRINT(">>>\ncheck mnn model if const '%s' is a learnable parameter in your original training model, ", name.c_str()); + MNN_PRINT("if not, add it to transformConfig.json NoUpdateOps\n<<<\n"); + } + } else { + va.fix(VARP::CONSTANT); + } } - return true; } - -private: - Transformer::TrainConfig mConfig; -}; + return true; +} std::shared_ptr Transformer::turnModelToTrainable(TrainConfig config) { std::shared_ptr res; @@ -90,46 +86,156 @@ std::shared_ptr Transformer::turnModelToTrainable(TrainConfi return res; } -class InferOptimizer : public Express::Optimizer { -public: - InferOptimizer(){} - virtual Cost onMeasure(const std::vector& outputs, std::shared_ptr parameters = nullptr) override { - Cost c; - return c; - }; +bool InferOptimizer::onExecute(const std::vector& outputs, std::shared_ptr parameters) { + auto exprs = Variable::getExecuteOrder(outputs); - virtual bool onExecute(const std::vector& outputs, std::shared_ptr parameters = nullptr) override { - auto exprs = Variable::getExecuteOrder(outputs); - for (auto& iter : exprs) { - auto op = iter->get(); - if (nullptr == op) { - continue; - } - if (op->type() != OpType_ConvInt8 && op->type() != OpType_DepthwiseConvInt8) { - continue; - } - auto inputExpr = iter->inputs()[0]->expr().first; - if (inputExpr->get() == nullptr) { - continue; - } - if (inputExpr->get()->type() != OpType_FloatToInt8) { - continue; - } - auto subInputExpr = inputExpr->inputs()[0]->expr().first; - if (subInputExpr->get() == nullptr) { - continue; - } - if (subInputExpr->get()->type() != OpType_Int8ToFloat) { - continue; - } - //MNN_PRINT("Find direct\n"); - std::vector newInputs = subInputExpr->inputs(); - auto newExpr = Expr::create(iter->extra(), std::move(newInputs)); - Expr::replace(iter, newExpr); + // convert trainable to const + for (auto& expr : exprs) { + if (expr->inputs().size() == 0 && expr->inputType() == VARP::InputType::TRAINABLE) { + auto newConst = Variable::create(expr); + newConst.fix(VARP::InputType::CONSTANT); + newConst->setName(expr->name()); + auto newExpr = newConst->expr().first; + newExpr->setName(expr->name()); + Expr::replace(expr, newExpr); } - return true; } -}; + + // merge bn after conv into conv + // convert single bn to scale + std::set bnNames; + std::string pattern1 = "_MNN_BN_after_conv_first_op"; + std::string pattern2 = "_MNN_single_BN_first_op"; + for (auto& expr : exprs) { + if (expr->name().find(pattern1) != std::string::npos) { + std::string bnName = expr->name(); + for (int i = 0; i < pattern1.size(); i++) { + bnName.pop_back(); + } + bnNames.insert(bnName); + } + if (expr->name().find(pattern2) != std::string::npos) { + std::string bnName = expr->name(); + for (int i = 0; i < pattern2.size(); i++) { + bnName.pop_back(); + } + bnNames.insert(bnName); + } + } + + std::map> bnInfo; + for (auto& name : bnNames) { + for (auto& expr : exprs) { + auto inputs = expr->inputs(); + if (expr->name() == name) { + bnInfo[name]["Self"] = expr; + } + if (inputs.size() == 0 && expr->name() == name + "_BN_RunningMean_Weight") { + bnInfo[name]["RunningMean"] = expr; + } + if (inputs.size() == 0 && expr->name() == name + "_BN_RunningVariance_Weight") { + bnInfo[name]["RunningVariance"] = expr; + } + if (inputs.size() == 0 && expr->name() == name + "_BN_Gamma_Weight") { + bnInfo[name]["Gamma"] = expr; + } + if (inputs.size() == 0 && expr->name() == name + "_BN_Beta_Bias") { + bnInfo[name]["Bias"] = expr; + } + if (inputs.size() == 0 && expr->name() == name + "_BN_Eps_Weight") { + bnInfo[name]["Eps"] = expr; + } + if (expr->name() == name + pattern1) { + bnInfo[name]["FirstOpAfterConv"] = expr; + } + if (expr->name() == name + pattern2) { + bnInfo[name]["FirstOpSingleBN"] = expr; + } + } + } + for (auto& bn : bnInfo) { + auto bnName = bn.first; + auto info = bn.second; + + bool bnAfterConv = false; + if (info.find("FirstOpAfterConv") != info.end()) { + bnAfterConv = true; + } + + auto rm = _Convert(Variable::create(info["RunningMean"]), NCHW); + auto rv = _Convert(Variable::create(info["RunningVariance"]), NCHW); + auto gamma = _Convert(Variable::create(info["Gamma"]), NCHW); + auto bias = _Convert(Variable::create(info["Bias"]), NCHW); + auto eps = Variable::create(info["Eps"]); + + auto s = _Sqrt(rv + eps); + auto alpha = gamma / s; + auto beta = bias - rm / s * gamma; + + if (bnAfterConv) { + auto firstOp = info["FirstOpAfterConv"]; + auto convExpr = firstOp->inputs()[0]->expr().first; + if (convExpr->get() == nullptr || convExpr->get()->type() != OpType_Convolution) { + continue; + } + auto convInput = convExpr->inputs()[0]; + auto w = convExpr->inputs()[1]; + auto b = convExpr->inputs()[2]; + + auto nw = w * _Reshape(alpha, {b->getInfo()->dim[0], 1, 1, 1}); + nw.fix(w->expr().first->inputType()); + nw->setName(w->name()); + auto nb = _Reshape(alpha, {b->getInfo()->dim}) * b + _Reshape(beta, b->getInfo()->dim); + nb.fix(b->expr().first->inputType()); + nb->setName(b->name()); + + std::vector newInputs = {convInput, nw, nb}; + auto newConv = Expr::create(convExpr->extra(), std::move(newInputs)); + Expr::replace(info["Self"], newConv); + } else { + auto firstOp = info["FirstOpSingleBN"]; + auto inputs = firstOp->inputs(); + std::vector scales, p; + for (int i = 0; i < beta->getInfo()->size; i++) { + scales.push_back(alpha->readMap()[i]); + p.push_back(beta->readMap()[i]); + } + auto res = _Scale(inputs[0], beta->getInfo()->size, std::move(scales), std::move(p)); + res->setName(info["Self"]->name()); + Expr::replace(info["Self"], res->expr().first); + } + } + + exprs = Variable::getExecuteOrder(outputs); + for (auto& iter : exprs) { + auto op = iter->get(); + if (nullptr == op) { + continue; + } + if (op->type() != OpType_ConvInt8 && op->type() != OpType_DepthwiseConvInt8) { + continue; + } + auto inputExpr = iter->inputs()[0]->expr().first; + if (inputExpr->get() == nullptr) { + continue; + } + if (inputExpr->get()->type() != OpType_FloatToInt8) { + continue; + } + auto subInputExpr = inputExpr->inputs()[0]->expr().first; + if (subInputExpr->get() == nullptr) { + continue; + } + if (subInputExpr->get()->type() != OpType_Int8ToFloat) { + continue; + } + //MNN_PRINT("Find direct\n"); + std::vector newInputs = subInputExpr->inputs(); + auto newExpr = Expr::create(iter->extra(), std::move(newInputs)); + Expr::replace(iter, newExpr); + } + return true; +} std::shared_ptr Transformer::turnModelToInfer() { return std::shared_ptr(new InferOptimizer); diff --git a/tools/train/source/transformer/Transformer.hpp b/tools/train/source/transformer/Transformer.hpp index 65b897b64..bef1768ed 100644 --- a/tools/train/source/transformer/Transformer.hpp +++ b/tools/train/source/transformer/Transformer.hpp @@ -9,6 +9,7 @@ #ifndef Transformer_hpp #define Transformer_hpp #include +#include namespace MNN { namespace Train { @@ -17,11 +18,40 @@ public: struct TrainConfig { std::vector noUpdateOps; std::vector onlyUpdateOps; + std::map> extraParams; }; static std::shared_ptr turnModelToTrainable(TrainConfig config); static std::shared_ptr turnModelToInfer(); }; + +class TurnTrainable : public Express::Optimizer { +public: + TurnTrainable(Transformer::TrainConfig config) { + mConfig = std::move(config); + } + virtual Cost onMeasure(const std::vector& outputs, + std::shared_ptr parameters = nullptr) override { + return Cost(); + } + virtual bool onExecute(const std::vector& outputs, std::shared_ptr p = nullptr) override; + +public: + std::map trainInfo; + +private: + Transformer::TrainConfig mConfig; +}; + +class InferOptimizer : public Express::Optimizer { +public: + InferOptimizer(){} + virtual Cost onMeasure(const std::vector& outputs, std::shared_ptr parameters = nullptr) override { + Cost c; + return c; + } + virtual bool onExecute(const std::vector& outputs, std::shared_ptr p = nullptr) override; +}; } // namespace Train } // namespace MNN #endif diff --git a/tools/train/transformConfig.json b/tools/train/transformConfig.json index da3ee9ff5..7f9751dc7 100644 --- a/tools/train/transformConfig.json +++ b/tools/train/transformConfig.json @@ -9,6 +9,9 @@ "StopBackPropOps":[], "type": "SGD" }, + "BatchNorm": { + "momentum":0.99 + }, "Debug": { "L2Norm": [] },