mirror of https://github.com/alibaba/MNN.git
92 lines
3.2 KiB
C++
92 lines
3.2 KiB
C++
//
|
|
// CPUPlugin.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2020/04/07.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "backend/cpu/CPUBackend.hpp"
|
|
#include "core/AutoStorage.h"
|
|
#include "core/Execution.hpp"
|
|
|
|
#ifdef MNN_WITH_PLUGIN
|
|
#include "MNN/plugin/PluginContext.hpp"
|
|
#include "MNN/plugin/PluginKernel.hpp"
|
|
#endif // MNN_WITH_PLUGIN
|
|
|
|
namespace MNN {
|
|
|
|
#ifdef MNN_WITH_PLUGIN
|
|
static std::shared_ptr<plugin::CPUComputeKernel> getCPUComputeKernel( // NOLINT
|
|
const std::string& name) { // NOLINT
|
|
return std::shared_ptr<plugin::CPUComputeKernel>( // NOLINT
|
|
plugin::ComputeKernelRegistry<plugin::CPUComputeKernel>::get(name));
|
|
}
|
|
|
|
class CPUPlugin : public Execution {
|
|
public:
|
|
CPUPlugin(std::unique_ptr<plugin::CPUKernelContext> ctx) // NOLINT
|
|
: Execution(ctx->backend()), ctx_(std::move(ctx)) {
|
|
kernel_ = getCPUComputeKernel(ctx_->op_type());
|
|
MNN_CHECK(nullptr != kernel_.get(), // NOLINT
|
|
"CPU compute kernel has not been registered for plugin op.");
|
|
kernel_->init(ctx_.get());
|
|
}
|
|
virtual ~CPUPlugin() = default;
|
|
|
|
virtual ErrorCode onExecute(const std::vector<Tensor*>& inputs, // NOLINT
|
|
const std::vector<Tensor*>& outputs) override;
|
|
|
|
private:
|
|
std::unique_ptr<plugin::CPUKernelContext> ctx_;
|
|
std::shared_ptr<plugin::CPUComputeKernel> kernel_;
|
|
};
|
|
|
|
ErrorCode CPUPlugin::onExecute(const std::vector<Tensor*>& inputs, // NOLINT
|
|
const std::vector<Tensor*>& outputs) {
|
|
// Setup new context with inputs and outputs.
|
|
plugin::CPUKernelContext ctx( // NOLINT
|
|
ctx_->op_type(), ctx_->backend(), inputs, outputs);
|
|
ctx.setAttrs(ctx_->getAttrs());
|
|
if (kernel_->compute(&ctx)) {
|
|
return NO_ERROR;
|
|
} else {
|
|
MNN_ERROR("Plugin kernel compute failed with false returned.");
|
|
return INVALID_VALUE;
|
|
}
|
|
}
|
|
#endif // MNN_WITH_PLUGIN
|
|
|
|
class CPUPluginCreator : public CPUBackend::Creator {
|
|
public:
|
|
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, // NOLINT
|
|
const std::vector<Tensor*>& outputs, // NOLINT
|
|
const MNN::Op* op, Backend* backend) const {
|
|
#ifdef MNN_WITH_PLUGIN
|
|
MNN_ASSERT(op->type() == OpType_Plugin);
|
|
// Plugin op should has inputs or outputs, or both of them.
|
|
MNN_CHECK(inputs.size() > 0 || outputs.size() > 0, // NOLINT
|
|
"Plugin op should has inputs or outputs, or both of them.");
|
|
|
|
const Plugin* plugin_param = op->main_as<Plugin>();
|
|
|
|
const std::string& op_type = plugin_param->type()->str();
|
|
std::unique_ptr<plugin::CPUKernelContext> ctx( // NOLINT
|
|
new plugin::CPUKernelContext(op_type, backend, inputs, outputs));
|
|
|
|
for (const Attribute* attr : *(plugin_param->attr())) {
|
|
ctx->setAttr(attr->key()->str(), attr);
|
|
}
|
|
return new CPUPlugin(std::move(ctx));
|
|
#else
|
|
MNN_ERROR("Plugin is not supported. Please recompile with `MNN_WITH_PLUGIN` enabled.");
|
|
return nullptr;
|
|
#endif // MNN_WITH_PLUGIN
|
|
}
|
|
};
|
|
|
|
REGISTER_CPU_OP_CREATOR(CPUPluginCreator, OpType_Plugin);
|
|
|
|
} // namespace MNN
|