MNN/tools/cpp/testModel.cpp

195 lines
5.5 KiB
C++
Raw Normal View History

2019-04-17 10:49:11 +08:00
//
// testModel.cpp
// MNN
//
// Created by MNN on 2019/01/22.
// Copyright © 2018, Alibaba Group Holding Limited
//
#define MNN_OPEN_TIME_TRACE
#include <MNN/MNNDefine.h>
2019-04-17 10:49:11 +08:00
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <MNN/AutoTime.hpp>
#include <MNN/Interpreter.hpp>
#include <MNN/Tensor.hpp>
2019-04-17 10:49:11 +08:00
#include <fstream>
#include <map>
#include <sstream>
2019-12-27 22:16:57 +08:00
#include "core/Backend.hpp"
#include "core/Macro.h"
#include "core/TensorUtils.hpp"
2019-04-17 10:49:11 +08:00
#define NONE "\e[0m"
#define RED "\e[0;31m"
#define GREEN "\e[0;32m"
#define L_GREEN "\e[1;32m"
#define BLUE "\e[0;34m"
#define L_BLUE "\e[1;34m"
#define BOLD "\e[1m"
template<typename T>
inline T stringConvert(const char* number) {
std::istringstream os(number);
T v;
os >> v;
return v;
}
2019-04-17 10:49:11 +08:00
MNN::Tensor* createTensor(const MNN::Tensor* shape, const char* path) {
std::ifstream stream(path);
if (stream.fail()) {
return NULL;
}
auto result = new MNN::Tensor(shape, shape->getDimensionType());
auto data = result->host<float>();
for (int i = 0; i < result->elementSize(); ++i) {
double temp = 0.0f;
stream >> temp;
data[i] = temp;
}
stream.close();
return result;
}
int main(int argc, const char* argv[]) {
2020-11-05 16:41:56 +08:00
2019-04-17 10:49:11 +08:00
// check given & expect
const char* modelPath = argv[1];
const char* givenName = argv[2];
const char* expectName = argv[3];
MNN_PRINT("Testing model %s, input: %s, output: %s\n", modelPath, givenName, expectName);
// create net
auto type = MNN_FORWARD_CPU;
if (argc > 4) {
type = (MNNForwardType)stringConvert<int>(argv[4]);
2019-04-17 10:49:11 +08:00
}
auto tolerance = 0.1f;
if (argc > 5) {
tolerance = stringConvert<float>(argv[5]);
2019-04-17 10:49:11 +08:00
}
2020-11-05 16:41:56 +08:00
MNN::BackendConfig::PrecisionMode precision = MNN::BackendConfig::Precision_High;
if (argc > 6) {
precision = (MNN::BackendConfig::PrecisionMode)stringConvert<int>(argv[6]);
}
2019-04-17 10:49:11 +08:00
std::shared_ptr<MNN::Interpreter> net =
std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromFile(modelPath));
// create session
MNN::ScheduleConfig config;
config.type = type;
2019-12-27 22:16:57 +08:00
MNN::BackendConfig backendConfig;
2020-11-05 16:41:56 +08:00
backendConfig.precision = precision;
2019-12-27 22:16:57 +08:00
config.backendConfig = &backendConfig;
auto session = net->createSession(config);
2021-06-21 15:13:10 +08:00
// input dims
std::vector<int> inputDims;
if (argc > 7) {
std::string inputShape(argv[7]);
const char* delim = "x";
std::ptrdiff_t p1 = 0, p2;
while (1) {
p2 = inputShape.find(delim, p1);
if (p2 != std::string::npos) {
inputDims.push_back(atoi(inputShape.substr(p1, p2 - p1).c_str()));
p1 = p2 + 1;
} else {
inputDims.push_back(atoi(inputShape.substr(p1).c_str()));
break;
}
}
}
for (auto dim : inputDims) {
MNN_PRINT("%d ", dim);
}
MNN_PRINT("\n");
2019-04-17 10:49:11 +08:00
auto allInput = net->getSessionInputAll(session);
for (auto& iter : allInput) {
2020-11-05 16:41:56 +08:00
auto inputTensor = iter.second;
2021-06-21 15:13:10 +08:00
if (!inputDims.empty()) {
MNN_PRINT("===========> Resize Tensor...\n");
net->resizeTensor(inputTensor, inputDims);
net->resizeSession(session);
}
2020-11-05 16:41:56 +08:00
auto size = inputTensor->size();
if (size <= 0) {
continue;
2019-04-17 10:49:11 +08:00
}
2020-11-05 16:41:56 +08:00
MNN::Tensor tempTensor(inputTensor, inputTensor->getDimensionType());
::memset(tempTensor.host<void>(), 0, tempTensor.size());
inputTensor->copyFromHostTensor(&tempTensor);
2019-04-17 10:49:11 +08:00
}
// write input tensor
auto inputTensor = net->getSessionInput(session, NULL);
std::shared_ptr<MNN::Tensor> givenTensor(createTensor(inputTensor, givenName));
2019-04-17 10:49:11 +08:00
if (!givenTensor) {
#if defined(_MSC_VER)
printf("Failed to open input file %s.\n", givenName);
#else
2019-04-17 10:49:11 +08:00
printf(RED "Failed to open input file %s.\n" NONE, givenName);
#endif
2019-04-17 10:49:11 +08:00
return -1;
}
// First time
inputTensor->copyFromHostTensor(givenTensor.get());
2019-04-17 10:49:11 +08:00
// infer
net->runSession(session);
// read expect tensor
auto outputTensor = net->getSessionOutput(session, NULL);
std::shared_ptr<MNN::Tensor> expectTensor(createTensor(outputTensor, expectName));
if (!expectTensor.get()) {
#if defined(_MSC_VER)
printf("Failed to open expect file %s.\n", expectName);
#else
2019-04-17 10:49:11 +08:00
printf(RED "Failed to open expect file %s.\n" NONE, expectName);
#endif
2019-04-17 10:49:11 +08:00
return -1;
}
// compare output with expect
bool correct = MNN::TensorUtils::compareTensors(outputTensor, expectTensor.get(), tolerance, true);
if (!correct) {
#if defined(_MSC_VER)
printf("Test Failed %s!\n", modelPath);
#else
printf(RED "Test Failed %s!\n" NONE, modelPath);
#endif
return -1;
} else {
printf("First run pass\n");
}
// Run Second time
inputTensor->copyFromHostTensor(givenTensor.get());
// infer
net->runSession(session);
// read expect tensor
std::shared_ptr<MNN::Tensor> expectTensor2(createTensor(outputTensor, expectName));
correct = MNN::TensorUtils::compareTensors(outputTensor, expectTensor2.get(), tolerance, true);
2019-04-17 10:49:11 +08:00
if (correct) {
#if defined(_MSC_VER)
printf("Test %s Correct!\n", modelPath);
#else
2019-04-17 10:49:11 +08:00
printf(GREEN BOLD "Test %s Correct!\n" NONE, modelPath);
#endif
2019-04-17 10:49:11 +08:00
} else {
#if defined(_MSC_VER)
printf("Test Failed %s!\n", modelPath);
#else
2019-04-17 10:49:11 +08:00
printf(RED "Test Failed %s!\n" NONE, modelPath);
#endif
2019-04-17 10:49:11 +08:00
}
return 0;
}