mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			75 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			75 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  AllAnyTest.cpp
 | |
| //  MNNTests
 | |
| //
 | |
| //  Created by MNN on 2019/09/10.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include <MNN/expr/ExprCreator.hpp>
 | |
| #include "MNNTestSuite.h"
 | |
| #include <MNN/expr/ExecutorScope.hpp>
 | |
| 
 | |
| using namespace MNN::Express;
 | |
| 
 | |
| class AllAnyTest : public MNNTestCase {
 | |
| public:
 | |
|     bool _run(int precision, bool lazy) {
 | |
|         auto y                = _Input({4}, NHWC, halide_type_of<int32_t>());
 | |
|         std::vector<int> seq0 = {1, 0, 0, 1};
 | |
|         std::vector<int> seq1 = {1, 1, 1, 1};
 | |
|         std::vector<int> seq2 = {0, 0, 0, 0};
 | |
|         auto yPtr             = y->writeMap<int32_t>();
 | |
|         ::memcpy(yPtr, seq0.data(), seq0.size() * sizeof(int32_t));
 | |
|         auto zAny     = _ReduceAny(y, {0});
 | |
|         auto zAll     = _ReduceAll(y, {0});
 | |
|         auto zAnyinfo = zAny->getInfo();
 | |
|         if (zAny->readMap<int32_t>()[0] != 1) {
 | |
|             FUNC_PRINT(1);
 | |
|             return false;
 | |
|         }
 | |
|         if (zAll->readMap<int32_t>()[0] != 0) {
 | |
|             FUNC_PRINT(1);
 | |
|             return false;
 | |
|         }
 | |
|         if (!lazy) {
 | |
|             return true;
 | |
|         }
 | |
|         // Call WriteMap to Refresh Compute
 | |
|         yPtr = y->writeMap<int32_t>();
 | |
|         ::memcpy(yPtr, seq1.data(), seq1.size() * sizeof(int32_t));
 | |
|         if (zAny->readMap<int32_t>()[0] != 1) {
 | |
|             FUNC_PRINT(1);
 | |
|             return false;
 | |
|         }
 | |
|         if (zAll->readMap<int32_t>()[0] != 1) {
 | |
|             FUNC_PRINT(1);
 | |
|             return false;
 | |
|         }
 | |
|         // Call WriteMap to Refresh Compute
 | |
|         yPtr = y->writeMap<int32_t>();
 | |
|         ::memcpy(yPtr, seq2.data(), seq2.size() * sizeof(int32_t));
 | |
|         if (zAny->readMap<int32_t>()[0] != 0) {
 | |
|             FUNC_PRINT(1);
 | |
|             return false;
 | |
|         }
 | |
|         if (zAll->readMap<int32_t>()[0] != 0) {
 | |
|             FUNC_PRINT(1);
 | |
|             return false;
 | |
|         }
 | |
|         return true;
 | |
|     }
 | |
|     virtual bool run(int precision) {
 | |
|         ExecutorScope::Current()->lazyEval = false;
 | |
|         auto res = _run(precision, false);
 | |
|         if (!res) {
 | |
|             FUNC_PRINT(1);
 | |
|             return false;
 | |
|         }
 | |
|         ExecutorScope::Current()->lazyEval = true;
 | |
|         res = _run(precision, true);
 | |
|         return res;
 | |
|     }
 | |
| };
 | |
| MNNTestSuiteRegister(AllAnyTest, "expr/AllAny");
 |