MNN/test/expr/ReverseSequenceTest.cpp

201 lines
7.5 KiB
C++

//
// ReverseSequenceTest.cpp
// MNNTests
//
// Created by MNN on 2019/08/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <MNN/expr/ExprCreator.hpp>
#include "MNNTestSuite.h"
using namespace MNN::Express;
class ReverseSequenceTest : public MNNTestCase {
public:
virtual bool run(int precision) {
// high dimension, batch_dim ahead
float threshold = 0.0001;
if (precision == 2) {
threshold = 0.01;
}
{
auto y = _Input({4}, NHWC, halide_type_of<int32_t>());
std::vector<int> seq = {7, 2, 3, 5};
auto yPtr = y->writeMap<int32_t>();
::memcpy(yPtr, seq.data(), seq.size() * sizeof(int32_t));
auto x = _Input({6, 4, 7, 10, 8}, NHWC, halide_type_of<float>());
auto xPtr = x->writeMap<float>();
for (int o = 0; o < 6; ++o) {
for (int i = 0; i < 4; ++i) {
for (int m = 0; m < 7; ++m) {
for (int j = 0; j < 10; ++j) {
for (int k = 0; k < 8; ++k) {
xPtr[2240 * o + 560 * i + 80 * m + 8 * j + k] = 0.1 * o + i + m + j + 0.2 * k;
}
}
}
}
}
auto ry = _ReverseSequence(x, y, 1, 3);
auto ryPtr = ry->readMap<float>();
auto func_equal = [threshold](float a, float b) -> bool {
if (a - b > threshold || a - b < -threshold) {
return false;
} else {
return true;
}
};
int count = 0;
for (int o = 0; o < 6; ++o) {
for (int i = 0; i < 4; ++i) {
auto req = seq[i];
for (int m = 0; m < 7; ++m) {
for (int j = 0; j < 10; ++j) {
for (int k = 0; k < 8; ++k) {
float compute = ryPtr[2240 * o + 560 * i + 80 * m + 8 * j + k];
float need = 0.1 * o + i + m + j + 0.2 * k;
if (j < req) {
need = 0.1 * o + i + m + (req - j - 1) + 0.2 * k;
}
if (!func_equal(need, compute)) {
MNN_PRINT("case 1 error\n");
return false;
}
}
}
}
}
}
}
{ // test SizeComputer::needInputContent
int dim0 = 1, dim1 = 6, dim2 = 7, dim3 = 10, dim4 = 8;
auto x = _Input({dim0, dim1, dim2, dim3, dim4}, NHWC, halide_type_of<float>());
auto x_transpose = _Transpose(x, {1, 0, 2, 3, 4});
auto x_shape = _Shape(x_transpose, NHWC);
int ii[]= {1};
auto x_gather = _Gather(x_shape, _Const(ii, {1}, NCHW, halide_type_of<int>()));
auto ry = _ReverseSequence(x_transpose, x_gather, 1, 3);
auto xPtr = x->writeMap<float>();
for (int i = 0; i < dim0 * dim1 * dim2 * dim3 * dim4; ++i) {
xPtr[i] = 1;
}
auto ryPtr = ry->readMap<float>();
if (ryPtr == nullptr) {
MNN_PRINT("case 2 error\n");
return false;
}
}
// high dimension, seq_dim ahead
{
auto y = _Input({4}, NHWC, halide_type_of<int32_t>());
std::vector<int> seq = {7, 2, 3, 5};
auto yPtr = y->writeMap<int32_t>();
::memcpy(yPtr, seq.data(), seq.size() * sizeof(int32_t));
auto x = _Input({6, 10, 7, 4, 8}, NHWC, halide_type_of<float>());
auto xPtr = x->writeMap<float>();
for (int o = 0; o < 6; ++o) {
for (int i = 0; i < 10; ++i) {
for (int m = 0; m < 7; ++m) {
for (int j = 0; j < 4; ++j) {
for (int k = 0; k < 8; ++k) {
xPtr[2240 * o + 224 * i + 32 * m + 8 * j + k] = 0.1 * o + i + m + j + 0.2 * k;
}
}
}
}
}
auto ry = _ReverseSequence(x, y, 3, 1);
auto ryPtr = ry->readMap<float>();
auto func_equal = [threshold](float a, float b) -> bool {
if (a - b > threshold || a - b < (-1 * threshold)) {
return false;
} else {
return true;
}
};
int count = 0;
for (int o = 0; o < 6; ++o) {
for (int i = 0; i < 10; ++i) {
for (int m = 0; m < 7; ++m) {
for (int j = 0; j < 4; ++j) {
auto req = seq[j];
for (int k = 0; k < 8; ++k) {
auto compute = ryPtr[2240 * o + 224 * i + 32 * m + 8 * j + k];
auto need = 0.1 * o + i + m + j + 0.2 * k;
if (i < req) {
need = 0.1 * o + (req - i - 1) + m + j + 0.2 * k;
}
if (!func_equal(need, compute)) {
MNN_PRINT("case 3 error\n");
return false;
}
}
}
}
}
}
}
// 3 dimension
{
auto y = _Input({4}, NHWC, halide_type_of<int32_t>());
std::vector<int> seq = {7, 2, 3, 5};
auto yPtr = y->writeMap<int32_t>();
::memcpy(yPtr, seq.data(), seq.size() * sizeof(int32_t));
auto x = _Input({10, 4, 8}, NHWC, halide_type_of<float>());
auto xPtr = x->writeMap<float>();
for (int i = 0; i < 10; ++i) {
for (int j = 0; j < 4; ++j) {
for (int k = 0; k < 8; ++k) {
xPtr[32 * i + 8 * j + k] = 0.1 * i + j + k;
}
}
}
auto ry = _ReverseSequence(x, y, 1, 0);
auto ryPtr = ry->readMap<float>();
auto func_equal = [threshold](float a, float b) -> bool {
if (a - b > threshold || a - b < (-1 * threshold)) {
return false;
} else {
return true;
}
};
for (int i = 0; i < 10; ++i) {
for (int j = 0; j < 4; ++j) {
auto req = seq[j];
for (int k = 0; k < 8; ++k) {
auto compute = ryPtr[32 * i + 8 * j + k];
auto need = 0.1 * i + j + k;
if (i < req) {
need = 0.1 * (req - i - 1) + j + k;
}
if (!func_equal(need, compute)) {
MNN_PRINT("case 4 error\n");
return false;
}
}
}
}
}
return true;
}
};
MNNTestSuiteRegister(ReverseSequenceTest, "expr/ReverseSequence");