| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  testModel.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2019/01/22.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #define MNN_OPEN_TIME_TRACE
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  | #include <MNN/MNNDefine.h>
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | #include <math.h>
 | 
					
						
							|  |  |  | #include <stdio.h>
 | 
					
						
							|  |  |  | #include <stdlib.h>
 | 
					
						
							|  |  |  | #include <string.h>
 | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  | #include <MNN/AutoTime.hpp>
 | 
					
						
							|  |  |  | #include <MNN/Interpreter.hpp>
 | 
					
						
							|  |  |  | #include <MNN/Tensor.hpp>
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | #include <fstream>
 | 
					
						
							|  |  |  | #include <map>
 | 
					
						
							|  |  |  | #include <sstream>
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "core/Backend.hpp"
 | 
					
						
							|  |  |  | #include "core/Macro.h"
 | 
					
						
							|  |  |  | #include "core/TensorUtils.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | #define NONE "\e[0m"
 | 
					
						
							|  |  |  | #define RED "\e[0;31m"
 | 
					
						
							|  |  |  | #define GREEN "\e[0;32m"
 | 
					
						
							|  |  |  | #define L_GREEN "\e[1;32m"
 | 
					
						
							|  |  |  | #define BLUE "\e[0;34m"
 | 
					
						
							|  |  |  | #define L_BLUE "\e[1;34m"
 | 
					
						
							|  |  |  | #define BOLD "\e[1m"
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-02-26 21:24:38 +08:00
										 |  |  | template<typename T> | 
					
						
							|  |  |  | inline T stringConvert(const char* number) { | 
					
						
							|  |  |  |     std::istringstream os(number); | 
					
						
							|  |  |  |     T v; | 
					
						
							|  |  |  |     os >> v; | 
					
						
							|  |  |  |     return v; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | MNN::Tensor* createTensor(const MNN::Tensor* shape, const char* path) { | 
					
						
							|  |  |  |     std::ifstream stream(path); | 
					
						
							|  |  |  |     if (stream.fail()) { | 
					
						
							|  |  |  |         return NULL; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto result = new MNN::Tensor(shape, shape->getDimensionType()); | 
					
						
							|  |  |  |     auto data   = result->host<float>(); | 
					
						
							|  |  |  |     for (int i = 0; i < result->elementSize(); ++i) { | 
					
						
							|  |  |  |         double temp = 0.0f; | 
					
						
							|  |  |  |         stream >> temp; | 
					
						
							|  |  |  |         data[i] = temp; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     stream.close(); | 
					
						
							|  |  |  |     return result; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | int main(int argc, const char* argv[]) { | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     // check given & expect
 | 
					
						
							|  |  |  |     const char* modelPath  = argv[1]; | 
					
						
							|  |  |  |     const char* givenName  = argv[2]; | 
					
						
							|  |  |  |     const char* expectName = argv[3]; | 
					
						
							|  |  |  |     MNN_PRINT("Testing model %s, input: %s, output: %s\n", modelPath, givenName, expectName); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // create net
 | 
					
						
							|  |  |  |     auto type = MNN_FORWARD_CPU; | 
					
						
							|  |  |  |     if (argc > 4) { | 
					
						
							| 
									
										
										
										
											2020-02-26 21:24:38 +08:00
										 |  |  |         type = (MNNForwardType)stringConvert<int>(argv[4]); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |     auto tolerance = 0.1f; | 
					
						
							|  |  |  |     if (argc > 5) { | 
					
						
							| 
									
										
										
										
											2020-02-26 21:24:38 +08:00
										 |  |  |         tolerance = stringConvert<float>(argv[5]); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     MNN::BackendConfig::PrecisionMode precision = MNN::BackendConfig::Precision_High; | 
					
						
							|  |  |  |     if (argc > 6) { | 
					
						
							|  |  |  |         precision = (MNN::BackendConfig::PrecisionMode)stringConvert<int>(argv[6]); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     std::shared_ptr<MNN::Interpreter> net = | 
					
						
							|  |  |  |         std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromFile(modelPath)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // create session
 | 
					
						
							|  |  |  |     MNN::ScheduleConfig config; | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  |     config.type = type; | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     MNN::BackendConfig backendConfig; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     backendConfig.precision = precision; | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     config.backendConfig = &backendConfig; | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  |     auto session         = net->createSession(config); | 
					
						
							| 
									
										
										
										
											2021-06-21 15:13:10 +08:00
										 |  |  |      | 
					
						
							|  |  |  |     // input dims
 | 
					
						
							|  |  |  |     std::vector<int> inputDims; | 
					
						
							|  |  |  |     if (argc > 7) { | 
					
						
							|  |  |  |         std::string inputShape(argv[7]); | 
					
						
							|  |  |  |         const char* delim = "x"; | 
					
						
							|  |  |  |         std::ptrdiff_t p1 = 0, p2; | 
					
						
							|  |  |  |         while (1) { | 
					
						
							|  |  |  |             p2 = inputShape.find(delim, p1); | 
					
						
							|  |  |  |             if (p2 != std::string::npos) { | 
					
						
							|  |  |  |                 inputDims.push_back(atoi(inputShape.substr(p1, p2 - p1).c_str())); | 
					
						
							|  |  |  |                 p1 = p2 + 1; | 
					
						
							|  |  |  |             } else { | 
					
						
							|  |  |  |                 inputDims.push_back(atoi(inputShape.substr(p1).c_str())); | 
					
						
							|  |  |  |                 break; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     for (auto dim : inputDims) { | 
					
						
							|  |  |  |         MNN_PRINT("%d ", dim); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     MNN_PRINT("\n"); | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  |      | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     auto allInput = net->getSessionInputAll(session); | 
					
						
							|  |  |  |     for (auto& iter : allInput) { | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         auto inputTensor = iter.second; | 
					
						
							| 
									
										
										
										
											2021-06-21 15:13:10 +08:00
										 |  |  |          | 
					
						
							|  |  |  |         if (!inputDims.empty()) { | 
					
						
							|  |  |  |             MNN_PRINT("===========> Resize Tensor...\n"); | 
					
						
							|  |  |  |             net->resizeTensor(inputTensor, inputDims); | 
					
						
							|  |  |  |             net->resizeSession(session); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |          | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         auto size = inputTensor->size(); | 
					
						
							|  |  |  |         if (size <= 0) { | 
					
						
							|  |  |  |             continue; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         MNN::Tensor tempTensor(inputTensor, inputTensor->getDimensionType()); | 
					
						
							|  |  |  |         ::memset(tempTensor.host<void>(), 0, tempTensor.size()); | 
					
						
							|  |  |  |         inputTensor->copyFromHostTensor(&tempTensor); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // write input tensor
 | 
					
						
							|  |  |  |     auto inputTensor = net->getSessionInput(session, NULL); | 
					
						
							| 
									
										
										
										
											2020-06-05 17:34:53 +08:00
										 |  |  |     std::shared_ptr<MNN::Tensor> givenTensor(createTensor(inputTensor, givenName)); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     if (!givenTensor) { | 
					
						
							| 
									
										
										
										
											2019-06-10 21:08:55 +08:00
										 |  |  | #if defined(_MSC_VER)
 | 
					
						
							|  |  |  |         printf("Failed to open input file %s.\n", givenName); | 
					
						
							|  |  |  | #else
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         printf(RED "Failed to open input file %s.\n" NONE, givenName); | 
					
						
							| 
									
										
										
										
											2019-06-10 21:08:55 +08:00
										 |  |  | #endif
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         return -1; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2020-06-05 17:34:53 +08:00
										 |  |  |     // First time
 | 
					
						
							|  |  |  |     inputTensor->copyFromHostTensor(givenTensor.get()); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     // infer
 | 
					
						
							|  |  |  |     net->runSession(session); | 
					
						
							|  |  |  |     // read expect tensor
 | 
					
						
							|  |  |  |     auto outputTensor = net->getSessionOutput(session, NULL); | 
					
						
							|  |  |  |     std::shared_ptr<MNN::Tensor> expectTensor(createTensor(outputTensor, expectName)); | 
					
						
							|  |  |  |     if (!expectTensor.get()) { | 
					
						
							| 
									
										
										
										
											2019-06-10 21:08:55 +08:00
										 |  |  | #if defined(_MSC_VER)
 | 
					
						
							|  |  |  |         printf("Failed to open expect file %s.\n", expectName); | 
					
						
							|  |  |  | #else
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         printf(RED "Failed to open expect file %s.\n" NONE, expectName); | 
					
						
							| 
									
										
										
										
											2019-06-10 21:08:55 +08:00
										 |  |  | #endif
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         return -1; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // compare output with expect
 | 
					
						
							|  |  |  |     bool correct = MNN::TensorUtils::compareTensors(outputTensor, expectTensor.get(), tolerance, true); | 
					
						
							| 
									
										
										
										
											2020-06-05 17:34:53 +08:00
										 |  |  |     if (!correct) { | 
					
						
							|  |  |  | #if defined(_MSC_VER)
 | 
					
						
							|  |  |  |         printf("Test Failed %s!\n", modelPath); | 
					
						
							|  |  |  | #else
 | 
					
						
							|  |  |  |         printf(RED "Test Failed %s!\n" NONE, modelPath); | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  |         return -1; | 
					
						
							|  |  |  |     } else { | 
					
						
							|  |  |  |         printf("First run pass\n"); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     // Run Second time
 | 
					
						
							|  |  |  |     inputTensor->copyFromHostTensor(givenTensor.get()); | 
					
						
							|  |  |  |     // infer
 | 
					
						
							|  |  |  |     net->runSession(session); | 
					
						
							|  |  |  |     // read expect tensor
 | 
					
						
							|  |  |  |     std::shared_ptr<MNN::Tensor> expectTensor2(createTensor(outputTensor, expectName)); | 
					
						
							|  |  |  |     correct = MNN::TensorUtils::compareTensors(outputTensor, expectTensor2.get(), tolerance, true); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     if (correct) { | 
					
						
							| 
									
										
										
										
											2019-06-10 21:08:55 +08:00
										 |  |  | #if defined(_MSC_VER)
 | 
					
						
							|  |  |  |         printf("Test %s Correct!\n", modelPath); | 
					
						
							|  |  |  | #else
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         printf(GREEN BOLD "Test %s Correct!\n" NONE, modelPath); | 
					
						
							| 
									
										
										
										
											2019-06-10 21:08:55 +08:00
										 |  |  | #endif
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     } else { | 
					
						
							| 
									
										
										
										
											2019-06-10 21:08:55 +08:00
										 |  |  | #if defined(_MSC_VER)
 | 
					
						
							|  |  |  |         printf("Test Failed %s!\n", modelPath); | 
					
						
							|  |  |  | #else
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         printf(RED "Test Failed %s!\n" NONE, modelPath); | 
					
						
							| 
									
										
										
										
											2019-06-10 21:08:55 +08:00
										 |  |  | #endif
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |     return 0; | 
					
						
							|  |  |  | } |