mirror of https://github.com/alibaba/MNN.git
78 lines
2.4 KiB
C++
78 lines
2.4 KiB
C++
//
|
|
// RasterOutputTest.cpp
|
|
// MNNTests
|
|
//
|
|
// Created by MNN on 2020/12/29.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include <MNN/expr/ExprCreator.hpp>
|
|
#include "MNNTestSuite.h"
|
|
#include "MNN_generated.h"
|
|
#include <MNN/expr/Module.hpp>
|
|
#include "TestUtils.h"
|
|
|
|
using namespace MNN::Express;
|
|
using namespace MNN;
|
|
|
|
static std::shared_ptr<Module> _createModel() {
|
|
auto x = _Input({1, 3, 224, 224}, NCHW, halide_type_of<int>());
|
|
x->setName("Input");
|
|
auto y = _Transpose(x, {0, 1, 3, 2});
|
|
auto z = _Add(y, _Scalar<int>(1));
|
|
z->setName("Add");
|
|
auto q = _Negative(y);
|
|
auto p = _Transpose(q, {0, 3, 1, 2});
|
|
p->setName("Transpose");
|
|
std::unique_ptr<NetT> net(new NetT);
|
|
Variable::save({z, p}, net.get());
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
auto len = MNN::Net::Pack(builder, net.get());
|
|
builder.Finish(len);
|
|
return std::shared_ptr<Module>(Module::load({"Input"}, {"Add", "Transpose"}, builder.GetBufferPointer(), builder.GetSize()));
|
|
}
|
|
|
|
class RasterOutputTest : public MNNTestCase {
|
|
public:
|
|
virtual bool run(int precision) {
|
|
auto executor = cloneCurrentExecutor();
|
|
ExecutorScope scope(executor);
|
|
auto net = _createModel();
|
|
auto x = _Input({1, 3, 224, 224}, NCHW, halide_type_of<int>());
|
|
auto y = _Transpose(x, {0, 1, 3, 2});
|
|
auto z = _Add(y, _Scalar<int>(1));
|
|
auto q = _Negative(y);
|
|
auto p = _Transpose(q, {0, 3, 1, 2});
|
|
{
|
|
auto xPtr = x->writeMap<int>();
|
|
for (int v = 0; v < 1 * 3 * 224 * 224; ++ v) {
|
|
xPtr[v] = v;
|
|
}
|
|
x->unMap();
|
|
}
|
|
auto outputs = net->onForward({x});
|
|
{
|
|
auto dPtr = outputs[0]->readMap<int>();
|
|
auto zPtr = z->readMap<int>();
|
|
auto size = z->getInfo()->size;
|
|
for (int v = 0; v < size; ++v) {
|
|
if (zPtr[v] != dPtr[v]) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
{
|
|
auto dPtr = outputs[1]->readMap<int>();
|
|
auto zPtr = p->readMap<int>();
|
|
auto size = p->getInfo()->size;
|
|
for (int v = 0; v < size; ++v) {
|
|
if (zPtr[v] != dPtr[v]) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
MNNTestSuiteRegister(RasterOutputTest, "expr/RasterOutput");
|