mirror of https://github.com/alibaba/MNN.git
77 lines
4.0 KiB
C++
77 lines
4.0 KiB
C++
//
|
|
// MatrixBandPartGradTest.cpp
|
|
// MNNTests
|
|
//
|
|
// Created by MNN on 2022/08/17.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include <MNN/expr/Expr.hpp>
|
|
#include <MNN/expr/ExprCreator.hpp>
|
|
#include "MNNTestSuite.h"
|
|
#include "TestUtils.h"
|
|
#include "../tools/train/source/grad/OpGrad.hpp"
|
|
|
|
using namespace MNN;
|
|
using namespace MNN::Express;
|
|
|
|
class MatrixBandPartGradTest : public MNNTestCase {
|
|
public:
|
|
char name[20] = "MatrixBandPart";
|
|
virtual ~MatrixBandPartGradTest() = default;
|
|
|
|
virtual bool run(int precision) {
|
|
std::vector<int> shape = {3, 2, 3, 2};
|
|
const int len = shape[0] * shape[1] * shape[2] * shape[3];
|
|
auto input = _Input(shape, NCHW);
|
|
const float inpudata[] = { -0.94003415f, -0.03546342f, -0.01028545f, 1.2092209f , 1.5427123f ,
|
|
1.0838836f , 0.5939991f , 2.017224f , 0.2702435f , 1.1762271f ,
|
|
-0.95548075f, -0.12556452f, -1.4086435f , -0.13807571f, -0.23514274f,
|
|
-0.5025484f , 0.93871444f, -0.5169497f , -1.5226837f , -0.8545326f ,
|
|
0.340934f , 0.25505793f, 1.7961069f , -0.7955173f , -0.16109313f,
|
|
1.3417882f , 0.9252207f , -0.69964254f, -0.5392309f , 0.4769467f ,
|
|
-0.33865267f, -0.5568984f , -0.53030866f, -0.07422069f, -0.7438325f ,
|
|
-0.06075661f};
|
|
auto inputPtr = input->writeMap<float>();
|
|
memcpy(inputPtr, inpudata, len * sizeof(float));
|
|
|
|
auto lower = _Scalar<int>(0);
|
|
auto upper = _Scalar<int>(1);
|
|
auto output = _MatrixBandPart(input, lower, upper);
|
|
|
|
auto opExpr = output->expr().first;
|
|
auto grad = OpGrad::get(opExpr->get()->type());
|
|
const float outputDiff[] = { 0.92951214f, -1.3656238f , 0.9058341f , 0.21897921f, -0.5062561f ,
|
|
0.29703847f, -0.5324379f , 0.8826049f , -0.9250548f , 1.8164085f ,
|
|
-1.7761891f , 1.2291343f , 0.45859334f, 0.09624046f, -0.8051032f ,
|
|
0.446291f , 0.9178219f , -0.7392022f , 2.31639f , -0.8006644f ,
|
|
0.5834905f , 1.5046587f , -0.11566874f, -2.449344f , -1.2720072f ,
|
|
-0.55631214f, -0.12848197f, -1.2433224f , -0.46224716f, 0.57611173f,
|
|
-1.0455252f , 1.1562699f , -1.2612194f , 0.46669045f, 0.38025302f,
|
|
-0.70845205f};
|
|
auto inputGrad = grad->onGrad(opExpr, {_Const(outputDiff, shape)});
|
|
|
|
const std::vector<float> expectedOutput = { 0.92951214f, -1.3656238f , 0.f , 0.21897921f, -0.f ,
|
|
0.f , -0.5324379f , 0.8826049f , -0.f , 1.8164085f ,
|
|
-0.f , 0.f , 0.45859334f, 0.09624046f, -0.f ,
|
|
0.446291f , 0.f , -0.f , 2.31639f , -0.8006644f ,
|
|
0.f , 1.5046587f , -0.f , -0.f , -1.2720072f ,
|
|
-0.55631214f, -0.f , -1.2433224f , -0.f , 0.f ,
|
|
-1.0455252f , 1.1562699f , -0.f , 0.46669045f, 0.f ,
|
|
-0.f};
|
|
auto gotOutput = inputGrad[0]->readMap<float>();
|
|
|
|
for (int i = 0; i < len; ++i) {
|
|
auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
|
|
if (diff > 0.0001) {
|
|
MNN_ERROR("%s grad test failed, expected: %f, but got: %f!\n", name, expectedOutput[i], gotOutput[i]);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
};
|
|
|
|
MNNTestSuiteRegister(MatrixBandPartGradTest, "grad/matrix_band_part");
|