mirror of https://github.com/alibaba/MNN.git
101 lines
3.3 KiB
C++
101 lines
3.3 KiB
C++
//
|
|
// VulkanOneHot.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2020/07/20.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "VulkanOneHot.hpp"
|
|
#include "core/Macro.h"
|
|
namespace MNN {
|
|
struct ConstBuffer {
|
|
ivec4 inputSize;
|
|
};
|
|
|
|
VulkanOneHot::VulkanOneHot(const Op* op, Backend* bn) : VulkanBasicExecution(bn) {
|
|
auto extra = static_cast<VulkanBackend*>(bn);
|
|
mAxis = op->main_as_OneHotParam()->axis();
|
|
mConstBuffer = extra->allocUniform();
|
|
mPipeline = extra->getPipeline("glsl_onehot_comp", {
|
|
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
|
|
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
|
|
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
|
|
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
|
|
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
|
|
});
|
|
mDescriptorSet.reset(mPipeline->createSet());
|
|
}
|
|
VulkanOneHot::~VulkanOneHot() {
|
|
auto extra = static_cast<VulkanBackend*>(backend());
|
|
extra->recycleUniform(mConstBuffer);
|
|
}
|
|
|
|
ErrorCode VulkanOneHot::onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
const VulkanCommandPool::Buffer* cmdBuffer) {
|
|
auto indices = inputs[0];
|
|
auto depthTensor = inputs[1];
|
|
auto onValueTensor = inputs[2];
|
|
auto offValueTensor = inputs[3];
|
|
|
|
int axis = mAxis;
|
|
if (axis < 0) {
|
|
axis += outputs[0]->dimensions();
|
|
}
|
|
int outerSize = 1;
|
|
for (int i = 0; i < axis; ++i) {
|
|
outerSize *= indices->length(i);
|
|
}
|
|
const int depth = outputs[0]->length(axis);
|
|
const int innerSize = indices->elementSize() / outerSize;
|
|
const auto indicesPtr = indices->host<int>();
|
|
auto total = outputs[0]->elementSize();
|
|
|
|
auto dataType = onValueTensor->getType();
|
|
auto offDataType = offValueTensor->getType();
|
|
auto vkBn = static_cast<VulkanBackend*>(backend());
|
|
MNN_ASSERT(dataType == offDataType);
|
|
// Set Const Buffer
|
|
{
|
|
auto pool = (ConstBuffer*)mConstBuffer->map();
|
|
::memset(pool, 0, sizeof(ConstBuffer));
|
|
pool->inputSize[0] = outerSize;
|
|
pool->inputSize[1] = depth;
|
|
pool->inputSize[2] = innerSize;
|
|
pool->inputSize[3] = total;
|
|
mConstBuffer->unmap();
|
|
}
|
|
|
|
// Set Command Buffer
|
|
{
|
|
auto outputT = vkBn->getBuffer(outputs[0]);
|
|
auto inputT = vkBn->getBuffer(indices);
|
|
auto onT = vkBn->getBuffer(onValueTensor);
|
|
auto offT = vkBn->getBuffer(offValueTensor);
|
|
mDescriptorSet->writeBuffer(outputT, 0);
|
|
mDescriptorSet->writeBuffer(inputT, 1);
|
|
cmdBuffer->barrierSource(inputT);
|
|
mDescriptorSet->writeBuffer(onT, 2);
|
|
mDescriptorSet->writeBuffer(offT, 3);
|
|
mDescriptorSet->writeBuffer(mConstBuffer->buffer(), 4, mConstBuffer->size());
|
|
mPipeline->bind(cmdBuffer->get(), mDescriptorSet->get());
|
|
vkCmdDispatch(cmdBuffer->get(), UP_DIV(total, 256), 1, 1);
|
|
}
|
|
return NO_ERROR;
|
|
}
|
|
|
|
class VulkanOneHotCreator : public VulkanBackend::Creator {
|
|
public:
|
|
virtual VulkanBasicExecution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const MNN::Op* op,
|
|
Backend* backend) const override {
|
|
return new VulkanOneHot(op, backend);
|
|
}
|
|
};
|
|
|
|
static bool gResistor = []() {
|
|
VulkanBackend::addCreator(OpType_OneHot, new VulkanOneHotCreator);
|
|
return true;
|
|
}();
|
|
|
|
} // namespace MNN
|