MNN/express/Optimizer.cpp

52 lines
1.5 KiB
C++
Raw Normal View History

//
// Optimizer.cpp
// MNN
//
// Created by MNN on 2019/08/20.
// Copyright © 2018, Alibaba Group Holding Limited
//
2019-12-27 22:16:57 +08:00
#include <MNN/expr/Optimizer.hpp>
#include "MergeOptimizer.hpp"
#include "core/Backend.hpp"
namespace MNN {
namespace Express {
Optimizer::Parameters::Parameters(int n) {
MNN_ASSERT(n > 0);
mValue = new float[n];
mSize = n;
}
Optimizer::Parameters::~Parameters() {
if (nullptr != mValue) {
delete[] mValue;
}
}
2019-12-27 22:16:57 +08:00
std::shared_ptr<Optimizer> Optimizer::create(Config config) {
const int numThread = config.numThread;
auto forwardType = config.forwardType;
if (forwardType != MNN_FORWARD_ALL) {
if (MNNGetExtraBackendCreator(forwardType) == nullptr) {
return nullptr;
}
return std::shared_ptr<Optimizer>(new MergeOptimizer(config.forwardType, numThread, nullptr));
}
auto device = config.device;
if (CPU == device) {
2019-12-27 22:16:57 +08:00
return std::shared_ptr<Optimizer>(new MergeOptimizer(MNN_FORWARD_CPU, numThread, nullptr));
}
if (GPU == device) {
std::vector<MNNForwardType> types {MNN_FORWARD_METAL, MNN_FORWARD_OPENCL, MNN_FORWARD_VULKAN, MNN_FORWARD_OPENGL};
for (auto type : types) {
auto creator = MNNGetExtraBackendCreator(type);
if (nullptr != creator) {
2019-12-27 22:16:57 +08:00
return std::shared_ptr<Optimizer>(new MergeOptimizer(type, numThread, nullptr));
}
}
}
return nullptr;
}
} // namespace Express
} // namespace MNN