mirror of https://github.com/alibaba/MNN.git
55 lines
1.8 KiB
C++
55 lines
1.8 KiB
C++
|
// ShapePlugin.cpp
|
||
|
// MNN
|
||
|
//
|
||
|
// Created by MNN on 2020/04/05.
|
||
|
// Copyright © 2018, Alibaba Group Holding Limited
|
||
|
//
|
||
|
|
||
|
#ifdef MNN_WITH_PLUGIN
|
||
|
|
||
|
#include "core/SizeComputer.hpp"
|
||
|
#include "core/TensorUtils.hpp"
|
||
|
|
||
|
#include "MNN/plugin/PluginShapeInference.hpp"
|
||
|
|
||
|
namespace MNN {
|
||
|
|
||
|
static std::shared_ptr<plugin::InferShapeKernel> getInferShapeKernel( // NOLINT
|
||
|
const std::string& name) { // NOLINT
|
||
|
return std::shared_ptr<plugin::InferShapeKernel>( // NOLINT
|
||
|
plugin::InferShapeKernelRegister::get(name));
|
||
|
}
|
||
|
|
||
|
class PluginSizeComputer : public SizeComputer {
|
||
|
public:
|
||
|
virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
|
||
|
const std::vector<Tensor*>& outputs) const override {
|
||
|
// Plugin op should has inputs or outputs, or both.
|
||
|
MNN_CHECK(inputs.size() > 0 || outputs.size() > 0, // NOLINT
|
||
|
"Plugin op should has inputs or outputs, or both of them.");
|
||
|
|
||
|
const Plugin* plugin_param = op->main_as<Plugin>();
|
||
|
|
||
|
std::shared_ptr<plugin::InferShapeKernel> kernel = // NOLINT
|
||
|
getInferShapeKernel(plugin_param->type()->str());
|
||
|
MNN_CHECK(nullptr != kernel.get(), // NOLINT
|
||
|
"Shape inference kernel has not been registered for plugin op.");
|
||
|
|
||
|
plugin::InferShapeContext ctx(inputs, outputs);
|
||
|
for (const Attribute* attr : *(plugin_param->attr())) {
|
||
|
ctx.setAttr(attr->key()->str(), attr);
|
||
|
}
|
||
|
bool status = kernel->compute(&ctx);
|
||
|
if (!status) {
|
||
|
MNN_ERROR("Plugin op infer shape failed with false returned.");
|
||
|
}
|
||
|
return status;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
REGISTER_SHAPE(PluginSizeComputer, OpType_Plugin);
|
||
|
|
||
|
} // namespace MNN
|
||
|
|
||
|
#endif // MNN_WITH_PLUGIN
|