mirror of https://github.com/alibaba/MNN.git
60 lines
2.0 KiB
C++
60 lines
2.0 KiB
C++
// ShapePlugin.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2020/04/05.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
|
|
#include "shape/SizeComputer.hpp"
|
|
#include "core/TensorUtils.hpp"
|
|
|
|
#ifdef MNN_WITH_PLUGIN
|
|
#include "MNN/plugin/PluginShapeInference.hpp"
|
|
#endif // MNN_WITH_PLUGIN
|
|
|
|
namespace MNN {
|
|
|
|
#ifdef MNN_WITH_PLUGIN
|
|
static std::shared_ptr<plugin::InferShapeKernel> getInferShapeKernel( // NOLINT
|
|
const std::string& name) { // NOLINT
|
|
return std::shared_ptr<plugin::InferShapeKernel>( // NOLINT
|
|
plugin::InferShapeKernelRegister::get(name));
|
|
}
|
|
#endif // MNN_WITH_PLUGIN
|
|
|
|
class PluginSizeComputer : public SizeComputer {
|
|
public:
|
|
virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
|
|
const std::vector<Tensor*>& outputs) const override {
|
|
// Plugin op should has inputs or outputs, or both.
|
|
MNN_CHECK(inputs.size() > 0 || outputs.size() > 0, // NOLINT
|
|
"Plugin op should has inputs or outputs, or both of them.");
|
|
|
|
#ifdef MNN_WITH_PLUGIN
|
|
const Plugin* plugin_param = op->main_as<Plugin>();
|
|
std::shared_ptr<plugin::InferShapeKernel> kernel = // NOLINT
|
|
getInferShapeKernel(plugin_param->type()->str());
|
|
MNN_CHECK(nullptr != kernel.get(), // NOLINT
|
|
"Shape inference kernel has not been registered for plugin op.");
|
|
|
|
plugin::InferShapeContext ctx(inputs, outputs);
|
|
for (const Attribute* attr : *(plugin_param->attr())) {
|
|
ctx.setAttr(attr->key()->str(), attr);
|
|
}
|
|
bool status = kernel->compute(&ctx);
|
|
if (!status) {
|
|
MNN_ERROR("Plugin op infer shape failed with false returned.");
|
|
}
|
|
return status;
|
|
#else
|
|
MNN_ERROR("Plugin is not supported. Please recompile with `MNN_WITH_PLUGIN` enabled.");
|
|
return false;
|
|
#endif // MNN_WITH_PLUGIN
|
|
}
|
|
};
|
|
|
|
REGISTER_SHAPE(PluginSizeComputer, OpType_Plugin);
|
|
|
|
} // namespace MNN
|