MNN/test/op/BinaryOPTest.cpp

735 lines
27 KiB
C++
Raw Normal View History

2019-04-17 10:49:11 +08:00
//
// BinaryOPTest.cpp
// MNNTests
//
// Created by MNN on 2019/01/15.
// Copyright © 2018, Alibaba Group Holding Limited
//
2020-11-05 16:41:56 +08:00
2019-12-27 22:16:57 +08:00
#include <MNN/expr/Expr.hpp>
#include <MNN/expr/ExprCreator.hpp>
2019-04-17 10:49:11 +08:00
#include "MNNTestSuite.h"
#include "TestUtils.h"
2019-12-27 22:16:57 +08:00
using namespace MNN::Express;
2020-02-26 09:57:17 +08:00
class BinaryBroadcastShapeTest : public MNNTestCase {
public:
virtual ~BinaryBroadcastShapeTest() = default;
virtual bool run() {
auto input_x = _Const(1, {4, 1, 2, 1}, NCHW);
auto input_y = _Const(1, {2, 1, 4}, NCHW);
input_x->setName("input_x");
input_y->setName("input_y");
2020-11-05 16:41:56 +08:00
auto output = _Add(input_x, input_y);
2020-02-26 09:57:17 +08:00
const std::vector<int> expectedOutputShape = {4, 2, 2, 4};
2020-11-05 16:41:56 +08:00
auto outputSize = output->getInfo()->dim.size();
2020-02-26 09:57:17 +08:00
if (outputSize != expectedOutputShape.size()) {
MNN_ERROR("BinaryBroadcastShapeTest shape compute error!\n");
return false;
}
for (int i = 0; i < outputSize; i++) {
if (output->getInfo()->dim[i] != expectedOutputShape[i]) {
MNN_ERROR("BinaryBroadcastShapeTest shape compute error!\n");
return false;
}
}
2020-11-05 16:41:56 +08:00
const std::vector<float> expectedOutput = {2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.};
auto outputPtr = output->readMap<float>();
2020-02-26 09:57:17 +08:00
if (!checkVector<float>(outputPtr, expectedOutput.data(), outputSize, 1e-6)) {
MNN_ERROR("BinaryBroadcastShapeTest compute error!\n");
return false;
}
return true;
}
};
2019-12-27 22:16:57 +08:00
class AddTest : public MNNTestCase {
public:
virtual ~AddTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input(
{
4,
},
NCHW);
auto input_y = _Input(
{
4,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {-1.0, -2.0, -3.0, -4.0};
const float data_y[] = {1.0, 2.0, 3.0, 4.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 4 * sizeof(float));
memcpy(ptr_y, data_y, 4 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _Add(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<float> expectedOutput = {0.0, 0.0, 0.0, 0.0};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<float>();
2019-12-27 22:16:57 +08:00
if (!checkVector<float>(gotOutput, expectedOutput.data(), 4, 0.01)) {
MNN_ERROR("AddTest test failed!\n");
return false;
}
return true;
2019-04-17 10:49:11 +08:00
}
2019-12-27 22:16:57 +08:00
};
class SubtractTest : public MNNTestCase {
public:
virtual ~SubtractTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input(
{
4,
},
NCHW);
auto input_y = _Input(
{
4,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {-1.0, -2.0, -3.0, -4.0};
const float data_y[] = {1.0, 2.0, 3.0, 4.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 4 * sizeof(float));
memcpy(ptr_y, data_y, 4 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _Subtract(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<float> expectedOutput = {-2.0, -4.0, -6.0, -8.0};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<float>();
2019-12-27 22:16:57 +08:00
if (!checkVector<float>(gotOutput, expectedOutput.data(), 4, 0.01)) {
MNN_ERROR("SubtractTest test failed!\n");
return false;
}
return true;
}
};
class MultiplyTest : public MNNTestCase {
public:
virtual ~MultiplyTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input(
{
4,
},
NCHW);
auto input_y = _Input(
{
4,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {-1.0, -2.0, -3.0, -4.0};
const float data_y[] = {1.0, 2.0, 3.0, 4.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 4 * sizeof(float));
memcpy(ptr_y, data_y, 4 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _Multiply(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<float> expectedOutput = {-1.0, -4.0, -9.0, -16.0};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<float>();
2019-12-27 22:16:57 +08:00
if (!checkVector<float>(gotOutput, expectedOutput.data(), 4, 0.01)) {
MNN_ERROR("MultiplyTest test failed!\n");
return false;
}
return true;
}
};
class DivideTest : public MNNTestCase {
public:
virtual ~DivideTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input(
{
4,
},
NCHW);
auto input_y = _Input(
{
4,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {-1.0, -2.0, -3.0, -4.0};
const float data_y[] = {2.0, 4.0, 6.0, 8.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 4 * sizeof(float));
memcpy(ptr_y, data_y, 4 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _Divide(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<float> expectedOutput = {-0.5, -0.5, -0.5, -0.5};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<float>();
2019-12-27 22:16:57 +08:00
if (!checkVector<float>(gotOutput, expectedOutput.data(), 4, 0.01)) {
MNN_ERROR("DivideTest test failed!\n");
return false;
}
return true;
}
};
class PowTest : public MNNTestCase {
public:
virtual ~PowTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input(
{
4,
},
NCHW);
auto input_y = _Input(
{
4,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {-1.0, -2.0, -3.0, -4.0};
const float data_y[] = {2.0, 4.0, 6.0, 4.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 4 * sizeof(float));
memcpy(ptr_y, data_y, 4 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _Pow(input_x, input_y);
const std::vector<float> expectedOutput = {1.0, 16.0, 729.0, 256.0};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<float>();
2019-12-27 22:16:57 +08:00
if (!checkVector<float>(gotOutput, expectedOutput.data(), 4, 0.01)) {
MNN_ERROR("PowTest test failed!\n");
return false;
}
return true;
}
};
class MinimumTest : public MNNTestCase {
public:
virtual ~MinimumTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input(
{
4,
},
NCHW);
auto input_y = _Input(
{
4,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {-1.0, -2.0, -3.0, -4.0};
const float data_y[] = {2.0, 4.0, 6.0, 8.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 4 * sizeof(float));
memcpy(ptr_y, data_y, 4 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _Minimum(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<float> expectedOutput = {-1.0, -2.0, -3.0, -4.0};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<float>();
2019-12-27 22:16:57 +08:00
if (!checkVector<float>(gotOutput, expectedOutput.data(), 4, 0.01)) {
MNN_ERROR("MinimumTest test failed!\n");
return false;
}
return true;
}
};
class MaximumTest : public MNNTestCase {
public:
virtual ~MaximumTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input(
{
4,
},
NCHW);
auto input_y = _Input(
{
4,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {-1.0, -2.0, -3.0, -4.0};
const float data_y[] = {2.0, 4.0, 6.0, 8.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 4 * sizeof(float));
memcpy(ptr_y, data_y, 4 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _Maximum(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<float> expectedOutput = {2.0, 4.0, 6.0, 8.0};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<float>();
2019-12-27 22:16:57 +08:00
if (!checkVector<float>(gotOutput, expectedOutput.data(), 4, 0.01)) {
MNN_ERROR("MaximumTest test failed!\n");
return false;
}
return true;
}
};
class BiasAddTest : public MNNTestCase {
public:
virtual ~BiasAddTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input({4, 2}, NCHW);
auto input_y = _Input(
{
2,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0};
const float data_y[] = {1.0, 2.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 8 * sizeof(float));
memcpy(ptr_y, data_y, 2 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _BiasAdd(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<float> expectedOutput = {0.0, 0.0, -2.0, -2.0, -4.0, -4.0, -6.0, -6.0};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<float>();
2019-12-27 22:16:57 +08:00
if (!checkVector<float>(gotOutput, expectedOutput.data(), 8, 0.01)) {
MNN_ERROR("BiasAddTest test failed!\n");
return false;
}
return true;
}
};
class GreaterTest : public MNNTestCase {
public:
virtual ~GreaterTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input({4, 2}, NCHW);
auto input_y = _Input(
{
2,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
const float data_y[] = {3.0, 4.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 8 * sizeof(float));
memcpy(ptr_y, data_y, 2 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _Greater(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<int> expectedOutput = {0, 0, 0, 0, 1, 1, 1, 1};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<int>();
2019-12-27 22:16:57 +08:00
if (!checkVector<int>(gotOutput, expectedOutput.data(), 8, 0)) {
MNN_ERROR("GreaterTest test failed!\n");
return false;
}
return true;
}
};
class GreaterEqualTest : public MNNTestCase {
public:
virtual ~GreaterEqualTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input({4, 2}, NCHW);
auto input_y = _Input(
{
2,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
const float data_y[] = {3.0, 4.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 8 * sizeof(float));
memcpy(ptr_y, data_y, 2 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _GreaterEqual(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<int> expectedOutput = {0, 0, 1, 1, 1, 1, 1, 1};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<int>();
2019-12-27 22:16:57 +08:00
if (!checkVector<int>(gotOutput, expectedOutput.data(), 8, 0)) {
MNN_ERROR("GreaterEqualTest test failed!\n");
return false;
}
return true;
}
};
class LessTest : public MNNTestCase {
public:
virtual ~LessTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input({4, 2}, NCHW);
auto input_y = _Input(
{
2,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
const float data_y[] = {3.0, 4.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 8 * sizeof(float));
memcpy(ptr_y, data_y, 2 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _Less(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<int> expectedOutput = {1, 1, 0, 0, 0, 0, 0, 0};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<int>();
2019-12-27 22:16:57 +08:00
if (!checkVector<int>(gotOutput, expectedOutput.data(), 8, 0)) {
MNN_ERROR("LessTest test failed!\n");
return false;
}
return true;
}
};
class FloorDivTest : public MNNTestCase {
public:
virtual ~FloorDivTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input({4, 2}, NCHW);
auto input_y = _Input(
{
2,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
2020-11-05 16:41:56 +08:00
const float data_x[] = {-1.0, -2.0, -3.0, -4.0, 5.0, 6.0, 7.0, 8.001};
2019-12-27 22:16:57 +08:00
const float data_y[] = {3.0, 4.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 8 * sizeof(float));
2020-11-05 16:41:56 +08:00
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_y, data_y, 2 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _FloorDiv(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<float> expectedOutput = {-1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 2.0, 2.0};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<float>();
2019-12-27 22:16:57 +08:00
if (!checkVector<float>(gotOutput, expectedOutput.data(), 8, 0.01)) {
MNN_ERROR("FloorDivTest test failed!\n");
2020-11-05 16:41:56 +08:00
for (int i = 0; i < expectedOutput.size(); ++i) {
printf("%f - %f\n", expectedOutput[i], gotOutput[i]);
}
2019-12-27 22:16:57 +08:00
return false;
}
return true;
}
};
class SquaredDifferenceTest : public MNNTestCase {
public:
virtual ~SquaredDifferenceTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input({4, 2}, NCHW);
auto input_y = _Input(
{
2,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {-1.0, -2.0, -3.0, -4.0, 5.0, 6.0, 7.0, 8.0};
const float data_y[] = {3.0, 4.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 8 * sizeof(float));
2020-11-05 16:41:56 +08:00
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_y, data_y, 2 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _SquaredDifference(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<float> expectedOutput = {16.0, 36.0, 36.0, 64.0, 4.0, 4.0, 16.0, 16.0};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<float>();
2019-12-27 22:16:57 +08:00
if (!checkVector<float>(gotOutput, expectedOutput.data(), 8, 0.01)) {
MNN_ERROR("SquaredDifferenceTest test failed!\n");
return false;
}
return true;
}
};
class EqualTest : public MNNTestCase {
public:
virtual ~EqualTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input({4, 2}, NCHW);
auto input_y = _Input(
{
2,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
const float data_y[] = {3.0, 4.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 8 * sizeof(float));
memcpy(ptr_y, data_y, 2 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _Equal(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<int> expectedOutput = {0, 0, 1, 1, 0, 0, 0, 0};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<int>();
2019-12-27 22:16:57 +08:00
if (!checkVector<int>(gotOutput, expectedOutput.data(), 8, 0)) {
MNN_ERROR("EqualTest test failed!\n");
return false;
}
return true;
}
};
class LessEqualTest : public MNNTestCase {
public:
virtual ~LessEqualTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input({4, 2}, NCHW);
auto input_y = _Input(
{
2,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
const float data_y[] = {3.0, 4.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 8 * sizeof(float));
memcpy(ptr_y, data_y, 2 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _LessEqual(input_x, input_y);
2019-12-27 22:16:57 +08:00
const std::vector<int> expectedOutput = {1, 1, 1, 1, 0, 0, 0, 0};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<int>();
2019-12-27 22:16:57 +08:00
if (!checkVector<int>(gotOutput, expectedOutput.data(), 8, 0)) {
MNN_ERROR("LessEqualTest test failed!\n");
return false;
}
return true;
}
};
class FloorModTest : public MNNTestCase {
public:
virtual ~FloorModTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input({4, 2}, NCHW);
auto input_y = _Input(
{
2,
},
NCHW);
2019-12-27 22:16:57 +08:00
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
2020-11-05 16:41:56 +08:00
const float data_x[] = {-1.0f, -2.0f, -3.0f, -4.0f, 5.0f, 6.0f, 7.0f, 8.00001f};
const float data_y[] = {3.0f, 4.0f};
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
2019-12-27 22:16:57 +08:00
memcpy(ptr_x, data_x, 8 * sizeof(float));
memcpy(ptr_y, data_y, 2 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _FloorMod(input_x, input_y);
const std::vector<float> expectedOutput = {2.0f, 2.0f, 0.0f, 0.0f, 2.0f, 2.0f, 1.0f, 0.0f};
auto gotOutput = output->readMap<float>();
2019-12-27 22:16:57 +08:00
if (!checkVector<float>(gotOutput, expectedOutput.data(), 8, 0.01)) {
MNN_ERROR("FloorMod test failed!\n");
2020-11-05 16:41:56 +08:00
for (int i = 0; i < expectedOutput.size(); ++i) {
printf("%f - %f\n", expectedOutput[i], gotOutput[i]);
}
2019-12-27 22:16:57 +08:00
return false;
2019-04-17 10:49:11 +08:00
}
return true;
2019-04-17 10:49:11 +08:00
}
};
class Atan2Test : public MNNTestCase {
public:
virtual ~Atan2Test() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input({4, 2}, NCHW);
auto input_y = _Input(
{
2,
},
NCHW);
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const float data_x[] = {-1.0, -2.0, -3.0, -4.0, 5.0, 6.0, 7.0, 8.0};
const float data_y[] = {3.0, 4.0};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<float>();
auto ptr_y = input_y->writeMap<float>();
memcpy(ptr_x, data_x, 8 * sizeof(float));
memcpy(ptr_y, data_y, 2 * sizeof(float));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _Atan2(input_x, input_y);
const std::vector<float> expectedOutput = {-0.32175055, -0.4636476, -0.7853982, -0.7853982,
1.0303768, 0.98279375, 1.1659045, 1.1071488};
auto gotOutput = output->readMap<float>();
if (!checkVector<float>(gotOutput, expectedOutput.data(), 8, 0.01)) {
MNN_ERROR("Atan2Test test failed!\n");
return false;
}
return true;
}
};
class LogicalOrTest : public MNNTestCase {
public:
virtual ~LogicalOrTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input({4, 2}, NCHW, halide_type_of<int>());
auto input_y = _Input(
{
2,
},
NCHW, halide_type_of<int>());
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const int data_x[] = {true, false, true, false, false, true, true, false};
const int data_y[] = {true, false};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<int>();
auto ptr_y = input_y->writeMap<int>();
memcpy(ptr_x, data_x, 8 * sizeof(int));
memcpy(ptr_y, data_y, 2 * sizeof(int));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _LogicalOr(input_x, input_y);
const std::vector<int> expectedOutput = {true, false, true, false, true, true, true, false};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<int>();
if (!checkVector<int>(gotOutput, expectedOutput.data(), 8, 0)) {
MNN_ERROR("LogicalOrTest test failed!\n");
return false;
}
return true;
}
};
class NotEqualTest : public MNNTestCase {
public:
virtual ~NotEqualTest() = default;
virtual bool run() {
2020-11-05 16:41:56 +08:00
auto input_x = _Input({4, 2}, NCHW, halide_type_of<int>());
auto input_y = _Input(
{
2,
},
NCHW, halide_type_of<int>());
input_x->setName("input_x");
input_y->setName("input_y");
// set input data
const int data_x[] = {true, false, true, false, false, true, true, false};
const int data_y[] = {true, false};
2020-11-05 16:41:56 +08:00
auto ptr_x = input_x->writeMap<int>();
auto ptr_y = input_y->writeMap<int>();
memcpy(ptr_x, data_x, 8 * sizeof(int));
memcpy(ptr_y, data_y, 2 * sizeof(int));
input_x->unMap();
input_y->unMap();
2020-11-05 16:41:56 +08:00
auto output = _NotEqual(input_x, input_y);
const std::vector<int> expectedOutput = {false, false, false, false, true, true, false, false};
2020-11-05 16:41:56 +08:00
auto gotOutput = output->readMap<int>();
if (!checkVector<int>(gotOutput, expectedOutput.data(), 8, 0)) {
MNN_ERROR("NotEqualTest test failed!\n");
return false;
}
return true;
}
};
2020-05-28 15:10:53 +08:00
class SubtractBroastTest : public MNNTestCase {
public:
virtual ~SubtractBroastTest() = default;
virtual bool run() {
auto input_x = _Input({560}, NCHW);
auto input_y = _Input({1, 20, 560}, NCHW);
input_x->setName("input_x");
input_y->setName("input_y");
std::vector<float> x0T(560);
2020-11-05 16:41:56 +08:00
std::vector<float> x1T(560 * 20);
2020-05-28 15:10:53 +08:00
auto x0 = input_x->writeMap<float>();
auto x1 = input_y->writeMap<float>();
2020-11-05 16:41:56 +08:00
for (int i = 0; i < 560; ++i) {
x0[i] = i / 1000.0f;
2020-05-28 15:10:53 +08:00
x0T[i] = x0[i];
}
2020-11-05 16:41:56 +08:00
for (int i = 0; i < 560 * 20; ++i) {
x1[i] = i / 1000.0f;
2020-05-28 15:10:53 +08:00
x1T[i] = x1[i];
}
auto output = _Subtract(input_x, input_y);
2020-11-05 16:41:56 +08:00
auto ptr = output->readMap<float>();
for (int i = 0; i < 20; ++i) {
for (int j = 0; j < 560; ++j) {
auto x0V = x0T[j];
auto x1V = x1T[j + i * 560];
auto y1V = ptr[j + i * 560];
2020-05-28 15:10:53 +08:00
auto target = x0V - x1V;
2020-11-05 16:41:56 +08:00
if (fabsf(target - y1V) > 0.01f) {
MNN_ERROR("SubtractTest broascast test failed: i:%d, j:%d, Right: %f - Compute: %f!\n", i, j, y1V,
target);
2020-05-28 15:10:53 +08:00
return false;
}
}
}
return true;
}
};
2020-02-26 09:57:17 +08:00
MNNTestSuiteRegister(BinaryBroadcastShapeTest, "op/binary/broadcastShapeTest");
2019-12-27 22:16:57 +08:00
MNNTestSuiteRegister(AddTest, "op/binary/add");
MNNTestSuiteRegister(SubtractTest, "op/binary/subtract");
MNNTestSuiteRegister(MultiplyTest, "op/binary/multiply");
MNNTestSuiteRegister(DivideTest, "op/binary/divide");
MNNTestSuiteRegister(PowTest, "op/binary/pow");
MNNTestSuiteRegister(MinimumTest, "op/binary/minimum");
MNNTestSuiteRegister(MaximumTest, "op/binary/maximum");
MNNTestSuiteRegister(BiasAddTest, "op/binary/biasadd");
MNNTestSuiteRegister(GreaterTest, "op/binary/greater");
MNNTestSuiteRegister(GreaterEqualTest, "op/binary/greaterequal");
MNNTestSuiteRegister(LessTest, "op/binary/less");
MNNTestSuiteRegister(FloorDivTest, "op/binary/floordiv");
MNNTestSuiteRegister(SquaredDifferenceTest, "op/binary/squareddifference");
MNNTestSuiteRegister(EqualTest, "op/binary/equal");
MNNTestSuiteRegister(LessEqualTest, "op/binary/lessequal");
MNNTestSuiteRegister(FloorModTest, "op/binary/floormod");
MNNTestSuiteRegister(Atan2Test, "op/binary/atan2");
MNNTestSuiteRegister(LogicalOrTest, "op/binary/logicalor");
MNNTestSuiteRegister(NotEqualTest, "op/binary/notqual");
2020-05-28 15:10:53 +08:00
MNNTestSuiteRegister(SubtractBroastTest, "op/binary/subtractBroastTest");