mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			99 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			99 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  MatrixBandTest.cpp
 | |
| //  MNNTests
 | |
| //
 | |
| //  Created by MNN on 2019/09/17.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| /*
 | |
|  Test Case From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/matrix-band-part
 | |
|  */
 | |
| #include <MNN/expr/ExprCreator.hpp>
 | |
| #include "MNNTestSuite.h"
 | |
| #include "MNN_generated.h"
 | |
| using namespace MNN::Express;
 | |
| 
 | |
| class MatrixBandTest : public MNNTestCase {
 | |
| public:
 | |
|     virtual bool run(int precision) {
 | |
|         std::unique_ptr<MNN::OpT> MatrixBandOp(new MNN::OpT);
 | |
|         MatrixBandOp->type        = MNN::OpType_MatrixBandPart;
 | |
|         auto matrix               = _Input({4, 4}, NHWC, halide_type_of<float>());
 | |
|         auto lower                = _Input({}, NHWC, halide_type_of<int32_t>());
 | |
|         auto upper                = _Input({}, NHWC, halide_type_of<int32_t>());
 | |
|         auto y                    = Variable::create(Expr::create(MatrixBandOp.get(), {matrix, lower, upper}));
 | |
|         std::vector<float> values = {0.0f,  1.0f,  2.0f, 3.0f, -1.0f, 0.0f,  1.0f,  2.0f,
 | |
|                                      -2.0f, -1.0f, 0.0f, 1.0f, -3.0f, -2.0f, -1.0f, 0.0f};
 | |
|         ::memcpy(matrix->writeMap<float>(), values.data(), values.size() * sizeof(float));
 | |
|         {
 | |
|             lower->writeMap<int>()[0] = 1;
 | |
|             upper->writeMap<int>()[0] = -1;
 | |
|             {
 | |
|                 auto yPtr = y->readMap<float>();
 | |
|                 for (int h = 0; h < 4; ++h) {
 | |
|                     for (int w = 0; w < 4; ++w) {
 | |
|                         auto computed = yPtr[4 * h + w];
 | |
|                         auto expected = 0.0f;
 | |
|                         if (h - w <= 1) {
 | |
|                             expected = values[4 * h + w];
 | |
|                         }
 | |
|                         if (computed != expected) {
 | |
|                             FUNC_PRINT(1);
 | |
|                             return false;
 | |
|                         }
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|         {
 | |
|             lower->writeMap<int>()[0] = 2;
 | |
|             upper->writeMap<int>()[0] = 1;
 | |
|             {
 | |
|                 auto yPtr = y->readMap<float>();
 | |
|                 for (int h = 0; h < 4; ++h) {
 | |
|                     for (int w = 0; w < 4; ++w) {
 | |
|                         auto computed = yPtr[4 * h + w];
 | |
|                         auto expected = 0.0f;
 | |
|                         if ((h - w) <= 2 && (w - h) <= 1) {
 | |
|                             expected = values[4 * h + w];
 | |
|                         }
 | |
|                         if (computed != expected) {
 | |
|                             FUNC_PRINT(1);
 | |
|                             return false;
 | |
|                         }
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|         {
 | |
|             matrix->resize({3, 5, 5});
 | |
|             auto matrixPtr = matrix->writeMap<float>();
 | |
|             for (int i = 0; i < matrix->getInfo()->size; ++i) {
 | |
|                 matrixPtr[i] = (float)i;
 | |
|             }
 | |
|             lower->writeMap<int>()[0] = 2;
 | |
|             upper->writeMap<int>()[0] = 1;
 | |
|             auto yPtr                 = y->readMap<float>();
 | |
|             for (int z = 0; z < 3; ++z) {
 | |
|                 for (int h = 0; h < 5; ++h) {
 | |
|                     for (int w = 0; w < 5; ++w) {
 | |
|                         auto index    = w + 5 * h + 5 * 5 * z;
 | |
|                         auto computed = yPtr[index];
 | |
|                         auto expected = 0.0f;
 | |
|                         if ((h - w) <= 2 && (w - h) <= 1) {
 | |
|                             expected = (float)(index);
 | |
|                         }
 | |
|                         if (computed != expected) {
 | |
|                             FUNC_PRINT(1);
 | |
|                             return false;
 | |
|                         }
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
|         }
 | |
|         return true;
 | |
|     }
 | |
| };
 | |
| MNNTestSuiteRegister(MatrixBandTest, "expr/MatrixBand");
 |