MNN/test/op/SvdTest.cpp

157 lines
8.4 KiB
C++

//
// SvdTest.cpp
// MNNTests
//
// Created by MNN on 2022/07/14.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <MNN/expr/Expr.hpp>
#include <MNN/expr/ExprCreator.hpp>
#include "MNNTestSuite.h"
#include "TestUtils.h"
#include <MNN/AutoTime.hpp>
using namespace MNN::Express;
class SvdTest : public MNNTestCase {
public:
virtual ~SvdTest() = default;
virtual bool run(int precision) {
{
{
const float inpudata[] = {-4.1639902e-02, -1.3293471e-03, -9.3426127e-03,
-5.4161578e-02, 7.3445633e-02, 2.8240314e-01,
-2.8891024e-01, 9.4372004e-02, 1.6786221e+00};
auto input = _Const(inpudata, {3, 3}, NHWC, halide_type_of<float>());
const float expected_w[] = {1.7306069 , 0.05702754, 0.04242307};
const float expected_u[] = {-0.00127122, -0.11249565, -0.9936514,
0.16840328, -0.9794852 , 0.11067636,
0.9857174 , 0.16719346, -0.02018979};
const float expected_vt[] = {-0.1697972 , 0.06090027, 0.98359585,
0.16537347, -0.9821745 , 0.08936048,
0.97150445, 0.1778338 , 0.15669909};
auto outputs = _Svd(input);
auto w = outputs[0];
auto u = outputs[1];
auto vt = outputs[2];
{
MNN::AutoTime ___t(__LINE__, __func__);
w->readMap<float>();
}
if (!checkVector<float>(w->readMap<float>(), expected_w, 3, 1e-3)) {
MNN_ERROR("SvdTest w is wrong!\n");
return false;
}
if (!checkVector<float>(u->readMap<float>(), expected_u, 9, 1e-3)) {
MNN_ERROR("SvdTest u is wrong!\n");
return false;
}
if (!checkVector<float>(vt->readMap<float>(), expected_vt, 9, 1e-3)) {
MNN_ERROR("SvdTest vt is wrong!\n");
return false;
}
}
}
{
const float inpudata[] = {
1.0657064e+05, -6.6985586e+03, 4.1599282e+03, -1.0666493e+03,
7.3605914e+04, -5.8630166e+03, 1.6793969e+04, -1.0292224e+04,
2.0765317e+03, -6.6985571e+03, 1.9655250e+05, -1.4522734e+04,
-1.8006883e+02, 1.2425942e+04, -9.8978143e+02, -6.7625410e+03,
3.0371273e+04, -1.9593423e+03, 4.1599297e+03, -1.4522729e+04,
1.4451375e+04, 1.1182600e+02, -7.7167476e+03, 6.1466479e+02,
2.1863223e+03, -1.5707598e+03, 2.7005537e+03, -1.0666493e+03,
-1.8006879e+02, 1.1182603e+02, 1.4622133e+05, 1.9786536e+03,
-1.5760797e+02, -1.7559078e+03, -2.3280170e+02, 5.4449776e+01,
7.3605914e+04, 1.2425944e+04, -7.7167441e+03, 1.9786536e+03,
6.1143047e+04, -4.3490020e+03, 1.0106666e+04, -5.7797246e+03,
-5.0167773e+02, -5.8630161e+03, -9.8978302e+02, 6.1466534e+02,
-1.5760800e+02, -4.3490000e+03, 1.4021176e+04, -7.7215833e+02,
-2.9290234e+02, -6.2441797e+02, 1.6793969e+04, -6.7625400e+03,
2.1863220e+03, -1.7559078e+03, 1.0106666e+04, -7.7215778e+02,
3.5546367e+03, -2.3252129e+03, 5.8544067e+02, -1.0292224e+04,
3.0371273e+04, -1.5707607e+03, -2.3280170e+02, -5.7797227e+03,
-2.9290234e+02, -2.3252129e+03, 6.8266089e+03, -4.3708252e+02,
2.0765312e+03, -1.9593418e+03, 2.7005518e+03, 5.4449791e+01,
-5.0167871e+02, -6.2441797e+02, 5.8544067e+02, -4.3708228e+02,
5.9221118e+02
};
auto input = _Const(inpudata, {9, 9}, NHWC, halide_type_of<float>());
const float expected_w[] = {
2.0382961e+05, 1.6452105e+05, 1.4627997e+05, 1.8760078e+04,
1.3720105e+04, 1.4239045e+03, 8.2915063e+02, 5.6911682e+02,
5.2337189e+00
};
const float expected_u[] = {
-5.30059375e-02, 7.95129955e-01, -1.37356063e-02, 2.99445540e-01,
1.09734191e-02, -3.34383070e-01, -2.79517472e-01, -2.90623546e-01,
-2.40407158e-02, 9.81155992e-01, 2.55002081e-02, -1.73566607e-03,
1.03179410e-01, 9.52578988e-03, 9.60456207e-02, -1.25066742e-01,
9.93702002e-03, -3.12699117e-02, -8.04865360e-02, -8.20002332e-03,
-1.65763253e-04, 8.26010704e-01, -2.06464007e-02, 3.68608236e-01,
3.77096266e-01, 6.95602000e-02, -1.66800454e-01, 1.31534311e-04,
6.72200741e-03, 9.99769986e-01, 1.01281721e-02, 3.49767128e-04,
-1.18328175e-02, -1.14854360e-02, 6.28872868e-03, -7.67501653e-04,
5.37704267e-02, 5.87597191e-01, 9.68683884e-03, -4.13454801e-01,
6.00379966e-02, 4.02822673e-01, 5.20126045e-01, 2.06396833e-01,
4.24040370e-02, -5.01791388e-03, -4.86786626e-02, -8.10053956e-04,
4.96227853e-02, 9.94751930e-01, -4.80379835e-02, 1.62436962e-02,
-8.39837827e-03, 5.44850230e-02, -3.75390835e-02, 1.19874418e-01,
-1.31259384e-02, 1.21245451e-01, 2.13273242e-03, -2.00914875e-01,
-2.48641193e-01, 9.31195736e-01, -5.39436471e-03, 1.53572425e-01,
-7.01504499e-02, -1.21269596e-03, 6.50331154e-02, -5.97846434e-02,
-7.15259433e-01, 6.51730597e-01, 2.75142882e-02, 1.63044885e-01,
-1.16258143e-02, 8.63685086e-03, 1.19157114e-04, 1.57965228e-01,
-5.14553860e-02, 1.62525818e-01, -8.07750002e-02, -2.91586504e-03,
9.69144881e-01
};
const float expected_vt[] = {
-5.3006630e-02, 9.8116404e-01, -8.0487266e-02, 1.3153395e-04,
5.3771134e-02, -5.0179465e-03, -3.7539378e-02, 1.5357375e-01,
-1.1625915e-02, 7.9513741e-01, 2.5500454e-02, -8.2001043e-03,
6.7220661e-03, 5.8760232e-01, -4.8679110e-02, 1.1987541e-01,
-7.0151038e-02, 8.6369254e-03, -1.3735750e-02, -1.7356840e-03,
-1.6576078e-04, 9.9977797e-01, 9.6869105e-03, -8.1005890e-04,
-1.3126039e-02, -1.2127035e-03, 1.1915851e-04, 2.9944834e-01,
1.0318064e-01, 8.2601863e-01, 1.0128255e-02, -4.1345882e-01,
4.9623299e-02, 1.2124645e-01, 6.5033823e-02, 1.5796673e-01,
1.0973417e-02, 9.5258495e-03, -2.0646475e-02, 3.4976407e-04,
6.0038477e-02, 9.9475998e-01, 2.1326796e-03, -5.9785094e-02,
-5.1456086e-02, -3.3438501e-01, 9.6046828e-02, 3.6861011e-01,
-1.1832938e-02, 4.0282562e-01, -4.8038427e-02, -2.0091695e-01,
-7.1526390e-01, 1.6252735e-01, -2.7952069e-01, -1.2506786e-01,
3.7710068e-01, -1.1485523e-02, 5.2013028e-01, 1.6243761e-02,
-2.4864188e-01, 6.5173382e-01, -8.0774382e-02, -2.9062533e-01,
9.9372808e-03, 6.9560230e-02, 6.2887804e-03, 2.0639756e-01,
-8.3984155e-03, 9.3120378e-01, 2.7513305e-02, -2.9159351e-03,
-2.4040809e-02, -3.1270202e-02, -1.6680241e-01, -7.6749898e-04,
4.2403858e-02, 5.4485701e-02, -5.3939600e-03, 1.6304553e-01,
9.6915299e-01
};
auto outputs = _Svd(input);
auto w = outputs[0];
auto u = outputs[1];
auto vt = outputs[2];
w->getInfo();
{
MNN::AutoTime ___t(__LINE__, __func__);
w->readMap<float>();
}
if (!checkVector<float>(w->readMap<float>(), expected_w, 9, 5)) {
MNN_ERROR("SvdTest w is wrong!\n");
return false;
}
if (!checkVector<float>(u->readMap<float>(), expected_u, 81, 1e-3)) {
MNN_ERROR("SvdTest u is wrong!\n");
return false;
}
if (!checkVector<float>(vt->readMap<float>(), expected_vt, 81, 1e-3)) {
MNN_ERROR("SvdTest vt is wrong!\n");
return false;
}
}
return true;
}
};
MNNTestSuiteRegister(SvdTest, "op/svd");