mirror of https://github.com/alibaba/MNN.git
201 lines
7.5 KiB
C++
201 lines
7.5 KiB
C++
//
|
|
// ReverseSequenceTest.cpp
|
|
// MNNTests
|
|
//
|
|
// Created by MNN on 2019/08/31.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include <MNN/expr/ExprCreator.hpp>
|
|
#include "MNNTestSuite.h"
|
|
|
|
using namespace MNN::Express;
|
|
|
|
class ReverseSequenceTest : public MNNTestCase {
|
|
public:
|
|
virtual bool run(int precision) {
|
|
// high dimension, batch_dim ahead
|
|
float threshold = 0.0001;
|
|
if (precision == 2) {
|
|
threshold = 0.01;
|
|
}
|
|
|
|
{
|
|
auto y = _Input({4}, NHWC, halide_type_of<int32_t>());
|
|
std::vector<int> seq = {7, 2, 3, 5};
|
|
auto yPtr = y->writeMap<int32_t>();
|
|
::memcpy(yPtr, seq.data(), seq.size() * sizeof(int32_t));
|
|
auto x = _Input({6, 4, 7, 10, 8}, NHWC, halide_type_of<float>());
|
|
auto xPtr = x->writeMap<float>();
|
|
for (int o = 0; o < 6; ++o) {
|
|
for (int i = 0; i < 4; ++i) {
|
|
for (int m = 0; m < 7; ++m) {
|
|
for (int j = 0; j < 10; ++j) {
|
|
for (int k = 0; k < 8; ++k) {
|
|
xPtr[2240 * o + 560 * i + 80 * m + 8 * j + k] = 0.1 * o + i + m + j + 0.2 * k;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
auto ry = _ReverseSequence(x, y, 1, 3);
|
|
auto ryPtr = ry->readMap<float>();
|
|
|
|
auto func_equal = [threshold](float a, float b) -> bool {
|
|
if (a - b > threshold || a - b < -threshold) {
|
|
return false;
|
|
} else {
|
|
return true;
|
|
}
|
|
};
|
|
|
|
int count = 0;
|
|
for (int o = 0; o < 6; ++o) {
|
|
for (int i = 0; i < 4; ++i) {
|
|
auto req = seq[i];
|
|
for (int m = 0; m < 7; ++m) {
|
|
for (int j = 0; j < 10; ++j) {
|
|
for (int k = 0; k < 8; ++k) {
|
|
float compute = ryPtr[2240 * o + 560 * i + 80 * m + 8 * j + k];
|
|
float need = 0.1 * o + i + m + j + 0.2 * k;
|
|
if (j < req) {
|
|
need = 0.1 * o + i + m + (req - j - 1) + 0.2 * k;
|
|
}
|
|
|
|
if (!func_equal(need, compute)) {
|
|
MNN_PRINT("case 1 error\n");
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
{ // test SizeComputer::needInputContent
|
|
int dim0 = 1, dim1 = 6, dim2 = 7, dim3 = 10, dim4 = 8;
|
|
auto x = _Input({dim0, dim1, dim2, dim3, dim4}, NHWC, halide_type_of<float>());
|
|
auto x_transpose = _Transpose(x, {1, 0, 2, 3, 4});
|
|
auto x_shape = _Shape(x_transpose, NHWC);
|
|
int ii[]= {1};
|
|
auto x_gather = _Gather(x_shape, _Const(ii, {1}, NCHW, halide_type_of<int>()));
|
|
auto ry = _ReverseSequence(x_transpose, x_gather, 1, 3);
|
|
auto xPtr = x->writeMap<float>();
|
|
|
|
for (int i = 0; i < dim0 * dim1 * dim2 * dim3 * dim4; ++i) {
|
|
xPtr[i] = 1;
|
|
}
|
|
|
|
auto ryPtr = ry->readMap<float>();
|
|
|
|
if (ryPtr == nullptr) {
|
|
MNN_PRINT("case 2 error\n");
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// high dimension, seq_dim ahead
|
|
{
|
|
auto y = _Input({4}, NHWC, halide_type_of<int32_t>());
|
|
std::vector<int> seq = {7, 2, 3, 5};
|
|
auto yPtr = y->writeMap<int32_t>();
|
|
::memcpy(yPtr, seq.data(), seq.size() * sizeof(int32_t));
|
|
auto x = _Input({6, 10, 7, 4, 8}, NHWC, halide_type_of<float>());
|
|
auto xPtr = x->writeMap<float>();
|
|
for (int o = 0; o < 6; ++o) {
|
|
for (int i = 0; i < 10; ++i) {
|
|
for (int m = 0; m < 7; ++m) {
|
|
for (int j = 0; j < 4; ++j) {
|
|
for (int k = 0; k < 8; ++k) {
|
|
xPtr[2240 * o + 224 * i + 32 * m + 8 * j + k] = 0.1 * o + i + m + j + 0.2 * k;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
auto ry = _ReverseSequence(x, y, 3, 1);
|
|
auto ryPtr = ry->readMap<float>();
|
|
|
|
auto func_equal = [threshold](float a, float b) -> bool {
|
|
if (a - b > threshold || a - b < (-1 * threshold)) {
|
|
return false;
|
|
} else {
|
|
return true;
|
|
}
|
|
};
|
|
|
|
int count = 0;
|
|
for (int o = 0; o < 6; ++o) {
|
|
for (int i = 0; i < 10; ++i) {
|
|
for (int m = 0; m < 7; ++m) {
|
|
for (int j = 0; j < 4; ++j) {
|
|
auto req = seq[j];
|
|
for (int k = 0; k < 8; ++k) {
|
|
auto compute = ryPtr[2240 * o + 224 * i + 32 * m + 8 * j + k];
|
|
auto need = 0.1 * o + i + m + j + 0.2 * k;
|
|
if (i < req) {
|
|
need = 0.1 * o + (req - i - 1) + m + j + 0.2 * k;
|
|
}
|
|
if (!func_equal(need, compute)) {
|
|
MNN_PRINT("case 3 error\n");
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 3 dimension
|
|
{
|
|
auto y = _Input({4}, NHWC, halide_type_of<int32_t>());
|
|
std::vector<int> seq = {7, 2, 3, 5};
|
|
auto yPtr = y->writeMap<int32_t>();
|
|
::memcpy(yPtr, seq.data(), seq.size() * sizeof(int32_t));
|
|
auto x = _Input({10, 4, 8}, NHWC, halide_type_of<float>());
|
|
auto xPtr = x->writeMap<float>();
|
|
for (int i = 0; i < 10; ++i) {
|
|
for (int j = 0; j < 4; ++j) {
|
|
for (int k = 0; k < 8; ++k) {
|
|
xPtr[32 * i + 8 * j + k] = 0.1 * i + j + k;
|
|
}
|
|
}
|
|
}
|
|
|
|
auto ry = _ReverseSequence(x, y, 1, 0);
|
|
auto ryPtr = ry->readMap<float>();
|
|
|
|
auto func_equal = [threshold](float a, float b) -> bool {
|
|
if (a - b > threshold || a - b < (-1 * threshold)) {
|
|
return false;
|
|
} else {
|
|
return true;
|
|
}
|
|
};
|
|
|
|
for (int i = 0; i < 10; ++i) {
|
|
for (int j = 0; j < 4; ++j) {
|
|
auto req = seq[j];
|
|
for (int k = 0; k < 8; ++k) {
|
|
auto compute = ryPtr[32 * i + 8 * j + k];
|
|
auto need = 0.1 * i + j + k;
|
|
if (i < req) {
|
|
need = 0.1 * (req - i - 1) + j + k;
|
|
}
|
|
if (!func_equal(need, compute)) {
|
|
MNN_PRINT("case 4 error\n");
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
MNNTestSuiteRegister(ReverseSequenceTest, "expr/ReverseSequence");
|