| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  ZeroShapeTest.cpp
 | 
					
						
							|  |  |  | //  MNNTests
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2019/12/18.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <MNN/expr/Expr.hpp>
 | 
					
						
							|  |  |  | #include <MNN/expr/ExprCreator.hpp>
 | 
					
						
							|  |  |  | #include "MNNTestSuite.h"
 | 
					
						
							|  |  |  | #include "TestUtils.h"
 | 
					
						
							|  |  |  | #include "MNN_generated.h"
 | 
					
						
							|  |  |  | using namespace MNN::Express; | 
					
						
							| 
									
										
										
										
											2024-04-19 11:58:21 +08:00
										 |  |  | class ZeroShapeResizeTest : public MNNTestCase { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~ZeroShapeResizeTest() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         auto input = _Input({1, 1, 4, 1}, NHWC); | 
					
						
							|  |  |  |         input->setName("input"); | 
					
						
							|  |  |  |         input->writeMap<float>(); | 
					
						
							|  |  |  |         auto output    = _Reshape(input, {-1}); | 
					
						
							|  |  |  |         auto outputPtr = output->readMap<float>(); | 
					
						
							|  |  |  |         input->resize({1, 0, 4, 1}); | 
					
						
							|  |  |  |         input->writeMap<float>(); | 
					
						
							|  |  |  |         auto info = output->getInfo(); | 
					
						
							|  |  |  |         outputPtr = output->readMap<float>(); | 
					
						
							|  |  |  |         if (info->size != 0) { | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | class ZeroShapeTest : public MNNTestCase { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~ZeroShapeTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         auto input = _Input({1, 0, 4, 1}, NHWC); | 
					
						
							|  |  |  |         input->setName("input"); | 
					
						
							| 
									
										
										
										
											2022-02-18 11:30:27 +08:00
										 |  |  |         auto output    = _Reshape(input, {1, 0, -1}); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         auto info      = output->getInfo(); | 
					
						
							| 
									
										
										
										
											2022-02-18 11:30:27 +08:00
										 |  |  |         auto rightDims = std::vector<int>{1, 0, 0}; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         if (info->dim[0] != rightDims[0] || info->dim[1] != rightDims[1] || info->dim[2] != rightDims[2]) { | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | class ZeroShapeTest2 : public MNNTestCase { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~ZeroShapeTest2() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         auto input = _Input({1, -1, 4, 1}, NHWC); | 
					
						
							|  |  |  |         input->setName("input"); | 
					
						
							|  |  |  |         auto output = _Reshape(input, {0, 0, -1}); | 
					
						
							|  |  |  |         auto info   = output->getInfo(); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         if (nullptr != info) { | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | class ZeroShapeTest3 : public MNNTestCase { | 
					
						
							|  |  |  | public: | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         auto input = _Input({1, 0, 4, 1}, NHWC); | 
					
						
							|  |  |  |         input->setName("input"); | 
					
						
							|  |  |  |         std::unique_ptr<MNN::OpT> op(new MNN::OpT); | 
					
						
							|  |  |  |         op->type = MNN::OpType_Unpack; | 
					
						
							|  |  |  |         op->main.value = new MNN::AxisT; | 
					
						
							|  |  |  |         op->main.type = MNN::OpParameter_Axis; | 
					
						
							|  |  |  |         op->main.AsAxis()->axis = 1; | 
					
						
							|  |  |  |         auto expr = Expr::create(op.get(), {input}, 3); | 
					
						
							|  |  |  |         auto output = Variable::create(expr, 0); | 
					
						
							|  |  |  |         auto info   = output->getInfo(); | 
					
						
							|  |  |  |         if (nullptr != info) { | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |             FUNC_PRINT(1); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         auto sliceOutput = _Split(input, {4}, 2); | 
					
						
							|  |  |  |         std::vector<int> dstDims = {1, 0, 1, 1}; | 
					
						
							|  |  |  |         for (auto s : sliceOutput) { | 
					
						
							|  |  |  |             auto info = s->getInfo(); | 
					
						
							|  |  |  |             if (info->dim != dstDims) { | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |                 FUNC_PRINT(1); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |                 return false; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             auto ptr = s->readMap<float>(); | 
					
						
							|  |  |  |             if (nullptr != ptr) { | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |                 FUNC_PRINT(1); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |                 return false; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         std::vector<int> padds = {0, 0, 1, 0, 0, 0, 0, 0}; | 
					
						
							|  |  |  |         auto paddings = _Const(padds.data(), {2, 4}, NHWC, halide_type_of<int>()); | 
					
						
							|  |  |  |         auto padOutput = _Pad(input, paddings); | 
					
						
							|  |  |  |         auto padinfo = padOutput->getInfo(); | 
					
						
							|  |  |  |         if (padinfo->dim != std::vector<int>{1, 1, 4, 1}) { | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |             FUNC_PRINT(1); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         input->writeMap<float>(); | 
					
						
							|  |  |  |         auto ptr = padOutput->readMap<float>(); | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |         if (nullptr == ptr) { | 
					
						
							|  |  |  |             FUNC_PRINT(1); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         for (int i = 0; i < padinfo->size; ++i) { | 
					
						
							|  |  |  |             if (ptr[i] > 0.000001f) { | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |                 FUNC_PRINT(1); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |                 return false; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-08-30 13:26:34 +08:00
										 |  |  | class ZeroShapeTest4 : public MNNTestCase { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         auto input = _Input({}, NHWC); | 
					
						
							|  |  |  |         input->setName("input"); | 
					
						
							|  |  |  |         auto output = _Shape(input, NCHW); | 
					
						
							|  |  |  |         auto info   = output->getInfo(); | 
					
						
							|  |  |  |         if (nullptr == info) { | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (info->dim[0] != 0) { | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2024-04-19 11:58:21 +08:00
										 |  |  | MNNTestSuiteRegister(ZeroShapeResizeTest, "expr/zeroshaperesize"); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | MNNTestSuiteRegister(ZeroShapeTest, "expr/zeroshape"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(ZeroShapeTest2, "expr/zeroshape2"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(ZeroShapeTest3, "expr/zeroshape3"); | 
					
						
							| 
									
										
										
										
											2021-08-30 13:26:34 +08:00
										 |  |  | MNNTestSuiteRegister(ZeroShapeTest4, "expr/zeroshape4"); |