mirror of https://github.com/alibaba/MNN.git
164 lines
5.6 KiB
C++
164 lines
5.6 KiB
C++
#include <math.h>
|
|
#include <MNN/expr/ExprCreator.hpp>
|
|
#include <MNN/expr/Module.hpp>
|
|
#include "MNNTestSuite.h"
|
|
#include "TestUtils.h"
|
|
using namespace MNN;
|
|
using namespace MNN::Express;
|
|
|
|
class ModuleShapeInfer : public MNNTestCase {
|
|
public:
|
|
static float _reduceSum(const float* zPtr, int size) {
|
|
float summer = 0.0f;
|
|
for (int i=0; i<size; ++i) {
|
|
summer+=zPtr[i];
|
|
}
|
|
return summer;
|
|
}
|
|
virtual bool run(int precision) {
|
|
auto executor = cloneCurrentExecutor();
|
|
ExecutorScope scope(executor);
|
|
std::vector<VARP> empty;
|
|
// Make Net
|
|
auto x = _Input({1, 3, 2, 2}, NCHW, halide_type_of<float>());
|
|
x->setName("x");
|
|
auto y = x * x;
|
|
VARP starts;
|
|
VARP sizes;
|
|
{
|
|
std::vector<int> sta = {0, 0, 1, 1};
|
|
std::vector<int> siz = {1, 1, 1, 1};
|
|
starts = _Const(sta.data(), {4}, NCHW, halide_type_of<int>());
|
|
sizes = _Const(siz.data(), {4}, NCHW, halide_type_of<int>());
|
|
}
|
|
auto z = _Slice(y, starts, sizes);
|
|
z->setName("z");
|
|
auto buffer = Variable::save({z});
|
|
ScheduleConfig config;
|
|
BackendConfig bnConfig;
|
|
bnConfig.precision = MNN::BackendConfig::Precision_Low;
|
|
config.backendConfig = &bnConfig;
|
|
std::shared_ptr<Executor::RuntimeManager> rt(Executor::RuntimeManager::createRuntimeManager(config), Executor::RuntimeManager::destroy);
|
|
std::shared_ptr<Module> net0(Module::load({"x"}, {"z"}, (const uint8_t*)buffer.data(), buffer.size(), rt), Module::destroy);
|
|
std::shared_ptr<Module> net1(Module::load({"x"}, {"z"}, (const uint8_t*)buffer.data(), buffer.size(), rt), Module::destroy);
|
|
x = _Input({1, 3, 2, 2}, NCHW, halide_type_of<float>());
|
|
// Run Init Value
|
|
auto inputPtr = x->writeMap<float>();
|
|
for (int i=0; i<x->getInfo()->size; ++i) {
|
|
inputPtr[i] = i;
|
|
}
|
|
y = net0->onForward({x})[0];
|
|
auto yPtr = y->readMap<float>();
|
|
auto ySize = y->getInfo()->size;
|
|
auto valueFirst = _reduceSum(yPtr, ySize);
|
|
for (int i=0; i<x->getInfo()->size; ++i) {
|
|
inputPtr[i] = x->getInfo()->size - i;
|
|
}
|
|
y = net0->onForward({x})[0];
|
|
yPtr = y->readMap<float>();
|
|
auto valueSecond = _reduceSum(yPtr, ySize);
|
|
|
|
// Shape Infer mode
|
|
auto code = net1->traceOrOptimize(Interpreter::Module_Forward_Separate);
|
|
if (0 != code) {
|
|
FUNC_PRINT(1);
|
|
return false;
|
|
}
|
|
for (int i=0; i<x->getInfo()->size; ++i) {
|
|
inputPtr[i] = i;
|
|
}
|
|
y = net1->onForward({x})[0];
|
|
yPtr = y->readMap<float>();
|
|
auto tmp = net1->onForward(empty);
|
|
if (tmp.size() > 0) {
|
|
FUNC_PRINT(1);
|
|
return false;
|
|
}
|
|
if (_reduceSum(yPtr, ySize) != valueFirst) {
|
|
FUNC_PRINT(1);
|
|
return false;
|
|
}
|
|
for (int i=0; i<x->getInfo()->size; ++i) {
|
|
inputPtr[i] = x->getInfo()->size - i;
|
|
}
|
|
net1->onForward(empty);
|
|
if (_reduceSum(yPtr, ySize) != valueSecond) {
|
|
FUNC_PRINT(1);
|
|
return false;
|
|
}
|
|
net1->traceOrOptimize(MNN::Interpreter::Module_Forward_Combine);
|
|
for (int i=0; i<x->getInfo()->size; ++i) {
|
|
inputPtr[i] = i;
|
|
}
|
|
y = net1->onForward({x})[0];
|
|
yPtr = y->readMap<float>();
|
|
if(_reduceSum(yPtr, ySize) != valueFirst) {
|
|
FUNC_PRINT(1);
|
|
return false;
|
|
}
|
|
for (int i=0; i<x->getInfo()->size; ++i) {
|
|
inputPtr[i] = x->getInfo()->size - i;
|
|
}
|
|
y = net1->onForward({x})[0];
|
|
yPtr = y->readMap<float>();
|
|
if(_reduceSum(yPtr, ySize) != valueSecond) {
|
|
FUNC_PRINT(1);
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
|
|
class VariableSaveLoad: public MNNTestCase { // Verify the order of load is the same as the order of save
|
|
public:
|
|
virtual bool run(int precision) {
|
|
std::vector<MNN::Express::VARP> vars;
|
|
std::vector<int32_t> contents(4);
|
|
std::string file = "file.txt";
|
|
contents[0] = 0;
|
|
contents[1] = 1;
|
|
contents[2] = 2;
|
|
contents[3] = 3;
|
|
|
|
for (auto number: contents) {
|
|
auto var = MNN::Express::_Const(&number, {1}, MNN::Express::NHWC, halide_type_of<int32_t>());
|
|
if (var->getInfo() == nullptr) {
|
|
MNN_PRINT("error\n");
|
|
return false;
|
|
}
|
|
vars.emplace_back(var);
|
|
}
|
|
|
|
MNN::Express::Variable::save(vars, file.c_str());
|
|
|
|
auto readVars = MNN::Express::Variable::load(file.c_str());
|
|
std::vector<int32_t> readContents;
|
|
for (auto var_: readVars) {
|
|
auto var_ptr = var_->getInfo();
|
|
if (var_ptr == nullptr) {
|
|
MNN_PRINT("error\n");
|
|
return false;
|
|
}
|
|
readContents.push_back(var_->readMap<int32_t>()[0]);
|
|
}
|
|
for (int i = 0; i < 4; ++i) {
|
|
if (readContents[i] != contents[i]) {
|
|
MNN_PRINT("error %d: read %d, expect %d\n", i, readContents[i], contents[i]);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
int result = std::remove(file.c_str());
|
|
if (result == 0) {
|
|
MNN_PRINT("delete file success\n");
|
|
return true;
|
|
} else {
|
|
MNN_PRINT("delete file failed\n");
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
MNNTestSuiteRegister(ModuleShapeInfer, "expr/ModuleShapeInfer");
|
|
MNNTestSuiteRegister(VariableSaveLoad, "variable/saveLoad");
|