mirror of https://github.com/alibaba/MNN.git
594 lines
47 KiB
C++
594 lines
47 KiB
C++
//
|
|
// RoiAlignGradTest.cpp
|
|
// MNNTests
|
|
//
|
|
// Created by MNN on 2022/12/14.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include <MNN/expr/Expr.hpp>
|
|
#include <MNN/expr/ExprCreator.hpp>
|
|
#include "MNNTestSuite.h"
|
|
#include "TestUtils.h"
|
|
#include <vector>
|
|
#include "../tools/train/source/grad/OpGrad.hpp"
|
|
|
|
using namespace std;
|
|
using namespace MNN;
|
|
using namespace MNN::Express;
|
|
|
|
class RoiAlignGradTest : public MNNTestCase {
|
|
public:
|
|
char name[20] = "RoiAlign";
|
|
virtual ~RoiAlignGradTest() = default;
|
|
|
|
virtual bool run(int precision) {
|
|
vector<int> inputShape = {3, 9, 4, 4};
|
|
auto input = _Input(inputShape, NCHW);
|
|
vector<float> inputData = { 1.2587, 0.8945, -0.7543, 2.0012, 1.3856, 0.9340, -0.1376, 0.1837,
|
|
-0.8983, -0.6558, -0.3315, -0.1219, 0.3075, 1.4431, 0.4746, -0.5140,
|
|
-0.1245, -0.3308, 0.0673, 0.7772, 0.0639, -0.6047, 0.1636, -0.3155,
|
|
0.1523, -0.4973, 1.2428, -0.5529, -0.4270, 1.6364, 0.2426, 1.0559,
|
|
0.9193, -0.3176, 0.0892, -0.2509, -2.4380, -0.9005, 0.5781, 1.4008,
|
|
-0.5696, 0.7918, -2.0354, -1.2119, -1.3411, -0.3476, -0.0886, -0.0649,
|
|
-0.7115, -0.5848, -1.0234, -1.5510, 0.0550, -1.4380, 0.7675, -0.5726,
|
|
1.4670, -0.9947, -1.8784, -1.5924, -0.8265, 0.0891, -0.3028, -1.3168,
|
|
-0.4735, -1.4944, 0.3030, -0.1165, 0.2647, -1.3413, -0.9787, 0.2888,
|
|
2.8702, -1.2232, -0.8815, 1.9677, 0.5460, -0.8153, -0.5402, -0.7055,
|
|
1.0696, 0.3802, -0.0949, 0.9391, -1.0831, -1.3940, 1.7714, 1.0268,
|
|
1.0844, -1.3981, 0.4609, -0.7931, -0.3239, 1.2300, 1.4248, -1.2277,
|
|
-1.5531, -0.3628, 0.3534, 1.1957, -0.5323, -0.5895, -1.6513, 0.8463,
|
|
0.1348, 0.7655, 0.1805, -1.6148, 2.8097, -0.4605, -0.6506, 0.0297,
|
|
0.6477, -1.1414, -1.0395, 0.0904, 1.5177, -0.0325, 0.5897, 0.3167,
|
|
0.4292, 0.3140, -0.3639, -0.7091, 0.2055, 0.0503, 0.6292, -0.1367,
|
|
0.7991, -1.3695, -0.7060, -0.0840, 0.3023, 0.4616, -0.7059, -1.6423,
|
|
-1.3314, 1.2474, 0.5421, 1.4275, -0.8528, -0.6006, 0.2814, 0.6976,
|
|
-0.6811, -1.9291, 1.2983, 0.4801, -1.1602, 0.4394, -0.3520, 0.6311,
|
|
-0.3585, 0.1489, 0.9659, 0.5493, -0.9856, -1.1759, 0.9381, 0.5606,
|
|
2.0309, 0.5102, 1.7770, 1.2903, -1.4298, -1.1124, 0.3458, 1.8255,
|
|
0.5936, 1.3503, 0.9923, -0.9042, 0.0124, 1.0796, -0.2233, -1.0319,
|
|
-1.3835, 0.8602, 0.0651, -0.1098, -1.9900, -1.1028, -1.0592, 0.8511,
|
|
-0.2102, -0.7675, -0.6877, -0.4493, -0.5632, 0.3369, 0.5917, -0.6685,
|
|
-1.1458, -1.9596, 1.8387, -0.2642, 1.4898, -0.7788, 0.7117, -1.2234,
|
|
0.3939, -0.2793, -0.4268, -1.1598, 0.4164, 0.4359, 0.5211, 0.5965,
|
|
2.0014, 0.7337, 0.3770, -0.8599, 0.7286, -0.9268, 0.1724, -1.1386,
|
|
-0.1429, 0.7072, 0.5999, -0.2979, -0.8230, -0.8795, -0.5317, 0.0974,
|
|
1.0004, 0.6322, -1.9103, 0.8706, -0.2598, -1.2323, -0.3205, 1.3420,
|
|
0.3936, 2.0456, 1.7977, -1.1196, -0.5652, -1.3567, 0.9958, -2.0845,
|
|
-1.3749, -0.7130, 1.0244, 0.0593, 0.0636, 0.2393, -1.3413, -0.3329,
|
|
0.2147, -0.5064, -0.5119, -0.9965, -0.4002, -0.6242, -0.3976, -0.6084,
|
|
1.2526, 0.7067, -0.2353, -0.5699, 0.0824, 1.0667, -0.0329, -0.1180,
|
|
-1.6390, 0.7729, 0.7066, -2.2387, -0.7651, -0.4625, 1.9304, -0.1592,
|
|
0.4796, -0.6125, -0.8265, 0.0568, 2.5158, -0.3929, -0.3927, -0.1145,
|
|
0.4040, 0.2954, -0.7797, 0.0569, 0.3714, 0.5620, -0.6556, 0.0075,
|
|
-0.0251, -1.6895, -0.8571, 0.3759, 0.0106, 1.0075, 1.3647, -0.1915,
|
|
-0.1687, -1.9660, 0.7073, -1.0942, -0.1903, 1.2114, 0.6589, 0.7416,
|
|
0.1255, -0.2084, -1.7247, -0.6163, 2.5999, 0.5725, -0.1817, 1.2373,
|
|
1.6475, -0.3679, -0.7700, 0.5559, -0.0299, -0.9032, -1.6034, 0.8630,
|
|
-0.9992, -0.5817, 0.5362, 1.4626, -0.7890, -1.0981, 1.7217, 0.7581,
|
|
1.2861, 0.3955, 0.2466, 0.3384, 0.1506, -0.9613, -1.5495, -0.1552,
|
|
-0.1180, 1.6468, -1.7706, 0.0055, -0.4989, 0.1550, 1.6259, -0.0722,
|
|
0.7386, 0.0657, 0.5618, -0.4135, -1.3406, 0.7209, -0.1369, -0.4943,
|
|
-0.3508, -0.2657, 1.5009, 1.8255, 0.7049, 0.9854, 0.3529, 1.1112,
|
|
-0.9561, 0.7174, -1.1929, -1.2257, 0.1584, 0.7370, -0.7273, 0.8572,
|
|
0.0591, -1.8631, -0.1637, 1.9188, -0.9281, -1.3265, 0.3382, -0.5424,
|
|
1.4783, -0.0339, 0.8036, 0.1805, 1.2170, -2.1388, -0.4797, 1.1232,
|
|
0.4213, -0.2824, -0.0592, -0.3094, -0.9494, 0.0946, 1.3795, 0.4063,
|
|
0.1934, -0.2050, 1.1473, -1.8769, -1.3865, 0.0212, 0.5409, 0.8030,
|
|
-2.4675, 0.6231, 0.1214, -1.3949, -1.0724, -1.5440, -0.4761, 1.4920,
|
|
1.7146, 0.3414, 0.1379, -0.8818, -1.7559, 1.8605, -1.3545, 1.5024,
|
|
-2.1116, -0.4113, -1.4620, -0.2642, 0.3996, -0.0468, -1.4671, 0.4811,
|
|
0.0413, 0.4503, -0.2901, 1.9869, -0.3118, -1.3857, 1.3151, -0.7364};
|
|
auto inputPtr = input->writeMap<float>();
|
|
memcpy(inputPtr, inputData.data(), input->getInfo()->size * sizeof(float));
|
|
|
|
const float spatialScale = 1.0 / 16;
|
|
const int pooledHeight = 2;
|
|
const int pooledWidth = 3;
|
|
|
|
auto roiInput = _Input({2, 5}, NCHW);
|
|
vector<float> roiData = { 2, 1 / spatialScale, 2 / spatialScale, 3 / spatialScale, 3 / spatialScale,
|
|
0, 0 / spatialScale, 2 / spatialScale, 2 / spatialScale, 3 / spatialScale};
|
|
memcpy(roiInput->writeMap<float>(), roiData.data(), roiInput->getInfo()->size * sizeof(float));
|
|
|
|
// avg pool, align false
|
|
auto outputOri = _ROIAlign(_Convert(input, NC4HW4), _Convert(roiInput, NC4HW4), pooledHeight, pooledWidth, spatialScale, 2, false, PoolingMode::AVEPOOL);
|
|
auto output = _Convert(outputOri, NCHW);
|
|
auto outputPtr = output->readMap<float>();
|
|
|
|
vector<float> outputTorch = { -0.5494, 0.4288, -0.1918, 0.5017, 0.6175, 0.4121, -0.6606, -0.8044,
|
|
0.0957, -0.9779, -1.1611, 0.0591, -0.0299, -0.1461, 0.0759, -0.7816,
|
|
-0.9716, -0.3880, 0.2820, 0.3056, -0.1601, 0.3840, 0.0384, -0.3035,
|
|
0.1227, -0.8956, -0.8288, 0.2068, -0.6140, -0.0569, -0.2125, 0.3903,
|
|
0.4384, -1.1280, -0.1928, 0.5387, 0.2329, 0.7168, -0.4727, 0.2073,
|
|
0.5852, 0.3195, 0.4024, -0.1597, -0.2689, 0.6600, -0.6189, 0.2771,
|
|
0.0313, 0.2008, 0.9078, -0.3132, 0.6796, 0.2675, -0.4416, -0.1698,
|
|
-0.1303, 0.3102, 0.7886, 0.4882, 0.0170, 0.1135, 0.6739, 0.1795,
|
|
0.9367, 0.6961, -0.3393, 0.2299, -0.8635, -0.7864, -0.1959, -0.4044,
|
|
0.3545, -0.6524, -1.2309, -0.2294, -0.2307, -0.5251, 1.1524, -0.8099,
|
|
-0.9045, 0.4456, -0.7226, -0.7228, 0.2412, -0.4980, 0.2209, 0.2098,
|
|
0.5785, 0.9802, 0.6887, 0.4472, 0.1348, 1.3760, 0.0132, -0.3466,
|
|
0.3315, 0.2282, 0.0056, 0.2130, 0.1504, 0.2927, -0.5460, 0.5933,
|
|
0.5798, -0.6945, -0.1677, 0.1849};
|
|
|
|
for (int i = 0; i < outputTorch.size(); ++i) {
|
|
auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
|
|
if (diff > 0.0001) {
|
|
MNN_ERROR("%s avg pool align false, output test failed, expected: %f, but got: %f!\n", name, outputTorch[i], outputPtr[i]);
|
|
return false;
|
|
} else {
|
|
// MNN_PRINT("\tavg pool align false, output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
|
|
}
|
|
}
|
|
|
|
auto opExpr = outputOri->expr().first;
|
|
|
|
auto grad = OpGrad::get(opExpr->get()->type());
|
|
vector<float> outputDiff = {0.4281, 0.5642, 1.0810, -1.0327, 0.5425, -0.7935, 0.3765, -0.0269,
|
|
-0.7886, 0.5253, -0.9186, -0.9041, 0.5890, 0.6958, -1.3026, 1.1327,
|
|
0.5041, -0.4669, -0.1207, -1.6695, 0.0044, 1.4214, -1.1095, 1.6068,
|
|
-0.4748, 1.7006, -0.1226, 1.7765, -0.2359, -1.4316, 0.1308, -0.3773,
|
|
0.2040, 2.4996, 0.5790, 1.0668, 1.0195, 0.9374, -1.5003, -0.6189,
|
|
-1.1948, 0.3161, 1.0680, -1.2055, -0.9482, -0.3838, -0.9000, 0.6312,
|
|
0.1845, 0.1908, 0.2646, -0.7158, 1.0276, -1.1181, 0.4016, 0.9310,
|
|
1.1894, 0.0410, 2.1430, 0.0430, 0.1397, 0.2587, -0.2221, -0.1704,
|
|
0.0979, 0.4072, -1.2300, -0.0112, 1.0392, 0.9954, -0.3105, 0.4925,
|
|
0.6148, -0.2098, 1.0337, 1.9675, -0.6049, -1.5234, 1.3445, -0.3587,
|
|
-0.0854, -0.2082, -0.5494, 1.3498, 1.2961, 0.9696, -0.8827, -1.1804,
|
|
0.6077, -1.9303, -1.9210, -0.0585, -0.4386, -0.9297, 1.7299, -0.0892,
|
|
-0.6664, -0.3320, -0.4892, 0.2264, 1.8798, 0.6208, -1.9061, -0.1891,
|
|
-0.5713, -0.4127, 0.0955, -0.5188};
|
|
|
|
auto inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {2, 9, 2, 3}, NCHW), NC4HW4)});
|
|
|
|
vector<float> expectedOutput = { 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.3105, 1.4331, 0.7047, 0.0000, 0.2408, 1.6869, 0.3731, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0597, 0.1812, -0.0250, 0.0000, -0.0504, 0.1674, 0.1781, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
-0.4563, 0.0046, 0.5945, 0.0000, 0.2731, 0.1597, 0.3998, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.6096, 0.1920, 0.2372, 0.0000, 1.0440, -0.1734, -0.6316, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.6037, 0.0713, 0.1484, 0.0000, 0.0782, -0.0278, 0.6189, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.5246, 0.5767, -0.6898, 0.0000, -0.3160, -0.1614, -1.0541, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
-1.0831, -0.3510, -0.2018, 0.0000, -0.6781, 0.6176, -0.0108, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
-0.2771, -0.0342, -0.1227, 0.0000, 0.1127, 1.2212, 0.3394, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
-1.0317, -0.7953, -0.3819, 0.0000, -0.5220, -0.4190, -0.3526, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0885, 0.6907, 0.4548, 0.0000, -0.3993, 0.1258, -0.1709,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.2550, -0.3428, -0.5658, 0.0000, 0.2674, -0.7088, -0.6415,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.5373, 0.4170, -0.6751, 0.0000, 0.7105, 0.5670, -0.4045,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0491, -1.0513, 0.1425, 0.0000, 0.5865, -0.2939, 0.7000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.1601, 0.8931, -0.1985, 0.0000, 0.8298, 0.2433, -0.7155,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.4705, 0.2657, 0.2683, 0.0000, 1.2999, 1.2028, 0.5957,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.4403, 0.1915, -0.6638, 0.0000, -0.1947, -0.6672, -0.1471,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.3759, -0.8904, -0.4630, 0.0000, -0.0953, -0.7418, 0.0762,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0063, 0.2928, -0.0207, 0.0000, -0.2590, 0.2610, -0.4468};
|
|
auto gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();
|
|
|
|
for (int i = 0; i < expectedOutput.size(); ++i) {
|
|
auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
|
|
if (diff > 0.001) {
|
|
MNN_ERROR("%s avg pool align false, grad test failed, expected: %f, but got: %f!\n", name, expectedOutput[i], gotOutput[i]);
|
|
return false;
|
|
} else {
|
|
// MNN_PRINT("%s avg pool align false, grad exact, %f <==> %f\n", name, expectedOutput[i], gotOutput[i]);
|
|
}
|
|
}
|
|
|
|
|
|
// avg pool, align true
|
|
outputOri = _ROIAlign(_Convert(input, NC4HW4), _Convert(roiInput, NC4HW4), pooledHeight, pooledWidth, spatialScale, 2, true, PoolingMode::AVEPOOL);
|
|
output = _Convert(outputOri, NCHW);
|
|
outputPtr = output->readMap<float>();
|
|
|
|
outputTorch = { -1.0395, -0.1755, 0.5816, -1.0054, -0.2382, 0.4735, 0.2036, -0.3779,
|
|
-0.3981, -0.2134, -0.7400, -0.7098, 0.1463, 0.3187, 0.5867, 0.2140,
|
|
-0.0731, -0.1329, 0.1449, 0.4579, 0.6352, 0.2277, 0.3083, 0.2503,
|
|
0.5635, -0.0110, -0.7790, 0.4890, -0.1771, -1.0146, -0.1514, 0.1651,
|
|
0.5727, -0.2313, -0.0387, 0.4717, -0.1238, 0.5376, 0.7868, -0.1573,
|
|
0.4236, 0.6286, 0.0613, -0.0728, -0.0611, 0.7421, 0.2430, -0.2436,
|
|
0.2935, -0.1292, -0.2185, -0.0151, 0.0513, 0.3103, -0.3273, -0.2928,
|
|
-0.2625, -0.5968, -0.3640, -0.1309, 0.1302, -0.1970, -0.2746, 0.0075,
|
|
0.0218, 0.1956, -1.0367, -0.3340, 0.0769, -0.7625, -0.1278, 0.1643,
|
|
1.1140, 0.0042, -1.1241, 0.8936, 0.0849, -0.8505, 2.2188, 0.4830,
|
|
-1.1949, 2.2891, 0.5840, -1.0670, 0.5425, -0.4273, -1.0328, 0.7323,
|
|
-0.0044, -0.5006, -0.0320, 0.1974, 0.3094, 0.8035, 0.6313, 0.3780,
|
|
0.7013, 0.4643, 0.1686, 0.3733, 0.3107, 0.1875, -0.9230, 0.0640,
|
|
0.9141, -1.2118, -0.2132, 0.7340};
|
|
|
|
for (int i = 0; i < outputTorch.size(); ++i) {
|
|
auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
|
|
if (diff > 0.0001) {
|
|
MNN_ERROR("%s avg pool align true, output test failed, expected: %f, but got: %f!\n", name, outputTorch[i], outputPtr[i]);
|
|
return false;
|
|
} else {
|
|
// MNN_PRINT("\tavg pool align true, output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
|
|
}
|
|
}
|
|
|
|
opExpr = outputOri->expr().first;
|
|
grad = OpGrad::get(opExpr->get()->type());
|
|
inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {2, 9, 2, 3}, NCHW), NC4HW4)});
|
|
|
|
expectedOutput = { 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.1677e-01,
|
|
3.6417e-01, 4.9558e-02, 0.0000e+00, 1.4847e+00, 1.9230e+00,
|
|
1.5405e-01, 0.0000e+00, 2.7812e-01, 2.7683e-01, 1.7917e-03,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
6.7263e-02, -1.3933e-02, -9.2542e-03, 0.0000e+00, 1.1070e-01,
|
|
2.4941e-01, 2.3138e-02, 0.0000e+00, -3.0363e-02, 9.7071e-02,
|
|
1.6967e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, -3.0890e-01, 2.1510e-01, 4.3300e-02, 0.0000e+00,
|
|
-2.9659e-01, 8.3667e-01, 1.9146e-01, 0.0000e+00, 2.1004e-01,
|
|
6.3792e-02, 2.0521e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 1.2747e-01, 1.8913e-01, 4.3071e-02,
|
|
0.0000e+00, 1.6312e+00, -6.1158e-01, -6.1213e-02, 0.0000e+00,
|
|
4.1626e-01, -3.9299e-01, -6.3475e-02, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 2.9129e-01, -6.2629e-02,
|
|
-3.5583e-03, 0.0000e+00, 5.1169e-01, 4.4971e-01, 1.5805e-01,
|
|
0.0000e+00, -1.2072e-01, 2.1253e-01, 5.6242e-02, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.4522e-01,
|
|
-6.2696e-02, -3.6779e-02, 0.0000e+00, 6.7826e-01, -1.1666e+00,
|
|
-3.5163e-01, 0.0000e+00, -2.1914e-01, -3.2618e-01, -8.0429e-02,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
-4.8756e-01, -9.8687e-02, -1.8275e-02, 0.0000e+00, -1.5112e+00,
|
|
2.9690e-01, -6.5975e-02, 0.0000e+00, -1.6188e-02, 1.9765e-01,
|
|
-3.7167e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, -2.0810e-01, -1.4342e-01, -2.0383e-02, 0.0000e+00,
|
|
2.5043e-01, 6.6268e-01, 1.6450e-02, 0.0000e+00, 2.9157e-01,
|
|
3.6431e-01, 2.5867e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, -5.0016e-01, -1.4266e-01, -2.3804e-02,
|
|
0.0000e+00, -1.7742e+00, -7.1641e-01, -1.3626e-01, 0.0000e+00,
|
|
-9.1238e-02, -9.6146e-02, -2.1617e-02, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 1.7837e-02, 1.5971e-01, 2.9573e-01,
|
|
4.5042e-02, -7.5575e-02, 3.7138e-02, 5.9470e-01, 3.5938e-02,
|
|
-4.3029e-02, -1.4733e-01, -9.7500e-02, -3.3063e-02, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 1.5687e-02, 7.5075e-02,
|
|
-1.6765e-01, -3.2858e-02, 1.1272e-01, 2.0906e-01, -1.4125e+00,
|
|
-2.1159e-01, 2.1887e-02, -5.3875e-03, -3.0318e-01, -3.7671e-02,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.4542e-02,
|
|
2.0968e-01, -1.8440e-01, -5.4275e-02, 2.1521e-01, 1.5260e+00,
|
|
-6.5597e-01, -2.2119e-01, 4.7196e-02, 2.9899e-01, -3.4258e-02,
|
|
-1.9454e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
-5.0292e-03, -2.3383e-01, -2.0777e-01, 1.8333e-04, 1.6259e-01,
|
|
-2.2919e-01, -3.5125e-02, 2.0140e-01, 5.9225e-02, 1.5744e-01,
|
|
1.9606e-01, 6.6950e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, -1.9783e-02, 1.1366e-01, 1.8703e-01, -5.1083e-03,
|
|
1.6271e-01, 1.3628e+00, -4.2211e-01, -1.9428e-01, 7.4021e-02,
|
|
3.4062e-01, -3.2774e-01, -5.9650e-02, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 5.4500e-03, -1.9912e-02, -4.6625e-03,
|
|
8.5000e-03, 3.2880e-01, 1.7196e+00, 8.6989e-01, 1.5885e-01,
|
|
1.0415e-01, 5.9312e-01, 2.9462e-01, 4.4450e-02, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 4.2479e-02, 3.2957e-01,
|
|
-1.9539e-01, -6.2513e-02, 5.0075e-02, 1.5385e-01, -8.3665e-01,
|
|
-1.4803e-01, -2.5787e-02, -2.7829e-01, -8.3496e-02, 1.3171e-02,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.4500e-02,
|
|
7.1813e-02, -3.4823e-01, -3.9508e-02, 8.5525e-02, -3.6194e-01,
|
|
-9.8769e-01, -3.9625e-02, -1.5992e-02, -1.9246e-01, 1.9000e-02,
|
|
2.6300e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
7.6875e-03, 6.2287e-02, 7.8975e-02, 1.1025e-02, -6.6412e-02,
|
|
1.2484e-01, -7.6537e-02, -1.0669e-01, -2.9825e-02, -2.0675e-02,
|
|
-1.0449e-01, -4.6588e-02};
|
|
gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();
|
|
|
|
for (int i = 0; i < expectedOutput.size(); ++i) {
|
|
auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
|
|
if (diff > 0.001) {
|
|
MNN_ERROR("%s avg pool align true, grad test failed, expected: %f, but got: %f!\n", name, expectedOutput[i], gotOutput[i]);
|
|
return false;
|
|
} else {
|
|
// MNN_PRINT("%s avg pool align true, grad exact, %f <==> %f\n", name, expectedOutput[i], gotOutput[i]);
|
|
}
|
|
}
|
|
|
|
#ifdef TEST_ROI_ALIGN_MAX_POOL
|
|
// max pool, align false
|
|
outputOri = _ROIAlign(_Convert(input, NC4HW4), _Convert(roiInput, NC4HW4), pooledHeight, pooledWidth, spatialScale, 2, false, PoolingMode::MAXPOOL);
|
|
output = _Convert(outputOri, NCHW);
|
|
outputPtr = output->readMap<float>();
|
|
|
|
outputTorch = { -0.0427, 0.5067, 0.1417, 0.7893, 0.6898, 0.5885, -0.5080, -0.6294,
|
|
0.3788, -0.8005, -0.9511, 0.4373, 0.1919, 0.0645, 0.2343, -0.5231,
|
|
-0.7250, -0.1211, 0.3110, 0.4200, 0.0254, 0.5242, 0.1835, -0.1694,
|
|
0.4342, -0.7278, -0.5402, 0.4812, -0.5322, 0.3664, 0.1731, 0.5857,
|
|
0.4987, -0.6740, 0.1296, 0.7846, 0.4474, 0.8635, 0.0240, 0.3523,
|
|
0.6296, 0.5424, 0.6889, 0.0480, -0.0601, 1.1975, -0.4472, 0.8090,
|
|
0.1691, 0.4208, 1.3572, 0.0080, 0.8627, 0.4990, -0.1574, 0.1045,
|
|
0.0510, 0.6688, 1.0463, 0.7773, 0.1189, 0.3970, 0.8930, 0.5076,
|
|
1.2027, 0.8687, -0.0083, 0.4301, -0.4704, -0.4861, -0.0700, -0.2686,
|
|
0.8404, -0.3891, -0.9379, -0.0249, -0.1219, -0.2731, 1.9544, -0.5469,
|
|
-0.8063, 1.0199, -0.5706, -0.6301, 0.5787, -0.2067, 0.6165, 0.3768,
|
|
0.9686, 1.2372, 0.9992, 0.5884, 0.3444, 2.0116, 0.3005, -0.1698,
|
|
0.3812, 0.3011, 0.1118, 0.2660, 0.1726, 0.4348, -0.1276, 0.9319,
|
|
0.7630, -0.4699, 0.1402, 0.3314};
|
|
|
|
for (int i = 0; i < outputTorch.size(); ++i) {
|
|
auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
|
|
if (diff > 0.0001) {
|
|
MNN_ERROR("%s max pool align false, output test failed, expected: %f, but got: %f!\n", name, outputTorch[i], outputPtr[i]);
|
|
return false;
|
|
} else {
|
|
// MNN_PRINT("\tmax pool align false, output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
|
|
}
|
|
}
|
|
|
|
opExpr = outputOri->expr().first;
|
|
grad = OpGrad::get(opExpr->get()->type());
|
|
inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {2, 9, 2, 3}, NCHW), NC4HW4)});
|
|
|
|
expectedOutput = { 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 1.2806e-01, 1.2106e+00,
|
|
5.1600e-01, 0.0000e+00, 9.3238e-02, 2.1886e+00, 6.1253e-01,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.3006e-02,
|
|
1.7100e-01, -1.0751e-01, 0.0000e+00, -4.8356e-02, 2.7740e-01,
|
|
1.8546e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
-3.5312e-01, -1.0115e-01, 3.3612e-01, 0.0000e+00, 2.3395e-01,
|
|
4.8162e-01, 3.7797e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 1.0413e+00, 2.6816e-01, 2.1522e-01, 0.0000e+00,
|
|
1.0757e+00, -7.6155e-01, -5.6088e-01, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 8.2865e-01, -2.3095e-01, 9.6125e-02,
|
|
0.0000e+00, -3.3087e-02, -1.2568e-01, 9.5754e-01, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 8.7130e-01, 5.5138e-01,
|
|
-5.4715e-01, 0.0000e+00, -3.8141e-01, -7.9979e-02, -1.5341e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -9.9778e-01,
|
|
6.9852e-02, -2.0861e-01, 0.0000e+00, -1.0992e+00, 5.8396e-01,
|
|
-5.5287e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
-3.4610e-01, 1.2238e-01, -8.8208e-02, 0.0000e+00, 2.3740e-01,
|
|
9.5299e-01, 3.6094e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, -9.1130e-01, -1.3017e+00, -4.3368e-01, 0.0000e+00,
|
|
-2.4810e-01, -3.0781e-01, -2.9991e-01, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 3.7510e-02, 7.5085e-01, 3.4699e-01,
|
|
0.0000e+00, -5.9363e-01, 3.5708e-01, -1.0921e-01, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 4.3869e-01, -3.1992e-01,
|
|
-5.6624e-01, 0.0000e+00, 3.1281e-01, -5.9981e-01, -1.0019e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 7.8345e-01,
|
|
6.0243e-01, -9.6274e-01, 0.0000e+00, 6.5130e-01, 3.8952e-01,
|
|
-3.1186e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
-2.1756e-01, -1.2841e+00, 3.0320e-01, 0.0000e+00, 8.7854e-01,
|
|
-4.9627e-02, 5.0240e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 1.0999e-01, 8.0611e-01, -2.1298e-01, 0.0000e+00,
|
|
1.2189e+00, 3.7239e-01, -1.0822e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 5.2590e-01, 4.7520e-01, 1.9854e-01,
|
|
0.0000e+00, 7.8930e-01, 1.2199e+00, 8.9408e-01, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 3.9202e-01, 1.9105e-01,
|
|
-4.4909e-01, 0.0000e+00, -2.3462e-01, -7.9735e-01, -1.4301e-01,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.8422e-01,
|
|
-1.1426e+00, -4.2810e-01, 0.0000e+00, -6.4969e-02, -5.5076e-01,
|
|
1.6394e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
|
3.1875e-04, -7.1946e-02, 2.4577e-02, 0.0000e+00, -2.0447e-01,
|
|
2.4518e-01, -1.6006e-01};
|
|
gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();
|
|
|
|
for (int i = 0; i < expectedOutput.size(); ++i) {
|
|
auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
|
|
if (diff > 0.001) {
|
|
MNN_ERROR("%s max pool align false, grad test failed, expected: %f, but got: %f!\n", name, expectedOutput[i], gotOutput[i]);
|
|
return false;
|
|
} else {
|
|
MNN_PRINT("%s max pool align false, grad exact, %f <==> %f\n", name, expectedOutput[i], gotOutput[i]);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// max pool, align true
|
|
outputOri = _ROIAlign(_Convert(input, NC4HW4), _Convert(roiInput, NC4HW4), pooledHeight, pooledWidth, spatialScale, 2, true, PoolingMode::MAXPOOL);
|
|
output = _Convert(outputOri, NCHW);
|
|
outputPtr = output->readMap<float>();
|
|
|
|
outputTorch = { -6.0111e-01, 3.5224e-01, 9.5382e-01, -5.7525e-01, 2.0128e-01,
|
|
7.0125e-01, 6.5805e-01, -1.9330e-01, -9.5783e-02, 1.8940e-01,
|
|
-5.8127e-01, -3.8469e-01, 4.8144e-01, 4.7831e-01, 7.9976e-01,
|
|
5.3199e-01, 1.5796e-01, 1.0696e-01, 2.4588e-01, 6.7362e-01,
|
|
9.6084e-01, 3.1140e-01, 3.6551e-01, 4.7446e-01, 8.1790e-01,
|
|
3.4086e-01, -5.2527e-01, 7.2475e-01, 1.4373e-01, -8.2707e-01,
|
|
2.6218e-01, 4.3179e-01, 7.4543e-01, 2.8387e-01, 3.2979e-01,
|
|
6.4319e-01, -9.2650e-02, 7.9203e-01, 1.2344e+00, -1.1918e-01,
|
|
6.5543e-01, 1.0715e+00, 5.2589e-01, 9.0867e-02, 6.1150e-02,
|
|
9.1106e-01, 4.6679e-01, -4.8650e-02, 3.8816e-01, 1.1303e-01,
|
|
3.0808e-01, 2.2080e-01, 1.2850e-01, 5.2979e-01, -4.1838e-02,
|
|
-4.7767e-02, -5.9625e-02, -4.4612e-01, -6.1183e-02, 1.3129e-01,
|
|
1.4125e-01, -7.6075e-02, 2.8817e-02, 7.9888e-02, 1.8025e-01,
|
|
4.9113e-01, -8.0315e-01, 1.1913e-01, 5.8026e-01, -6.6604e-01,
|
|
2.1090e-01, 6.4938e-01, 1.2905e+00, 5.1030e-01, -1.0501e+00,
|
|
1.1803e+00, 5.0047e-01, -5.8828e-01, 2.5445e+00, 1.2837e+00,
|
|
-1.1232e+00, 2.5797e+00, 1.3290e+00, -9.6466e-01, 8.1346e-01,
|
|
7.6446e-02, -6.1360e-01, 9.0836e-01, 2.4905e-01, -9.2080e-04,
|
|
5.1413e-02, 4.1455e-01, 5.9613e-01, 1.1379e+00, 8.6051e-01,
|
|
6.1225e-01, 8.3739e-01, 6.1961e-01, 2.7069e-01, 4.0124e-01,
|
|
3.6117e-01, 2.8104e-01, -7.1876e-01, 3.9556e-01, 1.1492e+00,
|
|
-1.1519e+00, 2.5374e-01, 1.0164e+00};
|
|
|
|
for (int i = 0; i < outputTorch.size(); ++i) {
|
|
auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
|
|
if (diff > 0.0001) {
|
|
MNN_ERROR("%s max pool align true, output test failed, expected: %f, but got: %f!\n", name, outputTorch[i], outputPtr[i]);
|
|
return false;
|
|
} else {
|
|
// MNN_PRINT("\tmax pool align true, output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
|
|
}
|
|
}
|
|
|
|
opExpr = outputOri->expr().first;
|
|
grad = OpGrad::get(opExpr->get()->type());
|
|
inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {2, 9, 2, 3}, NCHW), NC4HW4)});
|
|
|
|
expectedOutput = { 0.0000, 0.0000, 0.0000, 0.0000, 0.3833, 0.5624, 0.0000, 0.0000,
|
|
1.1110, 1.8571, 0.0000, 0.0000, 0.2832, 0.5519, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0390, -0.0077, -0.0093, 0.0000,
|
|
0.1444, 0.1564, 0.0201, 0.0000, -0.0091, 0.1263, 0.0509, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, -0.1542, 0.1290, 0.0000, 0.0000,
|
|
-0.2991, 1.1526, 0.0000, 0.0000, 0.1115, 0.0357, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0594, 0.1205, 0.0000, 0.0000,
|
|
1.7843, -0.2853, 0.0000, 0.0000, 0.1955, -0.5965, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.1382, -0.0221, -0.0036, 0.0000,
|
|
0.4645, 0.2477, 0.2563, 0.0000, -0.0718, 0.3146, 0.1687, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.2428, -0.1803, -0.1103, 0.0000,
|
|
1.0213, -0.7120, -0.5860, 0.0000, -0.0969, -0.4573, -0.2413, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, -0.2426, -0.0597, 0.0000, 0.0000,
|
|
-1.5582, -0.1356, 0.0000, 0.0000, 0.0838, 0.2051, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, -0.3329, -0.1027, 0.0000, 0.0000,
|
|
0.7398, 0.5943, 0.0000, 0.0000, 0.1849, 0.1559, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, -0.7384, -0.1187, 0.0000, 0.0000,
|
|
-1.4608, -0.9769, 0.0000, 0.0000, -0.1508, -0.0569, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0535, 0.1775, 0.5464, 0.0000,
|
|
-0.1260, -0.0214, 0.4424, 0.0000, -0.1291, -0.1904, 0.0364, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0471, 0.0874, -0.2005, -0.0986,
|
|
0.2316, -0.0838, -1.1295, -0.4280, 0.0219, -0.0328, -0.1136, -0.0377,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0245, 0.1361, -0.3145, 0.0000,
|
|
0.5022, 1.4433, -0.6495, -0.1362, 0.0472, 0.1364, -0.0179, -0.0195,
|
|
0.0000, 0.0000, 0.0000, 0.0000, -0.0050, -0.2187, -0.4157, 0.0000,
|
|
-0.0352, 0.1466, 0.0659, 0.0000, 0.0000, 0.4868, 0.1084, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2471, 0.1819, -0.0153,
|
|
0.0000, 1.4239, -0.3424, -0.3238, 0.0000, 0.6072, -0.3874, -0.1790,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0054, -0.0048, -0.0059, 0.0000,
|
|
0.7672, 1.5932, 1.2296, 0.0000, 0.1041, 0.2324, 0.1816, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4995, -0.3283, 0.0000,
|
|
-0.1805, 0.1230, -0.9675, 0.0000, -0.0258, -0.1014, -0.0601, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0445, -0.0115, -0.1688, 0.0000,
|
|
0.3115, -0.6951, -0.8165, 0.0000, 0.0000, -0.3689, -0.0336, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0390, 0.0300, 0.0110,
|
|
0.0000, -0.1395, 0.1723, -0.1558, 0.0000, 0.0390, -0.0226, -0.1398};
|
|
gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();
|
|
|
|
for (int i = 0; i < expectedOutput.size(); ++i) {
|
|
auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
|
|
if (diff > 0.001) {
|
|
MNN_ERROR("%s max pool align true, grad test failed, expected: %f, but got: %f!\n", name, expectedOutput[i], gotOutput[i]);
|
|
return false;
|
|
} else {
|
|
MNN_PRINT("%s max pool align true, grad exact, %f <==> %f\n", name, expectedOutput[i], gotOutput[i]);
|
|
}
|
|
}
|
|
#endif
|
|
|
|
return true;
|
|
}
|
|
};
|
|
|
|
MNNTestSuiteRegister(RoiAlignGradTest, "grad/roi_align");
|