mirror of https://github.com/alibaba/MNN.git
456 lines
29 KiB
C++
456 lines
29 KiB
C++
//
|
|
// InterpGradTest.cpp
|
|
// MNNTests
|
|
//
|
|
// Created by MNN on 2022/08/18.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include <MNN/expr/Expr.hpp>
|
|
#include <MNN/expr/ExprCreator.hpp>
|
|
#include "MNNTestSuite.h"
|
|
#include "TestUtils.h"
|
|
#include "../tools/train/source/grad/OpGrad.hpp"
|
|
|
|
using namespace MNN;
|
|
using namespace MNN::Express;
|
|
|
|
class InterpGradTest : public MNNTestCase {
|
|
public:
|
|
char name[20] = "Interp";
|
|
virtual ~InterpGradTest() = default;
|
|
|
|
virtual bool run(int precision) {
|
|
std::vector<int> shape = {2, 3, 2, 3};
|
|
const int len = shape[0] * shape[1] * shape[2] * shape[3];
|
|
auto input = _Input(shape, NCHW);
|
|
const float inpudata[] = { 0.5500, 0.6721, 0.4343, 0.8518, 0.9456, 0.6444, 0.5927, 0.4439, 0.9329,
|
|
0.1434, 0.6933, 0.0180, 0.3173, 0.2903, 0.4159, 0.8706, 0.1812, 0.5890,
|
|
0.3834, 0.0335, 0.9997, 0.7504, 0.5379, 0.9836, 0.3202, 0.4824, 0.9982,
|
|
0.8029, 0.2889, 0.8386, 0.2282, 0.6912, 0.2678, 0.9031, 0.7055, 0.9389};
|
|
auto inputPtr = input->writeMap<float>();
|
|
memcpy(inputPtr, inpudata, len * sizeof(float));
|
|
|
|
float wScale = 2.5;
|
|
float hScale = 2.5;
|
|
int outputW = int(floor(wScale * 3));
|
|
int outputH = int(floor(hScale * 2));
|
|
|
|
int mode = 1; // 1:near 2: bilinear 3: cubic 4: nearest_round
|
|
bool alignCorners = false;
|
|
float scales[] = {1.0, 1.0, hScale, wScale};
|
|
auto scaleVar = _Const((void*)scales, {4}, NCHW);
|
|
|
|
auto output = _Interp({input, scaleVar}, wScale, hScale, outputW, outputH, mode, alignCorners);
|
|
auto outputPtr = output->readMap<float>();
|
|
|
|
const int len2 = shape[0] * shape[1] * outputH * outputW;
|
|
|
|
std::vector<float> outputTorch = { 0.5500, 0.5500, 0.5500, 0.6721, 0.6721, 0.4343, 0.4343, 0.5500, 0.5500,
|
|
0.5500, 0.6721, 0.6721, 0.4343, 0.4343, 0.5500, 0.5500, 0.5500, 0.6721,
|
|
0.6721, 0.4343, 0.4343, 0.8518, 0.8518, 0.8518, 0.9456, 0.9456, 0.6444,
|
|
0.6444, 0.8518, 0.8518, 0.8518, 0.9456, 0.9456, 0.6444, 0.6444, 0.5927,
|
|
0.5927, 0.5927, 0.4439, 0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.5927,
|
|
0.4439, 0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.5927, 0.4439, 0.4439,
|
|
0.9329, 0.9329, 0.1434, 0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180,
|
|
0.1434, 0.1434, 0.1434, 0.6933, 0.6933, 0.0180, 0.0180, 0.3173, 0.3173,
|
|
0.3173, 0.2903, 0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.3173, 0.2903,
|
|
0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.3173, 0.2903, 0.2903, 0.4159,
|
|
0.4159, 0.8706, 0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890, 0.8706,
|
|
0.8706, 0.8706, 0.1812, 0.1812, 0.5890, 0.5890, 0.3834, 0.3834, 0.3834,
|
|
0.0335, 0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.3834, 0.0335, 0.0335,
|
|
0.9997, 0.9997, 0.3834, 0.3834, 0.3834, 0.0335, 0.0335, 0.9997, 0.9997,
|
|
0.7504, 0.7504, 0.7504, 0.5379, 0.5379, 0.9836, 0.9836, 0.7504, 0.7504,
|
|
0.7504, 0.5379, 0.5379, 0.9836, 0.9836, 0.3202, 0.3202, 0.3202, 0.4824,
|
|
0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.3202, 0.4824, 0.4824, 0.9982,
|
|
0.9982, 0.3202, 0.3202, 0.3202, 0.4824, 0.4824, 0.9982, 0.9982, 0.8029,
|
|
0.8029, 0.8029, 0.2889, 0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.8029,
|
|
0.2889, 0.2889, 0.8386, 0.8386, 0.2282, 0.2282, 0.2282, 0.6912, 0.6912,
|
|
0.2678, 0.2678, 0.2282, 0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678,
|
|
0.2282, 0.2282, 0.2282, 0.6912, 0.6912, 0.2678, 0.2678, 0.9031, 0.9031,
|
|
0.9031, 0.7055, 0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.9031, 0.7055,
|
|
0.7055, 0.9389, 0.9389};
|
|
|
|
for (int i = 0, count = 0; i < len2; ++i) {
|
|
auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
|
|
if (diff > 0.0001) {
|
|
count++;
|
|
MNN_ERROR("%d: %s type 1 output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
|
|
return false;
|
|
} else {
|
|
// printf("\ttype 1 output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
|
|
}
|
|
}
|
|
|
|
auto opExpr = output->expr().first;
|
|
auto grad = OpGrad::get(opExpr->get()->type());
|
|
|
|
float outputDiff[] = { 0.4334, 0.0341, 0.2641, 0.7900, 0.5033, 0.5753, 0.0180, 0.7617, 0.3299,
|
|
0.9388, 0.7645, 0.9430, 0.7312, 0.1058, 0.4154, 0.7537, 0.3959, 0.2058,
|
|
0.9084, 0.9178, 0.6203, 0.3946, 0.5286, 0.0053, 0.5959, 0.4805, 0.3953,
|
|
0.8146, 0.8543, 0.1426, 0.4022, 0.8199, 0.7822, 0.3160, 0.5057, 0.8435,
|
|
0.5449, 0.0964, 0.4719, 0.3557, 0.1786, 0.8186, 0.0859, 0.2833, 0.5462,
|
|
0.1870, 0.9203, 0.1523, 0.3556, 0.9206, 0.2185, 0.5502, 0.8321, 0.5941,
|
|
0.3160, 0.0663, 0.8522, 0.8215, 0.3595, 0.2714, 0.6255, 0.9103, 0.2248,
|
|
0.4765, 0.2330, 0.4213, 0.3474, 0.7129, 0.1307, 0.2414, 0.7421, 0.1453,
|
|
0.5165, 0.7668, 0.1646, 0.0379, 0.1988, 0.1783, 0.3200, 0.5802, 0.7501,
|
|
0.5057, 0.9157, 0.2080, 0.9982, 0.6694, 0.3964, 0.3710, 0.9381, 0.9157,
|
|
0.2548, 0.2127, 0.8212, 0.5140, 0.3528, 0.2028, 0.9128, 0.3492, 0.9882,
|
|
0.0330, 0.4107, 0.5150, 0.8750, 0.1118, 0.1271, 0.5068, 0.4232, 0.0709,
|
|
0.2711, 0.0418, 0.5329, 0.8123, 0.0119, 0.1818, 0.2264, 0.6342, 0.1863,
|
|
0.8303, 0.2253, 0.8572, 0.0520, 0.0255, 0.4094, 0.3164, 0.2758, 0.5764,
|
|
0.7998, 0.7261, 0.9420, 0.8043, 0.7131, 0.3567, 0.1961, 0.3868, 0.4668,
|
|
0.8830, 0.8475, 0.2829, 0.0681, 0.3372, 0.9303, 0.0397, 0.0962, 0.3651,
|
|
0.1226, 0.7876, 0.4374, 0.1730, 0.2058, 0.7499, 0.8105, 0.5794, 0.7401,
|
|
0.3478, 0.8476, 0.8795, 0.7856, 0.6042, 0.4180, 0.4664, 0.8128, 0.6839,
|
|
0.9811, 0.7328, 0.9305, 0.1411, 0.4011, 0.4810, 0.5414, 0.6038, 0.1644,
|
|
0.1686, 0.2125, 0.1554, 0.8285, 0.6496, 0.0667, 0.7326, 0.9510, 0.1087,
|
|
0.4501, 0.8744, 0.3976, 0.8691, 0.7303, 0.2784, 0.4464, 0.8928, 0.6532,
|
|
0.4175, 0.5971, 0.7475, 0.1091, 0.3149, 0.3717, 0.5579, 0.5649, 0.6624,
|
|
0.8024, 0.1316, 0.8202, 0.7971, 0.6213, 0.9040, 0.9452, 0.9925, 0.4661,
|
|
0.7995, 0.0764, 0.0370};
|
|
auto inputGrad = grad->onGrad(opExpr, {_Const(outputDiff, {shape[0], shape[1], outputH, outputW}, NCHW)});
|
|
|
|
std::vector<float> expectedOutput = { 4.3270, 4.1150, 2.9684, 2.3276, 2.6785, 2.0316, 4.0895, 3.3611, 1.8874,
|
|
3.1640, 1.9572, 1.5072, 4.5464, 3.4963, 2.5309, 2.9798, 1.9456, 1.5009,
|
|
2.3557, 1.8592, 3.2530, 4.2045, 2.6478, 0.9581, 4.7076, 2.8998, 3.5921,
|
|
3.7074, 1.4527, 1.8660, 5.2080, 2.2085, 3.8001, 4.8714, 2.2174, 1.5318};
|
|
auto gotOutput = inputGrad[0]->readMap<float>();
|
|
|
|
for (int i = 0; i < len; ++i) {
|
|
auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
|
|
if (diff > 0.0001) {
|
|
MNN_ERROR("%d: %s type 1 grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
|
|
return false;
|
|
} else {
|
|
// printf("\ttype 1 grad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
|
|
}
|
|
}
|
|
|
|
|
|
// TODO: inference of this mode is not aligned with pytorch
|
|
mode = 4;
|
|
output = _Interp({input, scaleVar}, wScale, hScale, outputW, outputH, mode, alignCorners);
|
|
|
|
outputPtr = output->readMap<float>();
|
|
|
|
outputTorch = { 0.5500, 0.5500, 0.6721, 0.6721, 0.6721, 0.4343, 0.4343, 0.5500, 0.5500,
|
|
0.6721, 0.6721, 0.6721, 0.4343, 0.4343, 0.8518, 0.8518, 0.9456, 0.9456,
|
|
0.9456, 0.6444, 0.6444, 0.8518, 0.8518, 0.9456, 0.9456, 0.9456, 0.6444,
|
|
0.6444, 0.8518, 0.8518, 0.9456, 0.9456, 0.9456, 0.6444, 0.6444, 0.5927,
|
|
0.5927, 0.4439, 0.4439, 0.4439, 0.9329, 0.9329, 0.5927, 0.5927, 0.4439,
|
|
0.4439, 0.4439, 0.9329, 0.9329, 0.1434, 0.1434, 0.6933, 0.6933, 0.6933,
|
|
0.0180, 0.0180, 0.1434, 0.1434, 0.6933, 0.6933, 0.6933, 0.0180, 0.0180,
|
|
0.1434, 0.1434, 0.6933, 0.6933, 0.6933, 0.0180, 0.0180, 0.3173, 0.3173,
|
|
0.2903, 0.2903, 0.2903, 0.4159, 0.4159, 0.3173, 0.3173, 0.2903, 0.2903,
|
|
0.2903, 0.4159, 0.4159, 0.8706, 0.8706, 0.1812, 0.1812, 0.1812, 0.5890,
|
|
0.5890, 0.8706, 0.8706, 0.1812, 0.1812, 0.1812, 0.5890, 0.5890, 0.8706,
|
|
0.8706, 0.1812, 0.1812, 0.1812, 0.5890, 0.5890, 0.3834, 0.3834, 0.0335,
|
|
0.0335, 0.0335, 0.9997, 0.9997, 0.3834, 0.3834, 0.0335, 0.0335, 0.0335,
|
|
0.9997, 0.9997, 0.7504, 0.7504, 0.5379, 0.5379, 0.5379, 0.9836, 0.9836,
|
|
0.7504, 0.7504, 0.5379, 0.5379, 0.5379, 0.9836, 0.9836, 0.7504, 0.7504,
|
|
0.5379, 0.5379, 0.5379, 0.9836, 0.9836, 0.3202, 0.3202, 0.4824, 0.4824,
|
|
0.4824, 0.9982, 0.9982, 0.3202, 0.3202, 0.4824, 0.4824, 0.4824, 0.9982,
|
|
0.9982, 0.8029, 0.8029, 0.2889, 0.2889, 0.2889, 0.8386, 0.8386, 0.8029,
|
|
0.8029, 0.2889, 0.2889, 0.2889, 0.8386, 0.8386, 0.8029, 0.8029, 0.2889,
|
|
0.2889, 0.2889, 0.8386, 0.8386, 0.2282, 0.2282, 0.6912, 0.6912, 0.6912,
|
|
0.2678, 0.2678, 0.2282, 0.2282, 0.6912, 0.6912, 0.6912, 0.2678, 0.2678,
|
|
0.9031, 0.9031, 0.7055, 0.7055, 0.7055, 0.9389, 0.9389, 0.9031, 0.9031,
|
|
0.7055, 0.7055, 0.7055, 0.9389, 0.9389, 0.9031, 0.9031, 0.7055, 0.7055,
|
|
0.7055, 0.9389, 0.9389};
|
|
|
|
for (int i = 0, count = 0; i < len2; ++i) {
|
|
auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
|
|
if (diff > 0.0001) {
|
|
count++;
|
|
MNN_ERROR("%d: %s type 4 output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
|
|
// return false;
|
|
} else {
|
|
// printf("\ttype 4 output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
|
|
}
|
|
}
|
|
|
|
opExpr = output->expr().first;
|
|
grad = OpGrad::get(opExpr->get()->type());
|
|
|
|
inputGrad = grad->onGrad(opExpr, {_Const(outputDiff, {shape[0], shape[1], outputH, outputW}, NCHW)});
|
|
|
|
expectedOutput = { 1.5591, 4.2037, 1.4303, 3.0892, 4.5961, 3.5697, 1.7576, 2.5775, 1.5051,
|
|
3.5223, 4.7144, 1.8895, 1.3857, 3.2839, 1.3604, 3.7227, 4.5758, 2.6714,
|
|
1.1237, 1.4307, 2.4008, 3.2887, 5.2241, 1.8103, 1.3488, 2.7237, 2.3129,
|
|
4.5373, 4.1577, 3.1452, 1.9830, 3.2474, 2.8705, 4.0911, 5.1838, 2.4614};
|
|
gotOutput = inputGrad[0]->readMap<float>();
|
|
|
|
for (int i = 0; i < len; ++i) {
|
|
auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
|
|
if (diff > 0.0001) {
|
|
MNN_ERROR("%d: %s type 4 grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
|
|
return false;
|
|
} else {
|
|
// printf("\ttype 4 grad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
|
|
}
|
|
}
|
|
|
|
|
|
mode = 2;
|
|
alignCorners = false;
|
|
output = _Interp({input, scaleVar}, wScale, hScale, outputW, outputH, mode, alignCorners);
|
|
outputPtr = output->readMap<float>();
|
|
|
|
outputTorch = { 0.5500, 0.5622, 0.6111, 0.6599, 0.6008, 0.5056, 0.4343, 0.5802, 0.5921,
|
|
0.6398, 0.6875, 0.6262, 0.5286, 0.4553, 0.7009, 0.7117, 0.7549, 0.7981,
|
|
0.7280, 0.6202, 0.5394, 0.8216, 0.8313, 0.8699, 0.9086, 0.8298, 0.7118,
|
|
0.6234, 0.8518, 0.8612, 0.8987, 0.9362, 0.8552, 0.7348, 0.6444, 0.5927,
|
|
0.5778, 0.5183, 0.4588, 0.5906, 0.7862, 0.9329, 0.5478, 0.5399, 0.5083,
|
|
0.4767, 0.5806, 0.7296, 0.8414, 0.3681, 0.3881, 0.4683, 0.5485, 0.5407,
|
|
0.5034, 0.4755, 0.1883, 0.2363, 0.4283, 0.6204, 0.5007, 0.2772, 0.1095,
|
|
0.1434, 0.1984, 0.4184, 0.6383, 0.4907, 0.2206, 0.0180, 0.3173, 0.3146,
|
|
0.3038, 0.2930, 0.3280, 0.3782, 0.4159, 0.3726, 0.3633, 0.3260, 0.2887,
|
|
0.3255, 0.3871, 0.4332, 0.5939, 0.5581, 0.4148, 0.2716, 0.3158, 0.4224,
|
|
0.5024, 0.8153, 0.7530, 0.5037, 0.2544, 0.3060, 0.4578, 0.5717, 0.8706,
|
|
0.8017, 0.5259, 0.2501, 0.3035, 0.4667, 0.5890, 0.3834, 0.3484, 0.2084,
|
|
0.0685, 0.3234, 0.7098, 0.9997, 0.4201, 0.3865, 0.2520, 0.1176, 0.3582,
|
|
0.7238, 0.9981, 0.5669, 0.5388, 0.4263, 0.3138, 0.4975, 0.7799, 0.9916,
|
|
0.7137, 0.6911, 0.6006, 0.5101, 0.6368, 0.8359, 0.9852, 0.7504, 0.7291,
|
|
0.6442, 0.5591, 0.6716, 0.8499, 0.9836, 0.3202, 0.3364, 0.4013, 0.4662,
|
|
0.6371, 0.8435, 0.9982, 0.3685, 0.3779, 0.4158, 0.4536, 0.6188, 0.8265,
|
|
0.9822, 0.5616, 0.5440, 0.4736, 0.4032, 0.5455, 0.7586, 0.9184, 0.7546,
|
|
0.7100, 0.5314, 0.3529, 0.4721, 0.6907, 0.8546, 0.8029, 0.7515, 0.5459,
|
|
0.3403, 0.4538, 0.6737, 0.8386, 0.2282, 0.2745, 0.4597, 0.6449, 0.5642,
|
|
0.3948, 0.2678, 0.2957, 0.3354, 0.4942, 0.6529, 0.5853, 0.4422, 0.3349,
|
|
0.5656, 0.5789, 0.6320, 0.6851, 0.6698, 0.6319, 0.6033, 0.8356, 0.8225,
|
|
0.7698, 0.7172, 0.7544, 0.8215, 0.8718, 0.9031, 0.8833, 0.8043, 0.7253,
|
|
0.7755, 0.8689, 0.9389};
|
|
|
|
for (int i = 0, count = 0; i < len2; ++i) {
|
|
auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
|
|
if (diff > 0.0001) {
|
|
count++;
|
|
MNN_ERROR("%d: %s type 2 alignCorners false output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
|
|
return false;
|
|
} else {
|
|
// printf("\ttype 2 alignCorners false output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
|
|
}
|
|
}
|
|
|
|
opExpr = output->expr().first;
|
|
grad = OpGrad::get(opExpr->get()->type());
|
|
|
|
inputGrad = grad->onGrad(opExpr, {_Const(outputDiff, {shape[0], shape[1], outputH, outputW}, NCHW)});
|
|
|
|
expectedOutput = { 2.8685, 4.0238, 2.2734, 2.9216, 3.4296, 2.9312, 2.9163, 2.7736, 2.0526,
|
|
3.3512, 3.0745, 1.7981, 2.9985, 3.4252, 1.8688, 3.3955, 3.2112, 2.1006,
|
|
1.9465, 2.0171, 2.4722, 3.6515, 3.6092, 1.5819, 3.0323, 3.1607, 2.6740,
|
|
4.1352, 2.7624, 2.4610, 3.4138, 3.1663, 3.0918, 4.5425, 3.4122, 2.2106};
|
|
gotOutput = inputGrad[0]->readMap<float>();
|
|
|
|
for (int i = 0; i < len; ++i) {
|
|
auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
|
|
if (diff > 0.0001) {
|
|
MNN_ERROR("%d: %s type 2 alignCorners false grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
|
|
return false;
|
|
} else {
|
|
// printf("\ttype 2 alignCorners false grad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
|
|
}
|
|
}
|
|
|
|
|
|
mode = 2;
|
|
alignCorners = true;
|
|
output = _Interp({input, scaleVar}, wScale, hScale, outputW, outputH, mode, alignCorners);
|
|
outputPtr = output->readMap<float>();
|
|
|
|
outputTorch = { 0.5500, 0.5907, 0.6314, 0.6721, 0.5928, 0.5136, 0.4343, 0.6255, 0.6638,
|
|
0.7021, 0.7405, 0.6559, 0.5714, 0.4868, 0.7009, 0.7369, 0.7729, 0.8088,
|
|
0.7190, 0.6292, 0.5394, 0.7764, 0.8100, 0.8436, 0.8772, 0.7821, 0.6870,
|
|
0.5919, 0.8518, 0.8831, 0.9143, 0.9456, 0.8452, 0.7448, 0.6444, 0.5927,
|
|
0.5431, 0.4935, 0.4439, 0.6069, 0.7699, 0.9329, 0.4804, 0.4890, 0.4976,
|
|
0.5063, 0.5722, 0.6382, 0.7042, 0.3681, 0.4349, 0.5017, 0.5686, 0.5375,
|
|
0.5065, 0.4755, 0.2557, 0.3808, 0.5059, 0.6309, 0.5029, 0.3748, 0.2467,
|
|
0.1434, 0.3267, 0.5100, 0.6933, 0.4682, 0.2431, 0.0180, 0.3173, 0.3083,
|
|
0.2993, 0.2903, 0.3322, 0.3740, 0.4159, 0.4556, 0.3914, 0.3272, 0.2630,
|
|
0.3284, 0.3938, 0.4592, 0.5939, 0.4745, 0.3551, 0.2358, 0.3246, 0.4136,
|
|
0.5024, 0.7323, 0.5577, 0.3831, 0.2085, 0.3209, 0.4333, 0.5457, 0.8706,
|
|
0.6408, 0.4110, 0.1812, 0.3171, 0.4531, 0.5890, 0.3834, 0.2668, 0.1501,
|
|
0.0335, 0.3556, 0.6776, 0.9997, 0.4751, 0.3700, 0.2648, 0.1596, 0.4383,
|
|
0.7170, 0.9957, 0.5669, 0.4732, 0.3794, 0.2857, 0.5210, 0.7563, 0.9916,
|
|
0.6586, 0.5764, 0.4941, 0.4118, 0.6037, 0.7957, 0.9876, 0.7504, 0.6796,
|
|
0.6087, 0.5379, 0.6865, 0.8350, 0.9836, 0.3202, 0.3743, 0.4283, 0.4824,
|
|
0.6543, 0.8263, 0.9982, 0.4409, 0.4386, 0.4363, 0.4340, 0.6088, 0.7835,
|
|
0.9583, 0.5616, 0.5029, 0.4443, 0.3856, 0.5632, 0.7408, 0.9184, 0.6822,
|
|
0.5672, 0.4523, 0.3373, 0.5177, 0.6981, 0.8785, 0.8029, 0.6316, 0.4602,
|
|
0.2889, 0.4721, 0.6554, 0.8386, 0.2282, 0.3825, 0.5369, 0.6912, 0.5501,
|
|
0.4089, 0.2678, 0.3969, 0.4962, 0.5955, 0.6948, 0.6084, 0.5220, 0.4356,
|
|
0.5656, 0.6099, 0.6541, 0.6984, 0.6667, 0.6350, 0.6033, 0.7344, 0.7236,
|
|
0.7127, 0.7019, 0.7250, 0.7481, 0.7711, 0.9031, 0.8372, 0.7714, 0.7055,
|
|
0.7833, 0.8611, 0.9389};
|
|
|
|
for (int i = 0, count = 0; i < len2; ++i) {
|
|
auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
|
|
if (diff > 0.0001) {
|
|
count++;
|
|
MNN_ERROR("%d: %s type 2 alignCorners true output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
|
|
return false;
|
|
} else {
|
|
// printf("\ttype 2 alignCorners true output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
|
|
}
|
|
}
|
|
|
|
opExpr = output->expr().first;
|
|
grad = OpGrad::get(opExpr->get()->type());
|
|
|
|
inputGrad = grad->onGrad(opExpr, {_Const(outputDiff, {shape[0], shape[1], outputH, outputW}, NCHW)});
|
|
|
|
expectedOutput = { 2.2272, 4.4075, 2.3271, 2.4936, 4.0926, 2.9002, 2.5863, 3.2598, 2.1267,
|
|
2.6511, 3.5678, 1.7748, 2.4711, 3.9431, 1.8645, 2.7803, 3.8429, 2.0980,
|
|
1.8186, 2.5404, 2.4130, 2.6931, 4.1895, 1.6237, 2.6524, 3.7168, 2.6095,
|
|
3.1733, 3.5839, 2.4896, 2.7044, 3.9169, 3.0705, 3.6579, 4.2589, 2.2286};
|
|
gotOutput = inputGrad[0]->readMap<float>();
|
|
|
|
for (int i = 0; i < len; ++i) {
|
|
auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
|
|
if (diff > 0.0001) {
|
|
MNN_ERROR("%d: %s type 2 alignCorners true grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
|
|
return false;
|
|
} else {
|
|
// printf("\ttype 2 alignCorners true grad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
|
|
}
|
|
}
|
|
|
|
|
|
mode = 3;
|
|
alignCorners = true;
|
|
output = _Interp({input, scaleVar}, wScale, hScale, outputW, outputH, mode, alignCorners);
|
|
outputPtr = output->readMap<float>();
|
|
|
|
outputTorch = { 0.5500, 0.6016, 0.6601, 0.6721, 0.6108, 0.5159, 0.4343, 0.6184, 0.6688,
|
|
0.7257, 0.7341, 0.6675, 0.5677, 0.4819, 0.7009, 0.7499, 0.8048, 0.8088,
|
|
0.7360, 0.6302, 0.5393, 0.7834, 0.8309, 0.8840, 0.8836, 0.8045, 0.6927,
|
|
0.5968, 0.8518, 0.8981, 0.9495, 0.9456, 0.8612, 0.7444, 0.6444, 0.5927,
|
|
0.5187, 0.4364, 0.4439, 0.5813, 0.7707, 0.9329, 0.4909, 0.4814, 0.4724,
|
|
0.5004, 0.5724, 0.6552, 0.7256, 0.3681, 0.4364, 0.5158, 0.5686, 0.5616,
|
|
0.5159, 0.4754, 0.2452, 0.3913, 0.5592, 0.6368, 0.5508, 0.3766, 0.2253,
|
|
0.1434, 0.3540, 0.5952, 0.6933, 0.5418, 0.2611, 0.0180, 0.3173, 0.3018,
|
|
0.2848, 0.2903, 0.3268, 0.3749, 0.4159, 0.4427, 0.3764, 0.3003, 0.2656,
|
|
0.3056, 0.3856, 0.4551, 0.5940, 0.4664, 0.3189, 0.2357, 0.2799, 0.3986,
|
|
0.5024, 0.7452, 0.5564, 0.3375, 0.2059, 0.2542, 0.4116, 0.5498, 0.8706,
|
|
0.6309, 0.3529, 0.1812, 0.2330, 0.4223, 0.5890, 0.3834, 0.2196, 0.0363,
|
|
0.0335, 0.2988, 0.6761, 0.9997, 0.4665, 0.3191, 0.1539, 0.1478, 0.3794,
|
|
0.7113, 0.9961, 0.5669, 0.4392, 0.2958, 0.2857, 0.4767, 0.7538, 0.9917,
|
|
0.6673, 0.5592, 0.4377, 0.4236, 0.5740, 0.7963, 0.9872, 0.7504, 0.6587,
|
|
0.5553, 0.5379, 0.6546, 0.8315, 0.9836, 0.3202, 0.3426, 0.3740, 0.4824,
|
|
0.6628, 0.8448, 0.9982, 0.4296, 0.4033, 0.3776, 0.4386, 0.6044, 0.7977,
|
|
0.9620, 0.5615, 0.4766, 0.3818, 0.3856, 0.5338, 0.7409, 0.9184, 0.6935,
|
|
0.5498, 0.3861, 0.3327, 0.4633, 0.6841, 0.8748, 0.8029, 0.6105, 0.3896,
|
|
0.2889, 0.4048, 0.6370, 0.8386, 0.2282, 0.3975, 0.5925, 0.6912, 0.6094,
|
|
0.4268, 0.2678, 0.3811, 0.4950, 0.6263, 0.6944, 0.6428, 0.5237, 0.4198,
|
|
0.5656, 0.6127, 0.6671, 0.6984, 0.6832, 0.6406, 0.6033, 0.7502, 0.7304,
|
|
0.7080, 0.7023, 0.7236, 0.7576, 0.7869, 0.9031, 0.8279, 0.7418, 0.7055,
|
|
0.7570, 0.8544, 0.9389};
|
|
|
|
for (int i = 0, count = 0; i < len2; ++i) {
|
|
auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
|
|
if (diff > 0.0001) {
|
|
count++;
|
|
MNN_ERROR("%d: %s type 3 alignCorners true output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
|
|
return false;
|
|
} else {
|
|
// printf("\ttype 3 alignCorners true output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
|
|
}
|
|
}
|
|
|
|
opExpr = output->expr().first;
|
|
grad = OpGrad::get(opExpr->get()->type());
|
|
|
|
inputGrad = grad->onGrad(opExpr, {_Const(outputDiff, {shape[0], shape[1], outputH, outputW}, NCHW)});
|
|
|
|
expectedOutput = { 1.9391, 4.9129, 2.1415, 2.2126, 4.4911, 2.7508, 2.3607, 3.6276, 1.9484,
|
|
2.4287, 4.0120, 1.5888, 2.2588, 4.3451, 1.6770, 2.5254, 4.2834, 1.9103,
|
|
1.6577, 2.7228, 2.3390, 2.5627, 4.6302, 1.3659, 2.4378, 4.0773, 2.4462,
|
|
3.0859, 3.9371, 2.2413, 2.5186, 4.3590, 2.8111, 3.4074, 4.8708, 1.8704};
|
|
gotOutput = inputGrad[0]->readMap<float>();
|
|
|
|
for (int i = 0; i < len; ++i) {
|
|
auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
|
|
if (diff > 0.0001) {
|
|
MNN_ERROR("%d: %s type 3 alignCorners true grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
|
|
return false;
|
|
} else {
|
|
// printf("\ttype 3 alignCorners true grad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
|
|
}
|
|
}
|
|
|
|
|
|
// TODO: inference of these arguments combination is wrong
|
|
mode = 3;
|
|
alignCorners = false;
|
|
output = _Interp({input, scaleVar}, wScale, hScale, outputW, outputH, mode, alignCorners);
|
|
outputPtr = output->readMap<float>();
|
|
|
|
outputTorch = { 0.5029, 0.5286, 0.6010, 0.6457, 0.5914, 0.4814, 0.3971, 0.5615,
|
|
0.5862, 0.6574, 0.6995, 0.6399, 0.5250, 0.4368, 0.6890, 0.7116,
|
|
0.7801, 0.8164, 0.7456, 0.6196, 0.5230, 0.8165, 0.8369, 0.9029,
|
|
0.9333, 0.8512, 0.7143, 0.6092, 0.8751, 0.8946, 0.9593, 0.9870,
|
|
0.8998, 0.7578, 0.6488, 0.6671, 0.6195, 0.4714, 0.3974, 0.5638,
|
|
0.8509, 1.0713, 0.5659, 0.5457, 0.4732, 0.4479, 0.5641, 0.7438,
|
|
0.8818, 0.3459, 0.3851, 0.4771, 0.5578, 0.5647, 0.5109, 0.4698,
|
|
0.1259, 0.2246, 0.4809, 0.6677, 0.5654, 0.2781, 0.0578, 0.0247,
|
|
0.1507, 0.4827, 0.7182, 0.5657, 0.1710, -0.1317, 0.2512, 0.2594,
|
|
0.2705, 0.2928, 0.3338, 0.3726, 0.4026, 0.3716, 0.3550, 0.3081,
|
|
0.2790, 0.3139, 0.3848, 0.4391, 0.6334, 0.5628, 0.3898, 0.2489,
|
|
0.2707, 0.4111, 0.5187, 0.8952, 0.7706, 0.4716, 0.2189, 0.2274,
|
|
0.4375, 0.5982, 1.0157, 0.8661, 0.5092, 0.2050, 0.2076, 0.4496,
|
|
0.6347, 0.3832, 0.3061, 0.0645, -0.0544, 0.2232, 0.6986, 1.0637,
|
|
0.4508, 0.3795, 0.1576, 0.0465, 0.2952, 0.7247, 1.0545, 0.5979,
|
|
0.5391, 0.3601, 0.2659, 0.4517, 0.7814, 1.0345, 0.7450, 0.6987,
|
|
0.5626, 0.4852, 0.6081, 0.8381, 1.0146, 0.8126, 0.7721, 0.6558,
|
|
0.5861, 0.6801, 0.8642, 1.0054, 0.2409, 0.2829, 0.3374, 0.4532,
|
|
0.6727, 0.8841, 1.0469, 0.3480, 0.3650, 0.3645, 0.4263, 0.6230,
|
|
0.8455, 1.0166, 0.5809, 0.5435, 0.4237, 0.3677, 0.5149, 0.7615,
|
|
0.9508, 0.8139, 0.7220, 0.4828, 0.3091, 0.4068, 0.6774, 0.8849,
|
|
0.9210, 0.8041, 0.5100, 0.2822, 0.3571, 0.6388, 0.8546, 0.0947,
|
|
0.2011, 0.4682, 0.6758, 0.6104, 0.3575, 0.1637, 0.2385, 0.3196,
|
|
0.5226, 0.6813, 0.6343, 0.4452, 0.3004, 0.5510, 0.5772, 0.6409,
|
|
0.6932, 0.6865, 0.6361, 0.5976, 0.8636, 0.8348, 0.7592, 0.7052,
|
|
0.7386, 0.8270, 0.8948, 1.0073, 0.9533, 0.8136, 0.7107, 0.7626,
|
|
0.9148, 1.0315};
|
|
|
|
for (int i = 0, count = 0; i < len2; ++i) {
|
|
auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
|
|
if (diff > 0.0001) {
|
|
count++;
|
|
MNN_ERROR("%d: %s type 3 alignCorners false output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
|
|
// return false;
|
|
} else {
|
|
// printf("\ttype 3 alignCorners false output exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
|
|
}
|
|
}
|
|
|
|
opExpr = output->expr().first;
|
|
grad = OpGrad::get(opExpr->get()->type());
|
|
|
|
inputGrad = grad->onGrad(opExpr, {_Const(outputDiff, {shape[0], shape[1], outputH, outputW}, NCHW)});
|
|
|
|
expectedOutput = { 2.6602, 4.3750, 2.0222, 2.8827, 3.6077, 2.9003, 2.9207, 2.8721, 2.0044,
|
|
3.2637, 3.3148, 1.5907, 2.9417, 3.6414, 1.6572, 3.3495, 3.4318, 1.9783,
|
|
1.8646, 1.8953, 2.5679, 3.7284, 3.9885, 1.2336, 2.9781, 3.3275, 2.5594,
|
|
4.2535, 2.7625, 2.3446, 3.2755, 3.2857, 3.0657, 4.6292, 3.6900, 1.8911};
|
|
gotOutput = inputGrad[0]->readMap<float>();
|
|
|
|
for (int i = 0; i < len; ++i) {
|
|
auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
|
|
if (diff > 0.0001) {
|
|
MNN_ERROR("%d: %s type 3 alignCorners false grad test failed, expected: %f, but got: %f!\n", i, name, expectedOutput[i], gotOutput[i]);
|
|
// return false;
|
|
} else {
|
|
// printf("\ttype 3 alignCorners false grad exact: %d, %f <==> %f\n", i, gotOutput[i], expectedOutput[i]);
|
|
}
|
|
}
|
|
|
|
|
|
return true;
|
|
}
|
|
};
|
|
|
|
MNNTestSuiteRegister(InterpGradTest, "grad/interp");
|