| 
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 |  |  | #include <math.h>
 | 
					
						
							|  |  |  | #include <MNN/expr/ExprCreator.hpp>
 | 
					
						
							|  |  |  | #include <MNN/expr/Module.hpp>
 | 
					
						
							|  |  |  | #include "MNNTestSuite.h"
 | 
					
						
							| 
									
										
										
										
											2025-08-01 17:19:36 +08:00
										 |  |  | #include "TestUtils.h"
 | 
					
						
							| 
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 |  |  | using namespace MNN; | 
					
						
							|  |  |  | using namespace MNN::Express; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ModuleShapeInfer : public MNNTestCase { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     static float _reduceSum(const float* zPtr, int size) { | 
					
						
							|  |  |  |         float summer = 0.0f; | 
					
						
							|  |  |  |         for (int i=0; i<size; ++i) { | 
					
						
							|  |  |  |             summer+=zPtr[i]; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return summer; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2025-08-01 17:19:36 +08:00
										 |  |  |         auto executor = cloneCurrentExecutor(); | 
					
						
							|  |  |  |         ExecutorScope scope(executor); | 
					
						
							| 
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 |  |  |         std::vector<VARP> empty; | 
					
						
							|  |  |  |         // Make Net
 | 
					
						
							|  |  |  |         auto x = _Input({1, 3, 2, 2}, NCHW, halide_type_of<float>()); | 
					
						
							|  |  |  |         x->setName("x"); | 
					
						
							|  |  |  |         auto y = x * x; | 
					
						
							|  |  |  |         VARP starts; | 
					
						
							|  |  |  |         VARP sizes; | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             std::vector<int> sta = {0, 0, 1, 1}; | 
					
						
							|  |  |  |             std::vector<int> siz = {1, 1, 1, 1}; | 
					
						
							|  |  |  |             starts = _Const(sta.data(), {4}, NCHW, halide_type_of<int>()); | 
					
						
							|  |  |  |             sizes = _Const(siz.data(), {4}, NCHW, halide_type_of<int>()); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         auto z = _Slice(y, starts, sizes); | 
					
						
							|  |  |  |         z->setName("z"); | 
					
						
							|  |  |  |         auto buffer = Variable::save({z}); | 
					
						
							|  |  |  |         ScheduleConfig config; | 
					
						
							|  |  |  |         BackendConfig bnConfig; | 
					
						
							|  |  |  |         bnConfig.precision = MNN::BackendConfig::Precision_Low; | 
					
						
							|  |  |  |         config.backendConfig = &bnConfig; | 
					
						
							|  |  |  |         std::shared_ptr<Executor::RuntimeManager> rt(Executor::RuntimeManager::createRuntimeManager(config), Executor::RuntimeManager::destroy); | 
					
						
							|  |  |  |         std::shared_ptr<Module> net0(Module::load({"x"}, {"z"}, (const uint8_t*)buffer.data(), buffer.size(), rt), Module::destroy); | 
					
						
							|  |  |  |         std::shared_ptr<Module> net1(Module::load({"x"}, {"z"}, (const uint8_t*)buffer.data(), buffer.size(), rt), Module::destroy); | 
					
						
							|  |  |  |         x = _Input({1, 3, 2, 2}, NCHW, halide_type_of<float>()); | 
					
						
							|  |  |  |         // Run Init Value
 | 
					
						
							|  |  |  |         auto inputPtr = x->writeMap<float>(); | 
					
						
							|  |  |  |         for (int i=0; i<x->getInfo()->size; ++i) { | 
					
						
							|  |  |  |             inputPtr[i] = i; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         y = net0->onForward({x})[0]; | 
					
						
							|  |  |  |         auto yPtr = y->readMap<float>(); | 
					
						
							|  |  |  |         auto ySize = y->getInfo()->size; | 
					
						
							|  |  |  |         auto valueFirst = _reduceSum(yPtr, ySize); | 
					
						
							|  |  |  |         for (int i=0; i<x->getInfo()->size; ++i) { | 
					
						
							|  |  |  |             inputPtr[i] = x->getInfo()->size - i; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         y = net0->onForward({x})[0]; | 
					
						
							|  |  |  |         yPtr = y->readMap<float>(); | 
					
						
							|  |  |  |         auto valueSecond = _reduceSum(yPtr, ySize); | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         // Shape Infer mode
 | 
					
						
							|  |  |  |         auto code = net1->traceOrOptimize(Interpreter::Module_Forward_Separate); | 
					
						
							|  |  |  |         if (0 != code) { | 
					
						
							|  |  |  |             FUNC_PRINT(1); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (int i=0; i<x->getInfo()->size; ++i) { | 
					
						
							|  |  |  |             inputPtr[i] = i; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         y = net1->onForward({x})[0]; | 
					
						
							|  |  |  |         yPtr = y->readMap<float>(); | 
					
						
							|  |  |  |         auto tmp = net1->onForward(empty); | 
					
						
							|  |  |  |         if (tmp.size() > 0) { | 
					
						
							|  |  |  |             FUNC_PRINT(1); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (_reduceSum(yPtr, ySize) != valueFirst) { | 
					
						
							|  |  |  |             FUNC_PRINT(1); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (int i=0; i<x->getInfo()->size; ++i) { | 
					
						
							|  |  |  |             inputPtr[i] = x->getInfo()->size - i; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         net1->onForward(empty); | 
					
						
							|  |  |  |         if (_reduceSum(yPtr, ySize) != valueSecond) { | 
					
						
							|  |  |  |             FUNC_PRINT(1); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         net1->traceOrOptimize(MNN::Interpreter::Module_Forward_Combine); | 
					
						
							|  |  |  |         for (int i=0; i<x->getInfo()->size; ++i) { | 
					
						
							|  |  |  |             inputPtr[i] = i; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         y = net1->onForward({x})[0]; | 
					
						
							|  |  |  |         yPtr = y->readMap<float>(); | 
					
						
							|  |  |  |         if(_reduceSum(yPtr, ySize) != valueFirst) { | 
					
						
							|  |  |  |             FUNC_PRINT(1); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (int i=0; i<x->getInfo()->size; ++i) { | 
					
						
							|  |  |  |             inputPtr[i] = x->getInfo()->size - i; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         y = net1->onForward({x})[0]; | 
					
						
							|  |  |  |         yPtr = y->readMap<float>(); | 
					
						
							|  |  |  |         if(_reduceSum(yPtr, ySize) != valueSecond) { | 
					
						
							|  |  |  |             FUNC_PRINT(1); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2025-07-23 14:10:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | class VariableSaveLoad: public MNNTestCase { // Verify the order of load is the same as the order of save
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         std::vector<MNN::Express::VARP> vars; | 
					
						
							|  |  |  |         std::vector<int32_t> contents(4); | 
					
						
							|  |  |  |         std::string file = "file.txt"; | 
					
						
							|  |  |  |         contents[0] = 0; | 
					
						
							|  |  |  |         contents[1] = 1; | 
					
						
							|  |  |  |         contents[2] = 2; | 
					
						
							|  |  |  |         contents[3] = 3; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for (auto number: contents) { | 
					
						
							|  |  |  |             auto var = MNN::Express::_Const(&number, {1}, MNN::Express::NHWC, halide_type_of<int32_t>()); | 
					
						
							|  |  |  |             if (var->getInfo() == nullptr) { | 
					
						
							|  |  |  |                 MNN_PRINT("error\n"); | 
					
						
							|  |  |  |                 return false; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             vars.emplace_back(var); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         MNN::Express::Variable::save(vars, file.c_str()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto readVars = MNN::Express::Variable::load(file.c_str()); | 
					
						
							|  |  |  |         std::vector<int32_t> readContents; | 
					
						
							|  |  |  |         for (auto var_: readVars) { | 
					
						
							|  |  |  |             auto var_ptr = var_->getInfo(); | 
					
						
							|  |  |  |             if (var_ptr == nullptr) { | 
					
						
							|  |  |  |                 MNN_PRINT("error\n"); | 
					
						
							|  |  |  |                 return false; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             readContents.push_back(var_->readMap<int32_t>()[0]); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (int i = 0; i < 4; ++i) { | 
					
						
							|  |  |  |             if (readContents[i] != contents[i]) { | 
					
						
							|  |  |  |                 MNN_PRINT("error %d: read %d, expect %d\n", i, readContents[i], contents[i]); | 
					
						
							|  |  |  |                 return false; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         int result = std::remove(file.c_str()); | 
					
						
							|  |  |  |         if (result == 0) { | 
					
						
							|  |  |  |             MNN_PRINT("delete file success\n"); | 
					
						
							|  |  |  |             return true; | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             MNN_PRINT("delete file failed\n"); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 |  |  | MNNTestSuiteRegister(ModuleShapeInfer, "expr/ModuleShapeInfer"); | 
					
						
							| 
									
										
										
										
											2025-07-23 14:10:58 +08:00
										 |  |  | MNNTestSuiteRegister(VariableSaveLoad, "variable/saveLoad"); |