mirror of https://github.com/alibaba/MNN.git
172 lines
4.6 KiB
C++
172 lines
4.6 KiB
C++
|
|
//
|
||
|
|
// testModel_expr.cpp
|
||
|
|
// MNN
|
||
|
|
//
|
||
|
|
// Created by MNN on 2021/08/09.
|
||
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
||
|
|
//
|
||
|
|
|
||
|
|
#define MNN_OPEN_TIME_TRACE
|
||
|
|
|
||
|
|
#include <MNN/MNNDefine.h>
|
||
|
|
#include <math.h>
|
||
|
|
#include <algorithm>
|
||
|
|
#include <cmath>
|
||
|
|
#include <stdio.h>
|
||
|
|
#include <stdlib.h>
|
||
|
|
#include <string.h>
|
||
|
|
#include <MNN/AutoTime.hpp>
|
||
|
|
#include <MNN/expr/Module.hpp>
|
||
|
|
#include <MNN/expr/Expr.hpp>
|
||
|
|
#include <MNN/expr/ExprCreator.hpp>
|
||
|
|
#include <fstream>
|
||
|
|
#include <map>
|
||
|
|
#include <iostream>
|
||
|
|
#include <sstream>
|
||
|
|
#define NONE "\e[0m"
|
||
|
|
#define RED "\e[0;31m"
|
||
|
|
#define GREEN "\e[0;32m"
|
||
|
|
#define L_GREEN "\e[1;32m"
|
||
|
|
#define BLUE "\e[0;34m"
|
||
|
|
#define L_BLUE "\e[1;34m"
|
||
|
|
#define BOLD "\e[1m"
|
||
|
|
|
||
|
|
void log_result(bool correct) {
|
||
|
|
if (correct) {
|
||
|
|
#if defined(_MSC_VER)
|
||
|
|
std::cout << "Correct!" << std::endl;
|
||
|
|
#else
|
||
|
|
std::cout << GREEN << BOLD << "Correct!" << NONE << std::endl;
|
||
|
|
#endif
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
template<typename T>
|
||
|
|
inline T stringConvert(const char* number) {
|
||
|
|
std::istringstream os(number);
|
||
|
|
T v;
|
||
|
|
os >> v;
|
||
|
|
return v;
|
||
|
|
}
|
||
|
|
|
||
|
|
template <typename T>
|
||
|
|
static bool compareImpl(MNN::Express::VARP x, MNN::Express::VARP y, int size, double tolerance) {
|
||
|
|
#define _ABS(a) (a < 0 ? -a : a)
|
||
|
|
#define _MAX(a, b) (a > b ? a : b)
|
||
|
|
auto px = x->readMap<T>();
|
||
|
|
auto py = y->readMap<T>();
|
||
|
|
// get max if using overall torelance
|
||
|
|
T maxValue = _ABS(py[0]);
|
||
|
|
for (int i = 1; i < size; i++) {
|
||
|
|
maxValue = _MAX(maxValue, _ABS(py[i]));
|
||
|
|
}
|
||
|
|
// compare
|
||
|
|
for (int i = 0; i < size; i++) {
|
||
|
|
T vx = px[i], vy = py[i];
|
||
|
|
if (_ABS(vx - vy) < tolerance * maxValue) {
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
std::cout << i << ": " << vx << " != " << vy << std::endl;
|
||
|
|
return false;
|
||
|
|
}
|
||
|
|
return true;
|
||
|
|
#undef _ABS
|
||
|
|
#undef _MAX
|
||
|
|
}
|
||
|
|
|
||
|
|
static bool compareVar(MNN::Express::VARP x, MNN::Express::VARP y, double tolerance) {
|
||
|
|
auto info = y->getInfo();
|
||
|
|
auto dtype = info->type;
|
||
|
|
auto size = info->size;
|
||
|
|
if (dtype == halide_type_of<int32_t>()) {
|
||
|
|
return compareImpl<int32_t>(x, y, size, tolerance);
|
||
|
|
}
|
||
|
|
if (dtype == halide_type_of<uint8_t>()) {
|
||
|
|
return compareImpl<uint8_t>(x, y, size, tolerance);
|
||
|
|
}
|
||
|
|
return compareImpl<float>(x, y, size, tolerance);
|
||
|
|
}
|
||
|
|
|
||
|
|
using namespace MNN::Express;
|
||
|
|
int main(int argc, const char* argv[]) {
|
||
|
|
if (argc < 4) {
|
||
|
|
MNN_PRINT("Usage: ./testModel_expr.out model.mnn input.mnn output.mnn [type] [tolerance] [precision]\n");
|
||
|
|
return 0;
|
||
|
|
}
|
||
|
|
// check given & expect
|
||
|
|
const char* modelPath = argv[1];
|
||
|
|
const char* inputName = argv[2];
|
||
|
|
const char* outputName = argv[3];
|
||
|
|
MNN_PRINT("Testing model %s, input: %s, output: %s\n", modelPath, inputName, outputName);
|
||
|
|
|
||
|
|
// create net
|
||
|
|
auto type = MNN_FORWARD_CPU;
|
||
|
|
if (argc > 4) {
|
||
|
|
type = (MNNForwardType)stringConvert<int>(argv[4]);
|
||
|
|
}
|
||
|
|
auto tolerance = 0.1f;
|
||
|
|
if (argc > 5) {
|
||
|
|
tolerance = stringConvert<float>(argv[5]);
|
||
|
|
}
|
||
|
|
MNN::BackendConfig::PrecisionMode precision = MNN::BackendConfig::Precision_High;
|
||
|
|
if (argc > 6) {
|
||
|
|
precision = (MNN::BackendConfig::PrecisionMode)stringConvert<int>(argv[6]);
|
||
|
|
}
|
||
|
|
auto inputVars = Variable::load(inputName);
|
||
|
|
auto outputVars = Variable::load(outputName);
|
||
|
|
std::vector<std::string> inputNames;
|
||
|
|
std::vector<std::string> outputNames;
|
||
|
|
for (auto v : inputVars) {
|
||
|
|
inputNames.emplace_back(v->name());
|
||
|
|
}
|
||
|
|
for (auto v : outputVars) {
|
||
|
|
outputNames.emplace_back(v->name());
|
||
|
|
}
|
||
|
|
if (inputVars.empty()) {
|
||
|
|
MNN_ERROR("Input is Error\n");
|
||
|
|
return 0;
|
||
|
|
}
|
||
|
|
if (outputVars.empty()) {
|
||
|
|
MNN_ERROR("Output is Error\n");
|
||
|
|
return 0;
|
||
|
|
}
|
||
|
|
Module::Config config;
|
||
|
|
config.rearrange = true;
|
||
|
|
std::shared_ptr<Module> m(Module::load(inputNames, outputNames, modelPath));
|
||
|
|
if (nullptr == m) {
|
||
|
|
MNN_ERROR("Model is Error\n");
|
||
|
|
return 0;
|
||
|
|
}
|
||
|
|
// First
|
||
|
|
auto outputs = m->onForward(inputVars);
|
||
|
|
if (outputs.size() != outputVars.size()) {
|
||
|
|
MNN_ERROR("Number not match\n");
|
||
|
|
return 0;
|
||
|
|
}
|
||
|
|
bool success = true;
|
||
|
|
for (int i=0; i<outputVars.size(); ++i) {
|
||
|
|
success = compareVar(outputs[i], outputVars[i], tolerance);
|
||
|
|
if (!success) {
|
||
|
|
MNN_ERROR("Error for %s\n", outputVars[i]->name().c_str());
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if (!success) {
|
||
|
|
return 0;
|
||
|
|
}
|
||
|
|
outputs = m->onForward(inputVars);
|
||
|
|
for (int i=0; i<outputVars.size(); ++i) {
|
||
|
|
success = compareVar(outputs[i], outputVars[i], tolerance);
|
||
|
|
if (!success) {
|
||
|
|
MNN_ERROR("Error for %s\n", outputVars[i]->name().c_str());
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if (!success) {
|
||
|
|
MNN_ERROR("Error for test second\n");
|
||
|
|
return 0;
|
||
|
|
}
|
||
|
|
log_result(success);
|
||
|
|
return 0;
|
||
|
|
}
|