| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  BinaryOPTest.cpp
 | 
					
						
							|  |  |  | //  MNNTests
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2019/01/15.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include <MNN/expr/Expr.hpp>
 | 
					
						
							|  |  |  | #include <MNN/expr/ExprCreator.hpp>
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | #include "MNNTestSuite.h"
 | 
					
						
							|  |  |  | #include "TestUtils.h"
 | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | #include "MNN_generated.h"
 | 
					
						
							|  |  |  | #include "core/TensorUtils.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | using namespace MNN; | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | using namespace MNN::Express; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | using namespace std; | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class BinaryTestCommon : public MNNTestCase { | 
					
						
							|  |  |  | protected: | 
					
						
							|  |  |  |     template<typename Tin, typename Tout> | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  |     bool test(VARP (*opFunc)(VARP, VARP), string name, float threshold, | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |               const vector<Tin>& data_x, const vector<Tin>& data_y, const vector<Tout>& data_out, | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  |               const vector<int>& shape_x, const vector<int>& shape_y, const vector<int>& shape_out, const vector<float> quantScales={}, const vector<float> zeroPoints={}) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         int size_x = 1, size_y = 1, size_out = 1; | 
					
						
							|  |  |  |         for (int i = 0; i < shape_x.size(); ++i) { | 
					
						
							|  |  |  |             size_x *= shape_x[i]; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (int i = 0; i < shape_y.size(); ++i) { | 
					
						
							|  |  |  |             size_y *= shape_y[i]; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (int i = 0; i < shape_y.size(); ++i) { | 
					
						
							|  |  |  |             size_out *= shape_out[i]; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         auto input_x = _Input(shape_x, NCHW, halide_type_of<Tin>()); | 
					
						
							|  |  |  |         auto input_y = _Input(shape_y, NCHW, halide_type_of<Tin>()); | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |         input_x->setName("input_x"); | 
					
						
							|  |  |  |         input_y->setName("input_y"); | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  |         if (quantScales.size() > 1) { | 
					
						
							|  |  |  |         input_x->writeScaleMap(quantScales[0], zeroPoints[0]); | 
					
						
							|  |  |  |         input_y->writeScaleMap(quantScales[1], zeroPoints[1]); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         // set input data
 | 
					
						
							|  |  |  |         auto ptr_x = input_x->template writeMap<Tin>(); | 
					
						
							|  |  |  |         auto ptr_y = input_y->template writeMap<Tin>(); | 
					
						
							|  |  |  |         memcpy(ptr_x, data_x.data(), size_x * sizeof(Tin)); | 
					
						
							|  |  |  |         memcpy(ptr_y, data_y.data(), size_y * sizeof(Tin)); | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         input_x->unMap(); | 
					
						
							|  |  |  |         input_y->unMap(); | 
					
						
							|  |  |  |         auto output = opFunc(input_x, input_y); | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  |         if (quantScales.size() > 0){ | 
					
						
							|  |  |  |             output->writeScaleMap(quantScales[2], zeroPoints[2]); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         auto gotOutput = output->template readMap<Tout>(); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         auto shape_got = output->getInfo()->dim; | 
					
						
							|  |  |  |         if (shape_got.size() != shape_out.size()) { | 
					
						
							|  |  |  |             MNN_ERROR("%s shape compute error!\n", name.c_str()); | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         for (int i = 0; i < shape_got.size(); i++) { | 
					
						
							|  |  |  |             if (shape_got[i] != shape_out[i]) { | 
					
						
							|  |  |  |                 MNN_ERROR("%s shape compute error!\n", name.c_str()); | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |                 return false; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  |         if (quantScales.size() > 0) { | 
					
						
							|  |  |  |             for (int i = 0; i < size_out; ++i) { | 
					
						
							|  |  |  |                 auto error = (int32_t)data_out[i] - (int32_t)gotOutput[i]; | 
					
						
							|  |  |  |                 if (error * error > 1) { | 
					
						
							|  |  |  |                     MNN_PRINT("%s Test error: compute result=%d, right value=%d\n", name.c_str(), (int32_t)gotOutput[i], (int32_t)data_out[i]); | 
					
						
							|  |  |  |                     return false; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             return true; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2022-05-06 19:51:20 +08:00
										 |  |  |         if (!checkVectorByRelativeError<Tout>(gotOutput, data_out.data(), size_out, threshold)) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |             MNN_ERROR("%s test failed!\n", name.c_str()); | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class AddTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~AddTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, float>(_Add, "AddTest", 0.01, | 
					
						
							|  |  |  |                     {-1.0, -2.0, -3.0, -4.0}, {1.0, 2.0, 3.0, 4.0}, {0.0, 0.0, 0.0, 0.0}, | 
					
						
							|  |  |  |                     {4}, {4}, {4}); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | class AddInt8Test : public BinaryTestCommon { | 
					
						
							|  |  |  |     public:  | 
					
						
							|  |  |  |         virtual ~AddInt8Test() = default; | 
					
						
							|  |  |  |         virtual bool run(int precision) { | 
					
						
							|  |  |  |         vector<float> inp2 = {1.1, 2.2, 3.3, 4.6}, inp1 = {2}; | 
					
						
							|  |  |  |             vector<float> rightResult = {3.1, 4.2, 5.3, 6.6}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return test<float, float>(_Add, "AddInt8Test", 0.01, inp1, inp2, rightResult, {1}, {4}, {4}, {0.4, 0.4, 0.4}, | 
					
						
							|  |  |  |                                   {0., 0., 0.}); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class SubtractTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~SubtractTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, float>(_Subtract, "SubtractTest", 0.01, | 
					
						
							|  |  |  |                     {-1.0, -2.0, -3.0, -4.0}, {1.0, 2.0, 3.0, 4.0}, {-2.0, -4.0, -6.0, -8.0}, | 
					
						
							|  |  |  |                     {4}, {4}, {4}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | class SubtractInt8Test : public BinaryTestCommon { | 
					
						
							|  |  |  |     public: | 
					
						
							|  |  |  |         virtual ~SubtractInt8Test() = default; | 
					
						
							|  |  |  |         virtual bool run(int precision) { | 
					
						
							|  |  |  |         vector<float> inp1 = {1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6,1.1, 2.2, 3.3, 4.6,1.1, 2.2, 3.3, 4.6}, inp2 = {5.7}; | 
					
						
							|  |  |  |         vector<float> rightResult = {-4.6, -3.5, -2.4, -1.1, -4.6, -3.5, -2.4, -1.1, -4.6, -3.5, -2.4, | 
					
						
							|  |  |  |                                     -1.1, -4.6, -3.5, -2.4, -1.1}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return test<float, float>(_Subtract, "SubtractInt8Test", 0.01, inp1, inp2, rightResult, | 
					
						
							|  |  |  |                                   {4, 4}, {1}, {4, 4}, {0.4, 0.4, 0.4}, {0., 0., 0.}); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class MultiplyTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~MultiplyTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, float>(_Multiply, "MultiplyTest", 0.01, | 
					
						
							|  |  |  |                     {-1.0, -2.0, -3.0, -4.0}, {1.0, 2.0, 3.0, 4.0}, {-1.0, -4.0, -9.0, -16.0}, | 
					
						
							|  |  |  |                     {4}, {4}, {4}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | class MultiplyInt8Test : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~MultiplyInt8Test() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         vector<float> inp1 = {1.1, 2.2, 3.3, 4.6}, inp2 = {5.7, 2.5, 0.25, 0.43}; | 
					
						
							|  |  |  |         vector<float> rightResult = {6.27 , 5.5  , 0.825, 1.978}; | 
					
						
							|  |  |  |         return test<float, float>(_Multiply, "MultiplyInt8Test", 0.01, inp1, inp2, rightResult, | 
					
						
							|  |  |  |                                   {4}, {4}, {4}, {0.4, 0.4, 0.16}, {0., 0., 0.}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class DivideTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~DivideTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, float>(_Divide, "DivideTest", 0.01, | 
					
						
							|  |  |  |                     {-1.0, -2.0, -3.0, -4.0}, {2.0, 4.0, 6.0, 8.0}, {-0.5, -0.5, -0.5, -0.5}, | 
					
						
							|  |  |  |                     {4}, {4}, {4}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | class DivideInt8Test : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~DivideInt8Test() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         vector<float> inp1 = {1.1, 2.2, 3.3, 4.6}, inp2 = {5.7, 2.5, 2.6, 1.88}; | 
					
						
							|  |  |  |         vector<float> rightResult = {0.19298,  0.88, 1.269, 2.4468}; | 
					
						
							|  |  |  |         return test<float, float>(_Divide, "DivideInt8Test", 0.01, inp1, inp2, rightResult, | 
					
						
							|  |  |  |                                   {4}, {4}, {4}, {0.4, 0.4, 1.0}, {0., 0., 0.}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class PowTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~PowTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2022-05-06 19:51:20 +08:00
										 |  |  |         float errorScale = precision <= MNN::BackendConfig::Precision_High ? 1 : 10; | 
					
						
							|  |  |  |         return test<float, float>(_Pow, "PowTest", 0.01 * errorScale, | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |                     {-1.0, -2.0, -3.0, -4.0}, {2.0, 4.0, 6.0, 4.0}, {1.0, 16.0, 729.0, 256.0}, | 
					
						
							|  |  |  |                     {4}, {4}, {4}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | class PowInt8Test : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~PowInt8Test() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         vector<float> inp1 = {-1.0, -2.0, -3.0, -4.0}, inp2 = {2.0, 4.0, 2, 4.0}; | 
					
						
							|  |  |  |         vector<float> rightResult = {1, 16, 8, 0}; | 
					
						
							|  |  |  |         return test<float, float>(_Pow, "PowInt8Test", 0.01, inp1, inp2, rightResult, | 
					
						
							|  |  |  |                                   {4}, {4}, {4}, {1.0, 1.0, 1.0}, {0., 0., 0.}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class MinimumTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~MinimumTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, float>(_Minimum, "MinimumTest", 0.01, | 
					
						
							|  |  |  |                     {-1.0, -2.0, -3.0, -4.0}, {1.0, 2.0, 3.0, 4.0}, {-1.0, -2.0, -3.0, -4.0}, | 
					
						
							|  |  |  |                     {4}, {4}, {4}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | class MinimumInt8Test : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~MinimumInt8Test() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         vector<float> inp1 = {-1.2, -5.0, 8, 10}, inp2 = {9.3, 3.1, 11.0, 2.9}; | 
					
						
							|  |  |  |         vector<float> rightResult = {-1.2, -5.0, 8, 2.9}; | 
					
						
							|  |  |  |         return test<float, float>(_Minimum, "MinimumInt8Test", 0.01, inp1, inp2, rightResult, | 
					
						
							|  |  |  |                                   {4}, {4}, {4}, {0.4, 0.4, 0.4}, {0., 0., 0.}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class MaximumTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~MaximumTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-28 18:02:10 +08:00
										 |  |  |         return test<float, float>(MNN::Express::_Maximum, "MaximumTest", 0.01, | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |                     {-1.0, -2.0, -3.0, -4.0}, {2.0, 4.0, 6.0, 8.0}, {2.0, 4.0, 6.0, 8.0}, | 
					
						
							|  |  |  |                     {4}, {4}, {4}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | class MaximumInt8Test : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~MaximumInt8Test() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         vector<float> inp1 = {-1, -5, 8, 10}, inp2 = {9}; | 
					
						
							|  |  |  |         vector<float> rightResult = {9, 9, 9, 10}; | 
					
						
							|  |  |  |         return test<float, float>(_Maximum, "MaximumInt8Test", 0.01, inp1, inp2, rightResult, | 
					
						
							|  |  |  |                                   {4}, {1}, {4}, {0.4, 0.4, 0.4}, {0., 0., 0.}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class BiasAddTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~BiasAddTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, float>(_BiasAdd, "BiasAddTest", 0.01, | 
					
						
							|  |  |  |                     {-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0}, | 
					
						
							|  |  |  |                     {1.0, 2.0}, | 
					
						
							|  |  |  |                     {0.0, 0.0, -2.0, -2.0, -4.0, -4.0, -6.0, -6.0}, | 
					
						
							|  |  |  |                     {4, 2}, {2}, {4, 2}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class GreaterTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~GreaterTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, int>(_Greater, "GreaterTest", 0, | 
					
						
							|  |  |  |                     {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, | 
					
						
							|  |  |  |                     {3.0, 4.0}, | 
					
						
							|  |  |  |                     {0, 0, 0, 0, 1, 1, 1, 1}, | 
					
						
							|  |  |  |                     {4, 2}, {2}, {4, 2}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class GreaterEqualTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~GreaterEqualTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, int>(_GreaterEqual, "GreaterEqualTest", 0, | 
					
						
							|  |  |  |                     {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, | 
					
						
							|  |  |  |                     {3.0, 4.0}, | 
					
						
							|  |  |  |                     {0, 0, 1, 1, 1, 1, 1, 1}, | 
					
						
							|  |  |  |                     {4, 2}, {2}, {4, 2}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class LessTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~LessTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, int>(_Less, "LessTest", 0, | 
					
						
							|  |  |  |                     {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, | 
					
						
							|  |  |  |                     {3.0, 4.0}, | 
					
						
							|  |  |  |                     {1, 1, 0, 0, 0, 0, 0, 0}, | 
					
						
							|  |  |  |                     {4, 2}, {2}, {4, 2}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class FloorDivTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~FloorDivTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, float>(_FloorDiv, "FloorDivTest", 0.01, | 
					
						
							|  |  |  |                     {-1.0, -2.0, -3.0, -4.0, 5.0, 6.0, 7.0, 8.1}, | 
					
						
							|  |  |  |                     {3.0, 4.0}, | 
					
						
							|  |  |  |                     {-1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 2.0, 2.0}, | 
					
						
							|  |  |  |                     {4, 2}, {2}, {4, 2}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | class FloorDivInt8Test : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~FloorDivInt8Test() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         vector<float> inp1 = {-3.98, 17.5, 25.4, 6.7}, inp2 = {3}; | 
					
						
							|  |  |  |         vector<float> rightResult = {-2, 5, 8, 2}; | 
					
						
							|  |  |  |         return test<float, float>(_FloorDiv, "FloorDivInt8Test", 0.01, inp1, inp2, rightResult, | 
					
						
							|  |  |  |                                   {4}, {1}, {4}, {0.4, 0.4, 1}, {0., 0., 0.}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2022-12-04 15:17:36 +08:00
										 |  |  | class ModTestInt : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  | public: | 
					
						
							| 
									
										
										
										
											2022-12-04 15:17:36 +08:00
										 |  |  |     virtual ~ModTestInt() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         std::vector<int> x = { | 
					
						
							|  |  |  |             -4, 7, 5, 4, -7, 8 | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |         std::vector<int> y = { | 
					
						
							|  |  |  |             2, -3, 8, -2, 3, 5 | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |         std::vector<int> z = { | 
					
						
							|  |  |  |             0, -2,  5,  0,  2,  3 | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |         return test<int, int>(_Mod, "ModTestFloat", 0, | 
					
						
							|  |  |  |                               x,y,z, {6}, {6}, {6}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | class ModTestFloat : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~ModTestFloat() = default; | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         std::vector<float> x = { | 
					
						
							|  |  |  |             1.1f, 2.3f, 3.5f, 4.7f, 5.9f, 6.2f, 7.4f, 8.6f | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |         std::vector<float> y = { | 
					
						
							|  |  |  |             0.4f, 0.6f | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |         std::vector<float> z(x.size()); | 
					
						
							|  |  |  |         for (int i=0; i<2; ++i) { | 
					
						
							|  |  |  |             for (int j=0; j<4; ++j) { | 
					
						
							| 
									
										
										
										
											2022-05-06 19:51:20 +08:00
										 |  |  |                 z[i + j * 2] = FP32Converter[precision](fmodf(FP32Converter[precision](x[i+j*2]), FP32Converter[precision](y[i]))); | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2022-12-04 15:17:36 +08:00
										 |  |  |         return test<float, float>(_Mod, "ModTestFloat", 0, | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |                     x,y,z, | 
					
						
							|  |  |  |                     {4, 2}, {2}, {4, 2}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class SquaredDifferenceTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~SquaredDifferenceTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, float>(_SquaredDifference, "SquaredDifferenceTest", 0.01, | 
					
						
							|  |  |  |                     {-1.0, -2.0, -3.0, -4.0, 5.0, 6.0, 7.0, 8.001}, | 
					
						
							|  |  |  |                     {3.0, 4.0}, | 
					
						
							|  |  |  |                     {16.0, 36.0, 36.0, 64.0, 4.0, 4.0, 16.0, 16.0}, | 
					
						
							|  |  |  |                     {4, 2}, {2}, {4, 2}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | class SquaredDifferenceInt8Test : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~SquaredDifferenceInt8Test() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         vector<float> inp1 = {-1, -2, -3, -4, 5, 6, 7, 8, -1, -2, -3, -4, 5, 6, 7, 8, -1, -2, -3, -4, 5, 6, 7, 8, -1, -2, -3, -4, 5, 6, 7, 8}, inp2 = {3}; | 
					
						
							|  |  |  |         vector<float> rightResult = {16, 25, 36, 49, 4, 9, 16, 25, 16, 25, 36, 49, 4, 9, 16, 25, 16, 25, 36, 49, 4, 9, 16, 25, 16, 25, 36, 49, 4, 9, 16, 25}; | 
					
						
							|  |  |  |         return test<float, float>(_SquaredDifference, "SquaredDifferenceInt8Test", 0.01, inp1, inp2, rightResult, | 
					
						
							|  |  |  |                                   {8, 4}, {1}, {8, 4}, {1, 1, 1}, {0., 0., 0.}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class EqualTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~EqualTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, int>(_Equal, "EqualTest", 0, | 
					
						
							|  |  |  |                     {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, | 
					
						
							|  |  |  |                     {3.0, 4.0}, | 
					
						
							|  |  |  |                     {0, 0, 1, 1, 0, 0, 0, 0}, | 
					
						
							|  |  |  |                     {4, 2}, {2}, {4, 2}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class LessEqualTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~LessEqualTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, int>(_LessEqual, "LessEqualTest", 0, | 
					
						
							|  |  |  |                     {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, | 
					
						
							|  |  |  |                     {3.0, 4.0}, | 
					
						
							|  |  |  |                     {1, 1, 1, 1, 0, 0, 0, 0}, | 
					
						
							|  |  |  |                     {4, 2}, {2}, {4, 2}); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class FloorModTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~FloorModTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, float>(_FloorMod, "FloorModTest", 0.01, | 
					
						
							|  |  |  |                     {-1.0f, -2.0f, -3.0f, -4.0f, 5.0f, 6.0f, 7.0f, 8.1f}, | 
					
						
							|  |  |  |                     {3.0f, 4.0f}, | 
					
						
							|  |  |  |                     {2.0f, 2.0f, 0.0f, 0.0f, 2.0f, 2.0f, 1.0f, 0.1f}, | 
					
						
							|  |  |  |                     {4, 2}, {2}, {4, 2}); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | class FloorModInt8Test : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~FloorModInt8Test() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         return test<float, float>(_FloorMod, "FloorModInt8Test", 0.01, | 
					
						
							|  |  |  |                     {-1, -3, 5, 7}, | 
					
						
							|  |  |  |                     {3.0f}, {2, 0, 2, 1}, | 
					
						
							|  |  |  |                                   {4}, {1}, {4}, {0.3, 0.3, 0.3}, {0., 0., 0.}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class Atan2Test : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2020-02-27 23:01:24 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~Atan2Test() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, float>(_Atan2, "Atan2Test", 0.01, | 
					
						
							|  |  |  |                     {-1.0, -2.0, -3.0, -4.0, 5.0, 6.0, 7.0, 8.0}, | 
					
						
							| 
									
										
										
										
											2023-03-20 11:32:29 +08:00
										 |  |  |                     {3.0, -4.0}, | 
					
						
							|  |  |  |                     {-0.32175055, -2.67794504, -0.7853982, -2.35619449, 1.0303768, 2.15879893, 1.1659045, 2.03444394}, | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |                     {4, 2}, {2}, {4, 2}); | 
					
						
							| 
									
										
										
										
											2020-02-27 23:01:24 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | class Atan2Int8Test : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~Atan2Int8Test() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         return test<float, float>(_Atan2, "Atan2Int8Test", 0.01, | 
					
						
							|  |  |  |                     {-1, -3, 5, 7}, | 
					
						
							|  |  |  |                     {3}, {-1, 0, 2, 1}, | 
					
						
							|  |  |  |                                   {4}, {1}, {4}, {1, 1, 1}, {0., 0., 0.}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class LogicalOrTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2020-02-27 23:01:24 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~LogicalOrTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<int, int>(_LogicalOr, "LogicalOrTest", 0, | 
					
						
							|  |  |  |                     {true, false, true, false, false, true, true, false}, | 
					
						
							|  |  |  |                     {true, false}, | 
					
						
							|  |  |  |                     {true, false, true, false, true, true, true, false}, | 
					
						
							|  |  |  |                     {4, 2}, {2}, {4, 2}); | 
					
						
							| 
									
										
										
										
											2020-02-27 23:01:24 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class NotEqualTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2020-02-27 23:01:24 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~NotEqualTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<int, int>(_NotEqual, "NotEqualTest", 0, | 
					
						
							|  |  |  |                     {true, false, true, false, false, true, true, false}, | 
					
						
							|  |  |  |                     {true, false}, | 
					
						
							|  |  |  |                     {false, false, false, false, true, true, false, false}, | 
					
						
							|  |  |  |                     {4, 2}, {2}, {4, 2}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2022-01-04 10:50:40 +08:00
										 |  |  | class BitwiseAndTest : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~BitwiseAndTest() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         return test<int, int>(_BitwiseAnd, "BitwiseAndTest", 0, | 
					
						
							|  |  |  |                     {1, 2, 3, 4, 5, 6, 7, 8}, | 
					
						
							|  |  |  |                     {8, 7, 6, 5, 4, 3, 2, 1}, | 
					
						
							|  |  |  |                     {0, 2, 2, 4, 4, 2, 2, 0}, | 
					
						
							|  |  |  |                     {8}, {8}, {8}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | class BitwiseOrTest : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~BitwiseOrTest() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         return test<int, int>(_BitwiseOr, "BitwiseOrTest", 0, | 
					
						
							|  |  |  |                     {1, 2, 3, 4, 5, 6, 7, 8}, | 
					
						
							|  |  |  |                     {8, 7, 6, 5, 4, 3, 2, 1}, | 
					
						
							|  |  |  |                     {9, 7, 7, 5, 5, 7, 7, 9}, | 
					
						
							|  |  |  |                     {8}, {8}, {8}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | class BitwiseXorTest : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~BitwiseXorTest() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         return test<int, int>(_BitwiseXor, "BitwiseXorTest", 0, | 
					
						
							|  |  |  |                     {1, 2, 3, 4, 5, 6, 7, 8}, | 
					
						
							|  |  |  |                     {8, 7, 6, 5, 4, 3, 2, 1}, | 
					
						
							|  |  |  |                     {9, 5, 5, 1, 1, 5, 5, 9}, | 
					
						
							|  |  |  |                     {8}, {8}, {8}); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | class BinaryReluTest : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~BinaryReluTest() = default; | 
					
						
							|  |  |  |     virtual bool run(int precision) { | 
					
						
							|  |  |  |         std::vector<float> input0_data = { | 
					
						
							|  |  |  |             1.0, 2.0, 3.0, | 
					
						
							|  |  |  |             4.0, 5.0, 6.0, | 
					
						
							|  |  |  |             7.0, 8.0, 9.0 | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |         std::vector<float> input1_data = { | 
					
						
							|  |  |  |             -2.0, 2.0, -4.0, | 
					
						
							|  |  |  |             4.0, 5.0, -8.0, | 
					
						
							|  |  |  |             7.0, -18.0, 9.0 | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |         std::vector<float> output_data = { | 
					
						
							|  |  |  |             0.0, 4.0, 0.0, | 
					
						
							|  |  |  |             8.0, 10.0, 0.0, | 
					
						
							|  |  |  |             14.0, 0.0, 18.0 | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |         auto input_0 = _Input({1, 1, 3, 3}, NCHW, halide_type_of<float>()); | 
					
						
							|  |  |  |         auto input_1 = _Input({1, 1, 3, 3}, NCHW, halide_type_of<float>()); | 
					
						
							|  |  |  |         ::memcpy(input_0->writeMap<float>(), input0_data.data(), input0_data.size() * sizeof(float)); | 
					
						
							|  |  |  |         ::memcpy(input_1->writeMap<float>(), input1_data.data(), input1_data.size() * sizeof(float)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         std::unique_ptr<OpT> binaryOp(new OpT); | 
					
						
							|  |  |  |         binaryOp->type = OpType_BinaryOp; | 
					
						
							|  |  |  |         binaryOp->main.type = OpParameter_BinaryOp; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         binaryOp->main.value = new BinaryOpT; | 
					
						
							|  |  |  |         binaryOp->main.AsBinaryOp()->opType = BinaryOpOperation_ADD; | 
					
						
							|  |  |  |         binaryOp->main.AsBinaryOp()->activationType = 1;// Do Relu
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto output = Variable::create(Expr::create(binaryOp.get(), {input_0, input_1}, 1)); | 
					
						
							|  |  |  |         auto getOutput = output->readMap<float>(); | 
					
						
							|  |  |  |         if (!checkVectorByRelativeError<float>(getOutput, output_data.data(), output_data.size(), 0.001)) { | 
					
						
							|  |  |  |             MNN_ERROR("Binary-Relu fuse test failed!\n"); | 
					
						
							|  |  |  |             return false; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class BinaryBroadcastShapeTest : public BinaryTestCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual ~BinaryBroadcastShapeTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         vector<int> data_x(8, 1), data_y(8, 1), data_out(64, 2); | 
					
						
							|  |  |  |         vector<int> shape_x = {4, 1, 2, 1}, shape_y = {2, 1, 4}, shape_out = {4, 2, 2, 4}; | 
					
						
							|  |  |  |         return test<int, int>(_Add, "BinaryBroadcastShapeTest", 0, | 
					
						
							|  |  |  |                               data_x, data_y, data_out, shape_x, shape_y, shape_out); | 
					
						
							| 
									
										
										
										
											2020-02-27 23:01:24 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2020-05-28 15:10:53 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  | class SubtractBroastTest : public BinaryTestCommon { | 
					
						
							| 
									
										
										
										
											2020-05-28 15:10:53 +08:00
										 |  |  | public: | 
					
						
							|  |  |  |     virtual ~SubtractBroastTest() = default; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     virtual bool run(int precision) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         vector<float> data_x(560), data_y(20 * 560), data_out(20 * 560); | 
					
						
							|  |  |  |         vector<int> shape_x = {560}, shape_y = {1, 20, 560}, shape_out = {1, 20, 560}; | 
					
						
							| 
									
										
										
										
											2022-05-06 19:51:20 +08:00
										 |  |  |         auto func = FP32Converter[precision]; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         for (int i = 0; i < 560; ++i) { | 
					
						
							| 
									
										
										
										
											2022-05-06 19:51:20 +08:00
										 |  |  |             data_x[i]  = func(i / 1000.0f); | 
					
						
							| 
									
										
										
										
											2020-05-28 15:10:53 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         for (int i = 0; i < 560 * 20; ++i) { | 
					
						
							| 
									
										
										
										
											2022-05-06 19:51:20 +08:00
										 |  |  |             data_y[i]  = func(i / 1000.0f); | 
					
						
							| 
									
										
										
										
											2020-05-28 15:10:53 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         for (int i = 0; i < 20; ++i) { | 
					
						
							|  |  |  |             for (int j = 0; j < 560; ++j) { | 
					
						
							| 
									
										
										
										
											2022-05-06 19:51:20 +08:00
										 |  |  |                 data_out[j + i * 560] = func(data_x[j] - data_y[j + i * 560]); | 
					
						
							| 
									
										
										
										
											2020-05-28 15:10:53 +08:00
										 |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         return test<float, float>(_Subtract, "SubtractBroastTest", 0.01, | 
					
						
							|  |  |  |                                   data_x, data_y, data_out, shape_x, shape_y, shape_out); | 
					
						
							| 
									
										
										
										
											2020-05-28 15:10:53 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | // Float32 OpTest.
 | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  | MNNTestSuiteRegister(BinaryBroadcastShapeTest, "op/binary/broadcastShapeTest"); | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | MNNTestSuiteRegister(AddTest, "op/binary/add"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(SubtractTest, "op/binary/subtract"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(MultiplyTest, "op/binary/multiply"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(DivideTest, "op/binary/divide"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(PowTest, "op/binary/pow"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(MinimumTest, "op/binary/minimum"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(MaximumTest, "op/binary/maximum"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(BiasAddTest, "op/binary/biasadd"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(GreaterTest, "op/binary/greater"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(GreaterEqualTest, "op/binary/greaterequal"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(LessTest, "op/binary/less"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(FloorDivTest, "op/binary/floordiv"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(SquaredDifferenceTest, "op/binary/squareddifference"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(EqualTest, "op/binary/equal"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(LessEqualTest, "op/binary/lessequal"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(FloorModTest, "op/binary/floormod"); | 
					
						
							| 
									
										
										
										
											2022-12-04 15:17:36 +08:00
										 |  |  | MNNTestSuiteRegister(ModTestFloat, "op/binary/mod_float"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(ModTestInt, "op/binary/mod_int"); | 
					
						
							| 
									
										
										
										
											2020-02-27 23:01:24 +08:00
										 |  |  | MNNTestSuiteRegister(Atan2Test, "op/binary/atan2"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(LogicalOrTest, "op/binary/logicalor"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(NotEqualTest, "op/binary/notqual"); | 
					
						
							| 
									
										
										
										
											2020-05-28 15:10:53 +08:00
										 |  |  | MNNTestSuiteRegister(SubtractBroastTest, "op/binary/subtractBroastTest"); | 
					
						
							| 
									
										
										
										
											2022-01-04 10:50:40 +08:00
										 |  |  | MNNTestSuiteRegister(BitwiseAndTest, "op/binary/bitwise_and"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(BitwiseOrTest, "op/binary/bitwise_or"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(BitwiseXorTest, "op/binary/bitwise_xor"); | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | MNNTestSuiteRegister(BinaryReluTest, "op/binary/fuse_relu"); | 
					
						
							|  |  |  | // Int8 OpTest.
 | 
					
						
							|  |  |  | MNNTestSuiteRegister(AddInt8Test, "op/binary/addInt8"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(SubtractInt8Test, "op/binary/subtractInt8"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(MultiplyInt8Test, "op/binary/multiplyInt8"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(DivideInt8Test, "op/binary/divideInt8"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(PowInt8Test, "op/binary/powInt8"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(MinimumInt8Test, "op/binary/minimumInt8"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(MaximumInt8Test, "op/binary/maximumInt8"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(FloorDivInt8Test, "op/binary/floordivInt8"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(FloorModInt8Test, "op/binary/floormodInt8"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(Atan2Int8Test, "op/binary/atan2Int8"); | 
					
						
							|  |  |  | MNNTestSuiteRegister(SquaredDifferenceInt8Test, "op/binary/sqdInt8"); |