MNN/test/grad/RoiPoolGradTest.cpp

229 lines
19 KiB
C++
Raw Permalink Normal View History

2023-03-20 11:32:29 +08:00
//
// RoiPoolGradTest.cpp
// MNNTests
//
// Created by MNN on 2022/11/23.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <MNN/expr/Expr.hpp>
#include <MNN/expr/ExprCreator.hpp>
#include "MNNTestSuite.h"
#include "TestUtils.h"
#include <vector>
#include "../tools/train/source/grad/OpGrad.hpp"
using namespace std;
using namespace MNN;
using namespace MNN::Express;
class RoiPoolGradTest : public MNNTestCase {
public:
char name[20] = "RoiPool";
virtual ~RoiPoolGradTest() = default;
virtual bool run(int precision) {
vector<int> inputShape = {3, 9, 4, 4};
auto input = _Input(inputShape, NCHW);
vector<float> inputData = { 1.2587, 0.8945, -0.7543, 2.0012, 1.3856, 0.9340, -0.1376, 0.1837,
-0.8983, -0.6558, -0.3315, -0.1219, 0.3075, 1.4431, 0.4746, -0.5140,
-0.1245, -0.3308, 0.0673, 0.7772, 0.0639, -0.6047, 0.1636, -0.3155,
0.1523, -0.4973, 1.2428, -0.5529, -0.4270, 1.6364, 0.2426, 1.0559,
0.9193, -0.3176, 0.0892, -0.2509, -2.4380, -0.9005, 0.5781, 1.4008,
-0.5696, 0.7918, -2.0354, -1.2119, -1.3411, -0.3476, -0.0886, -0.0649,
-0.7115, -0.5848, -1.0234, -1.5510, 0.0550, -1.4380, 0.7675, -0.5726,
1.4670, -0.9947, -1.8784, -1.5924, -0.8265, 0.0891, -0.3028, -1.3168,
-0.4735, -1.4944, 0.3030, -0.1165, 0.2647, -1.3413, -0.9787, 0.2888,
2.8702, -1.2232, -0.8815, 1.9677, 0.5460, -0.8153, -0.5402, -0.7055,
1.0696, 0.3802, -0.0949, 0.9391, -1.0831, -1.3940, 1.7714, 1.0268,
1.0844, -1.3981, 0.4609, -0.7931, -0.3239, 1.2300, 1.4248, -1.2277,
-1.5531, -0.3628, 0.3534, 1.1957, -0.5323, -0.5895, -1.6513, 0.8463,
0.1348, 0.7655, 0.1805, -1.6148, 2.8097, -0.4605, -0.6506, 0.0297,
0.6477, -1.1414, -1.0395, 0.0904, 1.5177, -0.0325, 0.5897, 0.3167,
0.4292, 0.3140, -0.3639, -0.7091, 0.2055, 0.0503, 0.6292, -0.1367,
0.7991, -1.3695, -0.7060, -0.0840, 0.3023, 0.4616, -0.7059, -1.6423,
-1.3314, 1.2474, 0.5421, 1.4275, -0.8528, -0.6006, 0.2814, 0.6976,
-0.6811, -1.9291, 1.2983, 0.4801, -1.1602, 0.4394, -0.3520, 0.6311,
-0.3585, 0.1489, 0.9659, 0.5493, -0.9856, -1.1759, 0.9381, 0.5606,
2.0309, 0.5102, 1.7770, 1.2903, -1.4298, -1.1124, 0.3458, 1.8255,
0.5936, 1.3503, 0.9923, -0.9042, 0.0124, 1.0796, -0.2233, -1.0319,
-1.3835, 0.8602, 0.0651, -0.1098, -1.9900, -1.1028, -1.0592, 0.8511,
-0.2102, -0.7675, -0.6877, -0.4493, -0.5632, 0.3369, 0.5917, -0.6685,
-1.1458, -1.9596, 1.8387, -0.2642, 1.4898, -0.7788, 0.7117, -1.2234,
0.3939, -0.2793, -0.4268, -1.1598, 0.4164, 0.4359, 0.5211, 0.5965,
2.0014, 0.7337, 0.3770, -0.8599, 0.7286, -0.9268, 0.1724, -1.1386,
-0.1429, 0.7072, 0.5999, -0.2979, -0.8230, -0.8795, -0.5317, 0.0974,
1.0004, 0.6322, -1.9103, 0.8706, -0.2598, -1.2323, -0.3205, 1.3420,
0.3936, 2.0456, 1.7977, -1.1196, -0.5652, -1.3567, 0.9958, -2.0845,
-1.3749, -0.7130, 1.0244, 0.0593, 0.0636, 0.2393, -1.3413, -0.3329,
0.2147, -0.5064, -0.5119, -0.9965, -0.4002, -0.6242, -0.3976, -0.6084,
1.2526, 0.7067, -0.2353, -0.5699, 0.0824, 1.0667, -0.0329, -0.1180,
-1.6390, 0.7729, 0.7066, -2.2387, -0.7651, -0.4625, 1.9304, -0.1592,
0.4796, -0.6125, -0.8265, 0.0568, 2.5158, -0.3929, -0.3927, -0.1145,
0.4040, 0.2954, -0.7797, 0.0569, 0.3714, 0.5620, -0.6556, 0.0075,
-0.0251, -1.6895, -0.8571, 0.3759, 0.0106, 1.0075, 1.3647, -0.1915,
-0.1687, -1.9660, 0.7073, -1.0942, -0.1903, 1.2114, 0.6589, 0.7416,
0.1255, -0.2084, -1.7247, -0.6163, 2.5999, 0.5725, -0.1817, 1.2373,
1.6475, -0.3679, -0.7700, 0.5559, -0.0299, -0.9032, -1.6034, 0.8630,
-0.9992, -0.5817, 0.5362, 1.4626, -0.7890, -1.0981, 1.7217, 0.7581,
1.2861, 0.3955, 0.2466, 0.3384, 0.1506, -0.9613, -1.5495, -0.1552,
-0.1180, 1.6468, -1.7706, 0.0055, -0.4989, 0.1550, 1.6259, -0.0722,
0.7386, 0.0657, 0.5618, -0.4135, -1.3406, 0.7209, -0.1369, -0.4943,
-0.3508, -0.2657, 1.5009, 1.8255, 0.7049, 0.9854, 0.3529, 1.1112,
-0.9561, 0.7174, -1.1929, -1.2257, 0.1584, 0.7370, -0.7273, 0.8572,
0.0591, -1.8631, -0.1637, 1.9188, -0.9281, -1.3265, 0.3382, -0.5424,
1.4783, -0.0339, 0.8036, 0.1805, 1.2170, -2.1388, -0.4797, 1.1232,
0.4213, -0.2824, -0.0592, -0.3094, -0.9494, 0.0946, 1.3795, 0.4063,
0.1934, -0.2050, 1.1473, -1.8769, -1.3865, 0.0212, 0.5409, 0.8030,
-2.4675, 0.6231, 0.1214, -1.3949, -1.0724, -1.5440, -0.4761, 1.4920,
1.7146, 0.3414, 0.1379, -0.8818, -1.7559, 1.8605, -1.3545, 1.5024,
-2.1116, -0.4113, -1.4620, -0.2642, 0.3996, -0.0468, -1.4671, 0.4811,
0.0413, 0.4503, -0.2901, 1.9869, -0.3118, -1.3857, 1.3151, -0.7364};
auto inputPtr = input->writeMap<float>();
memcpy(inputPtr, inputData.data(), input->getInfo()->size * sizeof(float));
const float spatialScale = 1.0 / 16;
const int pooledHeight = 3;
const int pooledWidth = 3;
auto roiInput = _Input({2, 5}, NCHW);
vector<float> roiData = { 2, 1 / spatialScale, 2 / spatialScale, 3 / spatialScale, 3 / spatialScale,
0, 0 / spatialScale, 2 / spatialScale, 2 / spatialScale, 3 / spatialScale};
memcpy(roiInput->writeMap<float>(), roiData.data(), roiInput->getInfo()->size * sizeof(float));
auto outputOri = _ROIPooling(_Convert(input, NC4HW4), _Convert(roiInput, NC4HW4), pooledHeight, pooledWidth, spatialScale);
auto output = _Convert(outputOri, NCHW);
auto outputPtr = output->readMap<float>();
vector<float> outputTorch = { -1.9660, 0.7073, -1.0942, 1.2114, 0.7073, 0.7416, 1.2114, 0.6589,
0.7416, -0.3679, -0.7700, 0.5559, -0.3679, -0.7700, 0.8630, -0.9032,
-1.6034, 0.8630, 0.3955, 0.2466, 0.3384, 0.3955, 0.2466, 0.3384,
-0.9613, -1.5495, -0.1552, 0.0657, 0.5618, -0.4135, 0.7209, 0.5618,
-0.4135, 0.7209, -0.1369, -0.4943, 0.7174, -1.1929, -1.2257, 0.7370,
-0.7273, 0.8572, 0.7370, -0.7273, 0.8572, -0.0339, 0.8036, 0.1805,
-0.0339, 0.8036, 1.1232, -2.1388, -0.4797, 1.1232, -0.2050, 1.1473,
-1.8769, 0.0212, 1.1473, 0.8030, 0.0212, 0.5409, 0.8030, 0.3414,
0.1379, -0.8818, 1.8605, 0.1379, 1.5024, 1.8605, -1.3545, 1.5024,
0.4503, -0.2901, 1.9869, 0.4503, 1.3151, 1.9869, -1.3857, 1.3151,
-0.7364, -0.8983, -0.6558, -0.3315, 0.3075, 1.4431, 0.4746, 0.3075,
1.4431, 0.4746, 0.1523, -0.4973, 1.2428, 0.1523, 1.6364, 1.2428,
-0.4270, 1.6364, 0.2426, -0.5696, 0.7918, -2.0354, -0.5696, 0.7918,
-0.0886, -1.3411, -0.3476, -0.0886, 1.4670, -0.9947, -1.8784, 1.4670,
0.0891, -0.3028, -0.8265, 0.0891, -0.3028, 2.8702, -1.2232, -0.8815,
2.8702, -0.8153, -0.5402, 0.5460, -0.8153, -0.5402, 1.0844, -1.3981,
0.4609, 1.0844, 1.2300, 1.4248, -0.3239, 1.2300, 1.4248, 0.1348,
0.7655, 0.1805, 2.8097, 0.7655, 0.1805, 2.8097, -0.4605, -0.6506,
0.4292, 0.3140, -0.3639, 0.4292, 0.3140, 0.6292, 0.2055, 0.0503,
0.6292, -1.3314, 1.2474, 0.5421, -0.8528, 1.2474, 0.5421, -0.8528,
-0.6006, 0.2814};
for (int i = 0, count = 0; i < outputTorch.size(); ++i) {
auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
if (diff > 0.0001) {
count++;
MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
return false;
} else {
// MNN_PRINT("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
}
}
auto opExpr = outputOri->expr().first;
auto grad = OpGrad::get(opExpr->get()->type());
vector<float> outputDiff = {0.2562, -0.0111, -1.5401, 0.5039, -0.6176, -0.1565, 1.8841, -0.3646,
-2.0870, 2.0135, -0.6354, -0.6129, 1.3251, 1.6232, -0.9059, -0.1318,
0.4667, 0.3912, 0.8568, -0.3556, 0.2248, -0.1303, -1.6850, 0.4877,
-0.1433, -1.3551, 0.8345, -1.1855, 1.4541, 0.4225, -1.0868, -1.0298,
0.8969, -0.3130, 0.5271, -0.8280, -0.6191, 0.5584, -1.8515, 0.6529,
-1.1239, 0.8073, 0.3257, 2.0378, -0.7919, 0.8637, 1.3289, 1.1278,
0.8832, 0.3839, -1.1428, 0.6202, -1.1006, -1.4295, -0.6698, 0.9958,
-0.4719, 1.7213, -0.1548, -0.0358, -0.1978, 0.4558, -1.3004, -0.1816,
0.4252, 0.9267, -0.9387, 1.0997, 1.2616, 1.7754, -0.5986, 0.4416,
-0.2952, -0.8717, 0.1005, -0.3586, 1.5658, -1.1852, 0.9115, -0.5239,
0.8183, 0.6974, 0.0715, -2.1861, -0.6542, -1.6065, -0.8234, 0.2259,
0.5781, -0.6618, -0.1676, 1.8451, 0.5430, -0.9335, -0.1344, -1.1820,
0.5422, -2.2710, 0.4764, 0.0155, 0.8077, 0.0861, -0.4085, 1.7200,
0.0790, -0.2339, -0.0539, 0.4019, -0.2817, -0.3598, -1.2706, 0.2367,
-0.8693, -1.2023, 1.0073, 1.4283, 0.0475, -2.9939, -0.6765, -0.9341,
-0.5517, -0.9149, 0.2808, -0.4714, 0.4733, 0.5395, 0.3451, 0.2129,
-1.2796, -0.2701, 2.2198, 1.2021, 0.5800, 0.5960, 1.4162, -1.1088,
1.0207, -2.8389, -2.0961, 0.6317, -0.7982, -0.6363, 0.4101, -1.3526,
0.9465, -0.1597, 0.9172, -1.0526, 1.3711, 0.7282, 1.0296, 0.3771,
0.3484, 0.8067, 1.4168, -1.1119, -1.6325, -0.8257, -0.0769, -1.3230,
1.1655, -0.4298};
auto inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {2, 9, 3, 3}, NCHW), NC4HW4)});
vector<float> expectedOutput = { 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.6974, 0.0715, -2.1861, 0.0000, -0.4283, -1.0284, -1.4852, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
-1.1011, 1.8451, -0.6390, 0.0000, 0.5422, -2.4054, 0.4764, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
-0.3930, 2.5277, 0.0861, 0.0000, -0.2339, -0.0539, 0.4809, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
-0.0450, -0.3598, -1.2706, 0.0000, 1.0073, 0.5590, -1.1548, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
-3.5456, -0.6765, -0.9341, 0.0000, -0.4714, -0.4416, 0.8203, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0750, 0.2129, -1.2796, 0.0000, 0.5800, 2.8158, 2.6183, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
-1.1088, 1.6524, -3.6371, 0.0000, -2.7324, 0.4101, -1.3526, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
-0.1061, 1.2114, 0.9172, 0.0000, 1.0296, 0.3771, 1.0766, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.8067, 0.5911, -1.1888, 0.0000, -2.9555, 1.1655, -0.4298, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.2562, -0.6287, -1.5401, 0.0000, 2.3880, -0.3646, -2.2435,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 3.3386, 0.9878, -0.6129, 0.0000, -0.1318, 0.4667, -0.5147,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.7265, -2.0406, 0.7125, 0.0000, -0.1433, -1.3551, 0.8345,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -1.1855, 0.4243, 1.3194, 0.0000, -1.3998, 0.5271, -0.8280,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.6191, 0.5584, -1.8515, 0.0000, 0.9786, 0.9139, 0.0154,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 1.7469, 1.7128, 1.1278, 0.0000, 0.6202, -1.1006, -2.5723,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.6698, 0.8410, -0.4719, 0.0000, 1.5235, 0.4558, -1.3362,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.1816, 1.5249, 0.9267, 0.0000, 0.8367, -0.5986, 1.7032,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.6538, -0.8717, -1.0847, 0.0000, 0.9115, 1.0419, 0.8183};
auto gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();
for (int i = 0; i < expectedOutput.size(); ++i) {
auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
if (diff > 0.001) {
MNN_ERROR("%s grad test failed, expected: %f, but got: %f!\n", name, expectedOutput[i], gotOutput[i]);
return false;
} else {
// MNN_PRINT("%s grad exact, %f <==> %f\n", name, expectedOutput[i], gotOutput[i]);
}
}
return true;
}
};
MNNTestSuiteRegister(RoiPoolGradTest, "grad/roi_pool");