mirror of https://github.com/alibaba/MNN.git
[MNN:Sync] A few bugfixes
1. 支持 Onnx If 空子图的情况(这种情况是条件判断一定为真或假)
2. 修正 Where 算子在 zeroshape 下维度计算出错的问题
3. 修正 Reduce 计算 zeroshape 的非 prod 情况
4. 修正 arch64-linux 上编译错误
5. 修正 头文件 NNAPI 的注释错误
6, 部分训练相关问题修正
This commit is contained in:
parent
767086816d
commit
ad5d243c9f
|
|
@ -31,7 +31,7 @@ typedef enum {
|
|||
MNN_FORWARD_OPENGL = 6,
|
||||
MNN_FORWARD_VULKAN = 7,
|
||||
|
||||
/*Android 8.1's NNAPI, Not Support yet. CoreML Now*/
|
||||
/*Android 8.1's NNAPI or CoreML for ios*/
|
||||
MNN_FORWARD_NN = 5,
|
||||
|
||||
/*User can use API from Backend.hpp to add or search Backend*/
|
||||
|
|
|
|||
|
|
@ -45,10 +45,17 @@ struct BinaryRealDiv {
|
|||
}
|
||||
};
|
||||
|
||||
/**
|
||||
Ref from onnxruntime/onnxruntime/core/providers/cpu/math/element_wise_ops.cc :: Modulus
|
||||
*/
|
||||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
||||
struct BinaryModInt {
|
||||
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
||||
return x - (x / y) * y;
|
||||
auto res = x % y;
|
||||
if ((res < 0 && y > 0) || (res > 0 && y < 0)) {
|
||||
res += y;
|
||||
}
|
||||
return (_ErrorCode)res;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -54,6 +54,10 @@
|
|||
#define CPUINFO_ARM_LINUX_FEATURE_FPHP UINT32_C(0x00000200)
|
||||
#define CPUINFO_ARM_LINUX_FEATURE_ASIMDHP UINT32_C(0x00000400)
|
||||
#define CPUINFO_ARM_LINUX_FEATURE_ASIMDDP UINT32_C(0x00100000)
|
||||
#define CPUINFO_ARM_LINUX_FEATURE_I8MM UINT32_C(0x00002000)
|
||||
#define CPUINFO_ARM_LINUX_FEATURE_SVE UINT32_C(0x00400000)
|
||||
#define CPUINFO_ARM_LINUX_FEATURE_SVE2 UINT32_C(0x00000002)
|
||||
|
||||
#endif /* __linux__ && __aarch64__ */
|
||||
|
||||
#ifdef __ANDROID__
|
||||
|
|
@ -1564,4 +1568,4 @@ void cpuinfo_arm_init(struct cpuinfo_arm_isa* cpuinfo_isa) {
|
|||
#endif
|
||||
|
||||
MNN_PRINT("The device support dot:%d, support fp16:%d, support i8mm: %d\n", cpuinfo_isa->dot, cpuinfo_isa->fp16arith, cpuinfo_isa->i8mm);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,14 +19,23 @@ public:
|
|||
auto reduct = op->main_as_ReductionParam();
|
||||
auto reductOp = reduct->operation();
|
||||
// prod([]) = 1
|
||||
if (inputs[0]->elementSize() == 0 && reductOp == ReductionType_PROD) {
|
||||
if (inputs[0]->elementSize() == 0) {
|
||||
if(!context.allocTensor(outputs[0])) {
|
||||
return false;
|
||||
}
|
||||
float res;
|
||||
switch (reductOp) {
|
||||
case ReductionType_PROD:
|
||||
res = 1.0f;
|
||||
break;
|
||||
default:
|
||||
res = 0.0f;
|
||||
break;
|
||||
}
|
||||
if (outputs[0]->getType() == halide_type_of<float>()) {
|
||||
outputs[0]->host<float>()[0] = 1.f;
|
||||
outputs[0]->host<float>()[0] = (float)res;
|
||||
} else {
|
||||
outputs[0]->host<int>()[0] = 1;
|
||||
outputs[0]->host<int>()[0] = (int)res;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ class WhereSizeComputer : public SizeComputer {
|
|||
}
|
||||
// For zeroshape input
|
||||
if (nullptr == inputs[0]->host<void>()) {
|
||||
ob.dimensions = 1;
|
||||
ob.dim[0].extent = 0;
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -226,6 +226,9 @@ void setInputOutputForOps(std::vector<std::shared_ptr<Tensor>>& allTensors, cons
|
|||
if (des->usage == Tensor::InsideDescribe::CONSTANT) {
|
||||
continue;
|
||||
}
|
||||
if (des->usage == Tensor::InsideDescribe::TRAINABLE) {
|
||||
continue;
|
||||
}
|
||||
des->usage = Tensor::InsideDescribe::INPUT;
|
||||
}
|
||||
for (auto index : output) {
|
||||
|
|
|
|||
|
|
@ -185,9 +185,26 @@ public:
|
|||
{4, 2}, {2}, {4, 2});
|
||||
}
|
||||
};
|
||||
class ModTest : public BinaryTestCommon {
|
||||
class ModTestInt : public BinaryTestCommon {
|
||||
public:
|
||||
virtual ~ModTest() = default;
|
||||
virtual ~ModTestInt() = default;
|
||||
virtual bool run(int precision) {
|
||||
std::vector<int> x = {
|
||||
-4, 7, 5, 4, -7, 8
|
||||
};
|
||||
std::vector<int> y = {
|
||||
2, -3, 8, -2, 3, 5
|
||||
};
|
||||
std::vector<int> z = {
|
||||
0, -2, 5, 0, 2, 3
|
||||
};
|
||||
return test<int, int>(_Mod, "ModTestFloat", 0,
|
||||
x,y,z, {6}, {6}, {6});
|
||||
}
|
||||
};
|
||||
class ModTestFloat : public BinaryTestCommon {
|
||||
public:
|
||||
virtual ~ModTestFloat() = default;
|
||||
virtual bool run(int precision) {
|
||||
std::vector<float> x = {
|
||||
1.1f, 2.3f, 3.5f, 4.7f, 5.9f, 6.2f, 7.4f, 8.6f
|
||||
|
|
@ -201,7 +218,7 @@ public:
|
|||
z[i + j * 2] = FP32Converter[precision](fmodf(FP32Converter[precision](x[i+j*2]), FP32Converter[precision](y[i])));
|
||||
}
|
||||
}
|
||||
return test<float, float>(_Mod, "ModTest", 0,
|
||||
return test<float, float>(_Mod, "ModTestFloat", 0,
|
||||
x,y,z,
|
||||
{4, 2}, {2}, {4, 2});
|
||||
}
|
||||
|
|
@ -368,7 +385,8 @@ MNNTestSuiteRegister(SquaredDifferenceTest, "op/binary/squareddifference");
|
|||
MNNTestSuiteRegister(EqualTest, "op/binary/equal");
|
||||
MNNTestSuiteRegister(LessEqualTest, "op/binary/lessequal");
|
||||
MNNTestSuiteRegister(FloorModTest, "op/binary/floormod");
|
||||
MNNTestSuiteRegister(ModTest, "op/binary/mod");
|
||||
MNNTestSuiteRegister(ModTestFloat, "op/binary/mod_float");
|
||||
MNNTestSuiteRegister(ModTestInt, "op/binary/mod_int");
|
||||
MNNTestSuiteRegister(Atan2Test, "op/binary/atan2");
|
||||
MNNTestSuiteRegister(LogicalOrTest, "op/binary/logicalor");
|
||||
MNNTestSuiteRegister(NotEqualTest, "op/binary/notqual");
|
||||
|
|
|
|||
|
|
@ -21,15 +21,14 @@ MNN::OpParameter IfOnnx::type() {
|
|||
void IfOnnx::run(MNN::OpT* dstOp, const onnx::NodeProto* onnxNode,
|
||||
OnnxScope* scope) {
|
||||
auto param = new MNN::IfParamT;
|
||||
dstOp->name += "/If";
|
||||
param->then_graph = dstOp->name + "/then";
|
||||
param->else_graph = dstOp->name + "/else";
|
||||
const ::onnx::GraphProto *thenG = nullptr, *elseG = nullptr;
|
||||
for (const auto& attr : onnxNode->attribute()) {
|
||||
if (attr.name() == "then_branch") {
|
||||
thenG = &attr.g();
|
||||
param->then_graph = thenG->name();
|
||||
} else if (attr.name() == "else_branch") {
|
||||
elseG = &attr.g();
|
||||
param->else_graph = elseG->name();
|
||||
}
|
||||
}
|
||||
if (thenG == nullptr || elseG == nullptr) {
|
||||
|
|
@ -50,8 +49,21 @@ void IfOnnx::run(MNN::OpT* dstOp, const onnx::NodeProto* onnxNode,
|
|||
std::transform(graph->output().begin(), graph->output().end(), outputs.begin(), [](const ::onnx::ValueInfoProto& p) { return p.name(); });
|
||||
return std::make_pair(inputs, outputs);
|
||||
};
|
||||
auto thenInOuts = dealWithSubGraph(thenG, param->then_graph);
|
||||
auto elseInOuts = dealWithSubGraph(elseG, param->else_graph);
|
||||
std::pair<std::vector<std::string>, std::vector<std::string>> thenInOuts, elseInOuts;
|
||||
MNN_ASSERT(thenG->node_size() > 0 || elseG->node_size() > 0);
|
||||
if (thenG->node_size() > 0) {
|
||||
thenInOuts = dealWithSubGraph(thenG, param->then_graph);
|
||||
}
|
||||
if (elseG->node_size() > 0) {
|
||||
elseInOuts = dealWithSubGraph(elseG, param->else_graph);
|
||||
}
|
||||
if (thenG->node_size() == 0) {
|
||||
thenInOuts = elseInOuts;
|
||||
param->then_graph = param->else_graph;
|
||||
} else if (elseG->node_size() == 0) {
|
||||
elseInOuts = thenInOuts;
|
||||
param->else_graph = param->then_graph;
|
||||
}
|
||||
auto thenInputs = thenInOuts.first, thenOutputs = thenInOuts.second;
|
||||
auto elseInputs = elseInOuts.first, elseOutputs = elseInOuts.second;
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ add_executable(SequenceModuleTest.out ${CMAKE_CURRENT_LIST_DIR}/SequenceModuleTe
|
|||
list(APPEND MNN_CPP_TOOLS SequenceModuleTest.out)
|
||||
|
||||
add_executable(mergeInplaceForCPU ${CMAKE_CURRENT_LIST_DIR}/mergeInplaceForCPU.cpp)
|
||||
list(APPEND MNN_CPP_TOOLS mergeInplaceForCPU)
|
||||
|
||||
add_executable(MNNV2Basic.out ${CMAKE_CURRENT_LIST_DIR}/MNNV2Basic.cpp)
|
||||
list(APPEND MNN_CPP_TOOLS MNNV2Basic.out)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ target_compile_definitions(MNNTrain PRIVATE STB_IMAGE_STATIC STB_IMAGE_IMPLEMENT
|
|||
|
||||
# executables
|
||||
set(MNN_TRAIN_TOOLS "")
|
||||
add_executable(transformer.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/transformerExecution.cpp)
|
||||
add_executable(transformer.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/transformerExecution.cpp ${TRANSFORMER})
|
||||
add_executable(train.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/train.cpp ${SCHEMA} ${BASIC_INCLUDE})
|
||||
add_executable(rawDataTransform.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/rawDataTransform.cpp ${SCHEMA} ${BASIC_INCLUDE})
|
||||
add_executable(dataTransformer.out ${CMAKE_CURRENT_LIST_DIR}/source/exec/dataTransformer.cpp ${SCHEMA} ${BASIC_INCLUDE})
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@
|
|||
#define MNN_OPEN_TIME_TRACE
|
||||
#include <MNN/AutoTime.hpp>
|
||||
#include "rapidjson/document.h"
|
||||
#include <algorithm>
|
||||
|
||||
using namespace MNN;
|
||||
using namespace MNN::Express;
|
||||
|
|
@ -82,6 +83,14 @@ int main(int argc, const char* argv[]) {
|
|||
MNN_PRINT("optimizer type: %s\n", optimizerType.c_str());
|
||||
}
|
||||
}
|
||||
auto bnMomentum = new MNN::AttributeT;
|
||||
bnMomentum->f = 0.99;
|
||||
if (configObject.HasMember("BatchNorm")) {
|
||||
auto bnConfig = configObject["BatchNorm"].GetObject();
|
||||
if (bnConfig.HasMember("momentum")) {
|
||||
bnMomentum->f = bnConfig["momentum"].GetFloat();
|
||||
}
|
||||
}
|
||||
const char* inputModeFileName = argv[1];
|
||||
FUNC_PRINT_ALL(inputModeFileName, s);
|
||||
std::map<std::string, VARP> inputVars;
|
||||
|
|
@ -94,7 +103,10 @@ int main(int argc, const char* argv[]) {
|
|||
Transformer::TrainConfig trainConfig;
|
||||
trainConfig.noUpdateOps = std::move(noUpdateOps);
|
||||
trainConfig.onlyUpdateOps = std::move(onlyUpdateOps);
|
||||
Transformer::turnModelToTrainable(trainConfig)->onExecute(Variable::mapToSequence(outputVars));
|
||||
trainConfig.extraParams["BatchNorm"]["momentum"] = bnMomentum;
|
||||
auto turnTrainable = Train::TurnTrainable(trainConfig);
|
||||
turnTrainable.onExecute(Variable::mapToSequence(outputVars));
|
||||
auto trainInfo = turnTrainable.trainInfo;
|
||||
if (configObject.HasMember("Shape")) {
|
||||
auto shapeArray = configObject["Shape"].GetObject();
|
||||
for (auto shapeIter = shapeArray.begin(); shapeIter != shapeArray.end(); shapeIter++) {
|
||||
|
|
@ -189,10 +201,17 @@ int main(int argc, const char* argv[]) {
|
|||
stepPlus1->setName("optimize_step+1");
|
||||
varUpdateMap[step] = stepPlus1;
|
||||
|
||||
std::map<std::string, std::string> extraInputs;
|
||||
extraInputs["LearningRate"] = "float";
|
||||
extraInputs["WeightDecay"] = "float";
|
||||
if (trainInfo["is_training_float"]->linkNumber() > 0) {
|
||||
extraInputs["is_training"] = "int, 0 or 1";
|
||||
}
|
||||
|
||||
if (optimizerType == "SGD") {
|
||||
auto momentum = _Input();
|
||||
momentum->setName("Momentum");
|
||||
MNN_PRINT(">>>\nextra input tensors for SGD: LearningRate, WeightDecay, Momentum\n<<<\n");
|
||||
extraInputs["Momentum"] = "float";
|
||||
|
||||
for (auto iter : gradMap) {
|
||||
auto p = iter.first;
|
||||
|
|
@ -200,6 +219,18 @@ int main(int argc, const char* argv[]) {
|
|||
auto grad = iter.second;
|
||||
grad->setName(p->name()+"_grad");
|
||||
|
||||
if (p->name().find("_BN_RunningMean_Weight") != string::npos) {
|
||||
varUpdateMap[p] = trainInfo[p->name()];
|
||||
continue; // not update running stats
|
||||
}
|
||||
if (p->name().find("_BN_RunningVariance_Weight") != string::npos) {
|
||||
varUpdateMap[p] = trainInfo[p->name()];
|
||||
continue; // not update running stats
|
||||
}
|
||||
if (p->name().find("_BN_Eps_Weight") != string::npos) {
|
||||
continue; // not update eps
|
||||
}
|
||||
|
||||
auto pInfo = p->getInfo();
|
||||
auto pDims = pInfo->dim;
|
||||
|
||||
|
|
@ -216,10 +247,14 @@ int main(int argc, const char* argv[]) {
|
|||
VARP history = _Const(0.0f, pDims, pInfo->order);
|
||||
history->setName(p->name() + "_momentum");
|
||||
history.fix(VARP::TRAINABLE);
|
||||
auto newHistory = _Multiply(learningRate, gradWithDecay) + momentum * history;
|
||||
|
||||
auto newHistory = gradWithDecay + momentum * history;
|
||||
newHistory->setName("update_" + history->name());
|
||||
|
||||
auto updateValue = _Subtract(p, history);
|
||||
auto finalGrad = learningRate * history;
|
||||
finalGrad->setName(p->name() + "_final_grad");
|
||||
|
||||
auto updateValue = _Subtract(p, finalGrad);
|
||||
updateValue->setName("update_" + p->name());
|
||||
varUpdateMap[p] = updateValue;
|
||||
varUpdateMap[history] = newHistory;
|
||||
|
|
@ -231,7 +266,10 @@ int main(int argc, const char* argv[]) {
|
|||
beta2->setName("Beta2");
|
||||
auto eps = _Input();
|
||||
eps->setName("Eps");
|
||||
MNN_PRINT(">>>\nextra input tensors for ADAM: LearningRate, WeightDecay, Beta1, Beta2, Eps\n<<<\n");
|
||||
|
||||
extraInputs["Beta1"] = "float";
|
||||
extraInputs["Beta2"] = "float";
|
||||
extraInputs["Eps"] = "float";
|
||||
|
||||
auto correction = _Sqrt(_Const(1.0f, {}, NCHW) - _Pow(beta2, step)) / (_Const(1.0f, {}, NCHW) - _Pow(beta1, step));
|
||||
correction->setName("correction");
|
||||
|
|
@ -242,6 +280,18 @@ int main(int argc, const char* argv[]) {
|
|||
auto grad = iter.second;
|
||||
grad->setName(p->name()+"_grad");
|
||||
|
||||
if (p->name().find("_BN_RunningMean_Weight") != string::npos) {
|
||||
varUpdateMap[p] = trainInfo[p->name()];
|
||||
continue; // not update running stats
|
||||
}
|
||||
if (p->name().find("_BN_RunningVariance_Weight") != string::npos) {
|
||||
varUpdateMap[p] = trainInfo[p->name()];
|
||||
continue; // not update running stats
|
||||
}
|
||||
if (p->name().find("_BN_Eps_Weight") != string::npos) {
|
||||
continue; // not update eps
|
||||
}
|
||||
|
||||
auto pInfo = p->getInfo();
|
||||
auto pDims = pInfo->dim;
|
||||
|
||||
|
|
@ -280,6 +330,12 @@ int main(int argc, const char* argv[]) {
|
|||
MNN_ERROR("error: don't support optimizer type: %s\n", optimizerType.c_str());
|
||||
}
|
||||
|
||||
MNN_PRINT(">>>\nextra input tensors for %s:\n\n", optimizerType.c_str());
|
||||
for (auto& input : extraInputs) {
|
||||
MNN_PRINT("name: %s, \ttype: %s\n", input.first.c_str(), input.second.c_str());
|
||||
}
|
||||
MNN_PRINT("<<<\n");
|
||||
|
||||
std::unique_ptr<MNN::NetT> netStruct(new MNN::NetT);
|
||||
std::vector<VARP> resultOutputs;
|
||||
for (auto output : outputVars) {
|
||||
|
|
@ -298,7 +354,21 @@ int main(int argc, const char* argv[]) {
|
|||
for (int j = 0; j < netStruct->oplists.size(); ++j) {
|
||||
auto& opSub = netStruct->oplists[j];
|
||||
if (opSub->name == iter.first->name()) {
|
||||
auto indexOri = op->outputIndexes;
|
||||
op->outputIndexes = opSub->outputIndexes;
|
||||
|
||||
if ((opSub->name.find("_BN_RunningMean_Weight") != string::npos) || (opSub->name.find("_BN_RunningVariance_Weight") != string::npos)) {
|
||||
for (int k = 0; k < netStruct->oplists.size(); ++k) {
|
||||
auto& opSubSub = netStruct->oplists[k];
|
||||
if (opSubSub->inputIndexes.size() > 0) {
|
||||
for (int kk = 0; kk < opSubSub->inputIndexes.size(); kk++) {
|
||||
if (opSubSub->inputIndexes[kk] == indexOri[0]) {
|
||||
opSubSub->inputIndexes[kk] = opSub->outputIndexes[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ Express::VARP SGD::regularizeParameters(Express::VARP param, Express::VARP grad)
|
|||
|
||||
Express::VARP SGD::onComputeUpdateValue(Express::VARP param, Express::VARP grad) {
|
||||
auto lr = _Const(mLearningRate, {}, NCHW);
|
||||
mHistory[param] = lr * grad + _Const(mMomentum, {}, NCHW) * mHistory[param];
|
||||
mHistory[param] = lr * (grad + _Const(mMomentum, {}, NCHW) * mHistory[param]);
|
||||
mHistory[param].fix(Express::VARP::CONSTANT);
|
||||
//FUNC_PRINT_ALL(_ReduceMax(grad)->readMap<float>()[0], f);
|
||||
return mHistory[param];
|
||||
|
|
|
|||
|
|
@ -31,12 +31,88 @@ void OpConverter::insert(MNN::OpType type, OpConverter* converter) {
|
|||
converterMap.insert(std::make_pair(type, converter));
|
||||
}
|
||||
|
||||
EXPRP OpConverter::convert(EXPRP source) {
|
||||
EXPRP OpConverter::convert(EXPRP source, std::map<std::string, MNN::Express::VARP>& helpInfo) {
|
||||
auto opOrigin = source->get();
|
||||
if (nullptr == opOrigin) {
|
||||
return source;
|
||||
}
|
||||
std::unique_ptr<MNN::OpT> op(opOrigin->UnPack());
|
||||
if (op->type == OpType_BatchNorm) {
|
||||
printf("transform batchnorm: %s\n", source->name().c_str());
|
||||
|
||||
auto params = op->main.AsBatchNorm();
|
||||
auto channels = params->channels;
|
||||
|
||||
auto input = source->inputs()[0];
|
||||
if (input->getInfo()->dim.size() != 4) {
|
||||
printf("only support BatchNorm with 4-D input\n");
|
||||
return nullptr;
|
||||
}
|
||||
auto preExpr = input->expr().first;
|
||||
bool cond = (preExpr->get() != nullptr) && (preExpr->get()->type() == OpType_Convolution) && (preExpr->inputs().size() == 3);
|
||||
auto oriInputOrder = input->getInfo()->order;
|
||||
if (oriInputOrder == NC4HW4) {
|
||||
input = _Convert(input, NCHW);
|
||||
if (cond) input->setName(source->name() + "_MNN_BN_after_conv_first_op");
|
||||
else input->setName(source->name() + "_MNN_single_BN_first_op");
|
||||
}
|
||||
auto inputOrder = input->getInfo()->order;
|
||||
|
||||
std::vector<int> reduceDims = {0, 2, 3};
|
||||
std::vector<int> statShape = {1, channels, 1, 1};
|
||||
if (inputOrder == NHWC) {
|
||||
reduceDims = {0, 1, 2};
|
||||
statShape = {1, 1, 1, channels};
|
||||
}
|
||||
|
||||
auto rMean = _Const((void*)params->meanData.data(), statShape, inputOrder);
|
||||
rMean->setName(source->name() + "_BN_RunningMean_Weight");
|
||||
auto rVar = _Const((void*)params->varData.data(), statShape, inputOrder);
|
||||
rVar->setName(source->name() + "_BN_RunningVariance_Weight");
|
||||
auto w = _Const((void*)params->slopeData.data(), statShape, inputOrder);
|
||||
w->setName(source->name() + "_BN_Gamma_Weight");
|
||||
auto b = _Const((void*)params->biasData.data(), statShape, inputOrder);
|
||||
b->setName(source->name() + "_BN_Beta_Bias");
|
||||
auto eps = _Scalar<float>(params->epsilon);
|
||||
eps->setName(source->name() + "_BN_Eps_Weight");
|
||||
|
||||
auto meanX = _ReduceMean(input, reduceDims, true);
|
||||
meanX->setName(source->name() + "_BN_xmean");
|
||||
auto varX = _ReduceMean(_Square(input - meanX), reduceDims, true);
|
||||
varX->setName(source->name() + "_BN_xvariance");
|
||||
|
||||
auto isTraining = helpInfo["is_training_float"];
|
||||
auto one = helpInfo["one_float"];
|
||||
auto momentum = helpInfo["bn_momentum"] * isTraining + (one - isTraining) * one;
|
||||
|
||||
auto mMean = momentum * rMean + (one - momentum) * meanX;
|
||||
mMean->setName(source->name() + "_BN_momentum_mean");
|
||||
helpInfo[rMean->name()] = mMean;
|
||||
auto mVar = momentum * rVar + (one - momentum) * varX;
|
||||
mVar->setName(source->name() + "_BN_momentum_variance");
|
||||
helpInfo[rVar->name()] = mVar;
|
||||
|
||||
auto meanFinal = isTraining * meanX + (one - isTraining) * mMean;
|
||||
meanFinal->setName(source->name() + "_BN_mean_final");
|
||||
auto varFinal = isTraining * varX + (one - isTraining) * mVar;
|
||||
varFinal->setName(source->name() + "_BN_variance_final");
|
||||
auto stdFinal = _Sqrt(varFinal + eps);
|
||||
|
||||
auto subMean = input - meanFinal;
|
||||
if (oriInputOrder != NC4HW4) {
|
||||
if (cond) subMean->setName(source->name() + "_MNN_BN_after_conv_first_op");
|
||||
else subMean->setName(source->name() + "_MNN_single_BN_first_op");
|
||||
}
|
||||
auto normed = subMean / stdFinal;
|
||||
auto res = normed * w + b;
|
||||
|
||||
if (oriInputOrder == NC4HW4) {
|
||||
res = _Convert(res, oriInputOrder);
|
||||
}
|
||||
res->setName(source->name());
|
||||
return res->expr().first;
|
||||
}
|
||||
|
||||
if (op->type != OpType_Convolution && op->type != OpType_ConvolutionDepthwise) {
|
||||
return source;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class MNN_PUBLIC OpConverter {
|
|||
public:
|
||||
OpConverter() = default;
|
||||
|
||||
static MNN::Express::EXPRP convert(MNN::Express::EXPRP source);
|
||||
static MNN::Express::EXPRP convert(MNN::Express::EXPRP source, std::map<std::string, MNN::Express::VARP>& helpInfo);
|
||||
|
||||
virtual ~OpConverter() = default;
|
||||
static OpConverter* get(MNN::OpType type);
|
||||
|
|
|
|||
|
|
@ -9,80 +9,76 @@
|
|||
#include "Transformer.hpp"
|
||||
#include "OpConverter.hpp"
|
||||
#include "MNN_generated.h"
|
||||
#include <MNN/expr/ExprCreator.hpp>
|
||||
using namespace MNN::Express;
|
||||
namespace MNN {
|
||||
namespace Train {
|
||||
|
||||
class TurnTrainable : public Express::Optimizer {
|
||||
public:
|
||||
TurnTrainable(Transformer::TrainConfig config) {
|
||||
mConfig = std::move(config);
|
||||
}
|
||||
virtual Cost onMeasure(const std::vector<VARP>& outputs,
|
||||
std::shared_ptr<Parameters> parameters = nullptr) override {
|
||||
return Cost();
|
||||
}
|
||||
virtual bool onExecute(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> p) override {
|
||||
auto exprs = Variable::getExecuteOrder(outputs);
|
||||
{
|
||||
// Turn convolution be trainable convolution
|
||||
for (auto expr : exprs) {
|
||||
auto newExpr = OpConverter::convert(expr);
|
||||
if (newExpr.get() != expr.get()) {
|
||||
Expr::replace(expr, newExpr);
|
||||
}
|
||||
bool TurnTrainable::onExecute(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> p) {
|
||||
auto exprs = Variable::getExecuteOrder(outputs);
|
||||
{
|
||||
auto isTraining = _Input({}, NCHW, halide_type_of<int>());
|
||||
isTraining->setName("is_training");
|
||||
trainInfo["is_training"] = isTraining;
|
||||
isTraining = _Cast<float>(isTraining);
|
||||
isTraining->setName("is_training_float");
|
||||
trainInfo["is_training_float"] = isTraining;
|
||||
trainInfo["one_float"] = _Scalar<float>(1.0f);
|
||||
trainInfo["bn_momentum"] = _Scalar<float>(mConfig.extraParams["BatchNorm"]["momentum"]->f);
|
||||
// Turn convolution be trainable convolution
|
||||
for (auto expr : exprs) {
|
||||
auto newExpr = OpConverter::convert(expr, trainInfo);
|
||||
if (newExpr.get() != expr.get()) {
|
||||
Expr::replace(expr, newExpr);
|
||||
}
|
||||
}
|
||||
exprs = Variable::getExecuteOrder(outputs);
|
||||
auto& noUpdateOps = mConfig.noUpdateOps;
|
||||
auto& onlyUpdateOps = mConfig.onlyUpdateOps;
|
||||
// Collect Const Variable and turn to Trainable
|
||||
for (auto v : exprs) {
|
||||
if (v->get() == nullptr && VARP::INPUT != v->inputType()) {
|
||||
auto name = v->name();
|
||||
auto info = v->outputInfo(0);
|
||||
if (halide_type_float != info->type.code) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
exprs = Variable::getExecuteOrder(outputs);
|
||||
auto& noUpdateOps = mConfig.noUpdateOps;
|
||||
auto& onlyUpdateOps = mConfig.onlyUpdateOps;
|
||||
// Collect Const Variable and turn to Trainable
|
||||
for (auto v : exprs) {
|
||||
if (v->get() == nullptr && VARP::INPUT != v->inputType()) {
|
||||
auto name = v->name();
|
||||
auto info = v->outputInfo(0);
|
||||
if (halide_type_float != info->type.code) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool update;
|
||||
if (!onlyUpdateOps.empty()) {
|
||||
update = false;
|
||||
for (auto limit : onlyUpdateOps) {
|
||||
if (name.find(limit) != std::string::npos) {
|
||||
update = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
update = true;
|
||||
for (auto limit : noUpdateOps) {
|
||||
if (name.find(limit) != std::string::npos) {
|
||||
update = false;
|
||||
break;
|
||||
}
|
||||
bool update;
|
||||
if (!onlyUpdateOps.empty()) {
|
||||
update = false;
|
||||
for (auto limit : onlyUpdateOps) {
|
||||
if (name.find(limit) != std::string::npos) {
|
||||
update = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto va = Variable::create(v, 0);
|
||||
if (update) {
|
||||
MNN_PRINT("Add Variable: %s\n", name.c_str());
|
||||
va.fix(VARP::TRAINABLE);
|
||||
if (name.find("Weight") == std::string::npos && name.find("Bias") == std::string::npos) {
|
||||
MNN_PRINT(">>>\ncheck mnn model if const '%s' is a learnable parameter in your original training model, ", name.c_str());
|
||||
MNN_PRINT("if not, add it to transformConfig.json NoUpdateOps\n<<<\n");
|
||||
} else {
|
||||
update = true;
|
||||
for (auto limit : noUpdateOps) {
|
||||
if (name.find(limit) != std::string::npos) {
|
||||
update = false;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
va.fix(VARP::CONSTANT);
|
||||
}
|
||||
}
|
||||
|
||||
auto va = Variable::create(v, 0);
|
||||
if (update && name != "") {
|
||||
MNN_PRINT("Add Variable: %s\n", name.c_str());
|
||||
va.fix(VARP::TRAINABLE);
|
||||
if (name.find("Weight") == std::string::npos && name.find("Bias") == std::string::npos) {
|
||||
MNN_PRINT(">>>\ncheck mnn model if const '%s' is a learnable parameter in your original training model, ", name.c_str());
|
||||
MNN_PRINT("if not, add it to transformConfig.json NoUpdateOps\n<<<\n");
|
||||
}
|
||||
} else {
|
||||
va.fix(VARP::CONSTANT);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
Transformer::TrainConfig mConfig;
|
||||
};
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Express::Optimizer> Transformer::turnModelToTrainable(TrainConfig config) {
|
||||
std::shared_ptr<Express::Optimizer> res;
|
||||
|
|
@ -90,46 +86,156 @@ std::shared_ptr<Express::Optimizer> Transformer::turnModelToTrainable(TrainConfi
|
|||
return res;
|
||||
}
|
||||
|
||||
class InferOptimizer : public Express::Optimizer {
|
||||
public:
|
||||
InferOptimizer(){}
|
||||
virtual Cost onMeasure(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> parameters = nullptr) override {
|
||||
Cost c;
|
||||
return c;
|
||||
};
|
||||
bool InferOptimizer::onExecute(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> parameters) {
|
||||
auto exprs = Variable::getExecuteOrder(outputs);
|
||||
|
||||
virtual bool onExecute(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> parameters = nullptr) override {
|
||||
auto exprs = Variable::getExecuteOrder(outputs);
|
||||
for (auto& iter : exprs) {
|
||||
auto op = iter->get();
|
||||
if (nullptr == op) {
|
||||
continue;
|
||||
}
|
||||
if (op->type() != OpType_ConvInt8 && op->type() != OpType_DepthwiseConvInt8) {
|
||||
continue;
|
||||
}
|
||||
auto inputExpr = iter->inputs()[0]->expr().first;
|
||||
if (inputExpr->get() == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (inputExpr->get()->type() != OpType_FloatToInt8) {
|
||||
continue;
|
||||
}
|
||||
auto subInputExpr = inputExpr->inputs()[0]->expr().first;
|
||||
if (subInputExpr->get() == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (subInputExpr->get()->type() != OpType_Int8ToFloat) {
|
||||
continue;
|
||||
}
|
||||
//MNN_PRINT("Find direct\n");
|
||||
std::vector<VARP> newInputs = subInputExpr->inputs();
|
||||
auto newExpr = Expr::create(iter->extra(), std::move(newInputs));
|
||||
Expr::replace(iter, newExpr);
|
||||
// convert trainable to const
|
||||
for (auto& expr : exprs) {
|
||||
if (expr->inputs().size() == 0 && expr->inputType() == VARP::InputType::TRAINABLE) {
|
||||
auto newConst = Variable::create(expr);
|
||||
newConst.fix(VARP::InputType::CONSTANT);
|
||||
newConst->setName(expr->name());
|
||||
auto newExpr = newConst->expr().first;
|
||||
newExpr->setName(expr->name());
|
||||
Expr::replace(expr, newExpr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// merge bn after conv into conv
|
||||
// convert single bn to scale
|
||||
std::set<std::string> bnNames;
|
||||
std::string pattern1 = "_MNN_BN_after_conv_first_op";
|
||||
std::string pattern2 = "_MNN_single_BN_first_op";
|
||||
for (auto& expr : exprs) {
|
||||
if (expr->name().find(pattern1) != std::string::npos) {
|
||||
std::string bnName = expr->name();
|
||||
for (int i = 0; i < pattern1.size(); i++) {
|
||||
bnName.pop_back();
|
||||
}
|
||||
bnNames.insert(bnName);
|
||||
}
|
||||
if (expr->name().find(pattern2) != std::string::npos) {
|
||||
std::string bnName = expr->name();
|
||||
for (int i = 0; i < pattern2.size(); i++) {
|
||||
bnName.pop_back();
|
||||
}
|
||||
bnNames.insert(bnName);
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, std::map<std::string, EXPRP>> bnInfo;
|
||||
for (auto& name : bnNames) {
|
||||
for (auto& expr : exprs) {
|
||||
auto inputs = expr->inputs();
|
||||
if (expr->name() == name) {
|
||||
bnInfo[name]["Self"] = expr;
|
||||
}
|
||||
if (inputs.size() == 0 && expr->name() == name + "_BN_RunningMean_Weight") {
|
||||
bnInfo[name]["RunningMean"] = expr;
|
||||
}
|
||||
if (inputs.size() == 0 && expr->name() == name + "_BN_RunningVariance_Weight") {
|
||||
bnInfo[name]["RunningVariance"] = expr;
|
||||
}
|
||||
if (inputs.size() == 0 && expr->name() == name + "_BN_Gamma_Weight") {
|
||||
bnInfo[name]["Gamma"] = expr;
|
||||
}
|
||||
if (inputs.size() == 0 && expr->name() == name + "_BN_Beta_Bias") {
|
||||
bnInfo[name]["Bias"] = expr;
|
||||
}
|
||||
if (inputs.size() == 0 && expr->name() == name + "_BN_Eps_Weight") {
|
||||
bnInfo[name]["Eps"] = expr;
|
||||
}
|
||||
if (expr->name() == name + pattern1) {
|
||||
bnInfo[name]["FirstOpAfterConv"] = expr;
|
||||
}
|
||||
if (expr->name() == name + pattern2) {
|
||||
bnInfo[name]["FirstOpSingleBN"] = expr;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto& bn : bnInfo) {
|
||||
auto bnName = bn.first;
|
||||
auto info = bn.second;
|
||||
|
||||
bool bnAfterConv = false;
|
||||
if (info.find("FirstOpAfterConv") != info.end()) {
|
||||
bnAfterConv = true;
|
||||
}
|
||||
|
||||
auto rm = _Convert(Variable::create(info["RunningMean"]), NCHW);
|
||||
auto rv = _Convert(Variable::create(info["RunningVariance"]), NCHW);
|
||||
auto gamma = _Convert(Variable::create(info["Gamma"]), NCHW);
|
||||
auto bias = _Convert(Variable::create(info["Bias"]), NCHW);
|
||||
auto eps = Variable::create(info["Eps"]);
|
||||
|
||||
auto s = _Sqrt(rv + eps);
|
||||
auto alpha = gamma / s;
|
||||
auto beta = bias - rm / s * gamma;
|
||||
|
||||
if (bnAfterConv) {
|
||||
auto firstOp = info["FirstOpAfterConv"];
|
||||
auto convExpr = firstOp->inputs()[0]->expr().first;
|
||||
if (convExpr->get() == nullptr || convExpr->get()->type() != OpType_Convolution) {
|
||||
continue;
|
||||
}
|
||||
auto convInput = convExpr->inputs()[0];
|
||||
auto w = convExpr->inputs()[1];
|
||||
auto b = convExpr->inputs()[2];
|
||||
|
||||
auto nw = w * _Reshape(alpha, {b->getInfo()->dim[0], 1, 1, 1});
|
||||
nw.fix(w->expr().first->inputType());
|
||||
nw->setName(w->name());
|
||||
auto nb = _Reshape(alpha, {b->getInfo()->dim}) * b + _Reshape(beta, b->getInfo()->dim);
|
||||
nb.fix(b->expr().first->inputType());
|
||||
nb->setName(b->name());
|
||||
|
||||
std::vector<VARP> newInputs = {convInput, nw, nb};
|
||||
auto newConv = Expr::create(convExpr->extra(), std::move(newInputs));
|
||||
Expr::replace(info["Self"], newConv);
|
||||
} else {
|
||||
auto firstOp = info["FirstOpSingleBN"];
|
||||
auto inputs = firstOp->inputs();
|
||||
std::vector<float> scales, p;
|
||||
for (int i = 0; i < beta->getInfo()->size; i++) {
|
||||
scales.push_back(alpha->readMap<float>()[i]);
|
||||
p.push_back(beta->readMap<float>()[i]);
|
||||
}
|
||||
auto res = _Scale(inputs[0], beta->getInfo()->size, std::move(scales), std::move(p));
|
||||
res->setName(info["Self"]->name());
|
||||
Expr::replace(info["Self"], res->expr().first);
|
||||
}
|
||||
}
|
||||
|
||||
exprs = Variable::getExecuteOrder(outputs);
|
||||
for (auto& iter : exprs) {
|
||||
auto op = iter->get();
|
||||
if (nullptr == op) {
|
||||
continue;
|
||||
}
|
||||
if (op->type() != OpType_ConvInt8 && op->type() != OpType_DepthwiseConvInt8) {
|
||||
continue;
|
||||
}
|
||||
auto inputExpr = iter->inputs()[0]->expr().first;
|
||||
if (inputExpr->get() == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (inputExpr->get()->type() != OpType_FloatToInt8) {
|
||||
continue;
|
||||
}
|
||||
auto subInputExpr = inputExpr->inputs()[0]->expr().first;
|
||||
if (subInputExpr->get() == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (subInputExpr->get()->type() != OpType_Int8ToFloat) {
|
||||
continue;
|
||||
}
|
||||
//MNN_PRINT("Find direct\n");
|
||||
std::vector<VARP> newInputs = subInputExpr->inputs();
|
||||
auto newExpr = Expr::create(iter->extra(), std::move(newInputs));
|
||||
Expr::replace(iter, newExpr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Express::Optimizer> Transformer::turnModelToInfer() {
|
||||
return std::shared_ptr<Optimizer>(new InferOptimizer);
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#ifndef Transformer_hpp
|
||||
#define Transformer_hpp
|
||||
#include <MNN/expr/Optimizer.hpp>
|
||||
#include <MNN_generated.h>
|
||||
|
||||
namespace MNN {
|
||||
namespace Train {
|
||||
|
|
@ -17,11 +18,40 @@ public:
|
|||
struct TrainConfig {
|
||||
std::vector<std::string> noUpdateOps;
|
||||
std::vector<std::string> onlyUpdateOps;
|
||||
std::map<std::string, std::map<std::string, MNN::AttributeT*>> extraParams;
|
||||
};
|
||||
|
||||
static std::shared_ptr<Express::Optimizer> turnModelToTrainable(TrainConfig config);
|
||||
static std::shared_ptr<Express::Optimizer> turnModelToInfer();
|
||||
};
|
||||
|
||||
class TurnTrainable : public Express::Optimizer {
|
||||
public:
|
||||
TurnTrainable(Transformer::TrainConfig config) {
|
||||
mConfig = std::move(config);
|
||||
}
|
||||
virtual Cost onMeasure(const std::vector<Express::VARP>& outputs,
|
||||
std::shared_ptr<Parameters> parameters = nullptr) override {
|
||||
return Cost();
|
||||
}
|
||||
virtual bool onExecute(const std::vector<Express::VARP>& outputs, std::shared_ptr<Parameters> p = nullptr) override;
|
||||
|
||||
public:
|
||||
std::map<std::string, Express::VARP> trainInfo;
|
||||
|
||||
private:
|
||||
Transformer::TrainConfig mConfig;
|
||||
};
|
||||
|
||||
class InferOptimizer : public Express::Optimizer {
|
||||
public:
|
||||
InferOptimizer(){}
|
||||
virtual Cost onMeasure(const std::vector<Express::VARP>& outputs, std::shared_ptr<Parameters> parameters = nullptr) override {
|
||||
Cost c;
|
||||
return c;
|
||||
}
|
||||
virtual bool onExecute(const std::vector<Express::VARP>& outputs, std::shared_ptr<Parameters> p = nullptr) override;
|
||||
};
|
||||
} // namespace Train
|
||||
} // namespace MNN
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@
|
|||
"StopBackPropOps":[],
|
||||
"type": "SGD"
|
||||
},
|
||||
"BatchNorm": {
|
||||
"momentum":0.99
|
||||
},
|
||||
"Debug": {
|
||||
"L2Norm": []
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in New Issue