MNN/source/shape/ShapePlugin.cpp

60 lines
2.0 KiB
C++
Raw Normal View History

// ShapePlugin.cpp
// MNN
//
// Created by MNN on 2020/04/05.
// Copyright © 2018, Alibaba Group Holding Limited
//
2020-11-05 16:41:56 +08:00
#include "shape/SizeComputer.hpp"
#include "core/TensorUtils.hpp"
2020-12-15 14:12:35 +08:00
#ifdef MNN_WITH_PLUGIN
#include "MNN/plugin/PluginShapeInference.hpp"
2020-12-15 14:12:35 +08:00
#endif // MNN_WITH_PLUGIN
namespace MNN {
2020-12-15 14:12:35 +08:00
#ifdef MNN_WITH_PLUGIN
static std::shared_ptr<plugin::InferShapeKernel> getInferShapeKernel( // NOLINT
const std::string& name) { // NOLINT
return std::shared_ptr<plugin::InferShapeKernel>( // NOLINT
plugin::InferShapeKernelRegister::get(name));
}
2020-12-15 14:12:35 +08:00
#endif // MNN_WITH_PLUGIN
class PluginSizeComputer : public SizeComputer {
public:
virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const override {
// Plugin op should has inputs or outputs, or both.
MNN_CHECK(inputs.size() > 0 || outputs.size() > 0, // NOLINT
"Plugin op should has inputs or outputs, or both of them.");
2020-12-15 14:12:35 +08:00
#ifdef MNN_WITH_PLUGIN
const Plugin* plugin_param = op->main_as<Plugin>();
std::shared_ptr<plugin::InferShapeKernel> kernel = // NOLINT
getInferShapeKernel(plugin_param->type()->str());
MNN_CHECK(nullptr != kernel.get(), // NOLINT
"Shape inference kernel has not been registered for plugin op.");
plugin::InferShapeContext ctx(inputs, outputs);
for (const Attribute* attr : *(plugin_param->attr())) {
ctx.setAttr(attr->key()->str(), attr);
}
bool status = kernel->compute(&ctx);
if (!status) {
MNN_ERROR("Plugin op infer shape failed with false returned.");
}
return status;
2020-12-15 14:12:35 +08:00
#else
MNN_ERROR("Plugin is not supported. Please recompile with `MNN_WITH_PLUGIN` enabled.");
return false;
#endif // MNN_WITH_PLUGIN
}
};
REGISTER_SHAPE(PluginSizeComputer, OpType_Plugin);
} // namespace MNN