MNN/source/backend/vulkan/buffer/execution/VulkanFuse.cpp

157 lines
6.1 KiB
C++
Raw Normal View History

2023-12-04 11:12:20 +08:00
//
// Vulkan.cpp
// MNN
//
// Created by MNN on 2023/07/25.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <stdio.h>
#include "VulkanBasicExecution.hpp"
#include "core/OpCommonUtils.hpp"
namespace MNN {
struct ConstBuffer {
ivec4 inShape; // inW, inH
};
class VulkanFuse : public VulkanBasicExecution {
public:
VulkanFuse(const Extra* extra, Backend* bn, int inputSize, int outputSize) : VulkanBasicExecution(bn) {
auto vkBn = static_cast<VulkanBackend*>(bn);
auto factory = vkBn->getPipelineFactory();
mOutputBinding.resize(outputSize);
mInputBinding.resize(inputSize);
mGroupSize.resize(3);
// Find shader
const uint8_t* data = nullptr;
size_t dataSize = 0;
for (int i=0; i<extra->attr()->size(); ++i) {
auto attr = extra->attr()->GetAs<Attribute>(i);
if (attr->key()->str() == "spirv") {
data = (uint8_t*)attr->tensor()->int8s()->data();
dataSize = attr->tensor()->int8s()->size();
break;
}
}
for (int i=0; i<extra->attr()->size(); ++i) {
auto attr = extra->attr()->GetAs<Attribute>(i);
if (attr->key()->str() == "group_size") {
auto ptr = attr->tensor()->int32s()->data();
mGroupSize[0] = ptr[0];
mGroupSize[1] = ptr[1];
mGroupSize[2] = ptr[2];
}
}
std::vector<VkDescriptorType> types;
int maxIndex = -1;
for (int i=0; i<extra->attr()->size(); ++i) {
auto attr = extra->attr()->GetAs<Attribute>(i);
if (attr->key()->str() == "input") {
maxIndex = ALIMAX(maxIndex, attr->i());
} else if (attr->key()->str() == "const") {
maxIndex = ALIMAX(maxIndex, attr->i());
}
}
types.resize(maxIndex+1);
for (int i=0; i<extra->attr()->size(); ++i) {
auto attr = extra->attr()->GetAs<Attribute>(i);
if (attr->key()->str() == "input") {
auto list = attr->list()->i()->data();
if (0 == list[0]) {
mInputBinding[list[1]] = attr->i();
} else {
mOutputBinding[list[1]] = attr->i();
}
if (attr->b()) {
types[attr->i()] = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
} else {
types[attr->i()] = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
}
continue;
}
if (attr->key()->str() == "const") {
auto usageBit = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
if (attr->b()) {
types[attr->i()] = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
} else {
usageBit = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
types[attr->i()] = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
}
auto b = attr->tensor();
void* result = nullptr;
size_t bufferSize = 0;
switch (b->dataType()) {
case DataType_DT_FLOAT:
result = (void*)b->float32s()->Data();
bufferSize = b->float32s()->size() * sizeof(float);
break;
case DataType_DT_INT32:
result = (void*)b->int32s()->Data();
bufferSize = b->int32s()->size() * sizeof(float);
break;
default:
MNN_ASSERT(false);
break;
}
std::shared_ptr<VulkanBuffer> vkBuffer(new VulkanBuffer(vkBn->getMemoryPool(), false, bufferSize, result, usageBit));
mConstIndides.emplace_back(std::make_pair(attr->i(), vkBuffer));
continue;
}
}
mPipeline = factory->createComputePipeline(data, dataSize, types, std::vector<uint32_t>{});
mDescriptorSet.reset(mPipeline->createSet());
}
virtual ~VulkanFuse() {
// Remove set firstly before destroy pipeline
mDescriptorSet = nullptr;
}
virtual ErrorCode onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const VulkanCommandPool::Buffer* cmdBuffer) override {
auto vkBn = static_cast<VulkanBackend*>(backend());
for (int i=0; i<inputs.size(); ++i) {
mDescriptorSet->writeBuffer(vkBn->getBuffer(inputs[i]), mInputBinding[i]);
}
for (int i=0; i<outputs.size(); ++i) {
mDescriptorSet->writeBuffer(vkBn->getBuffer(outputs[i]), mOutputBinding[i]);
}
for (auto& iter : mConstIndides) {
mDescriptorSet->writeBuffer(iter.second->buffer(), iter.first, iter.second->size());
}
mPipeline->bind(cmdBuffer->get(), mDescriptorSet->get());
vkCmdDispatch(cmdBuffer->get(), mGroupSize[0], mGroupSize[1], mGroupSize[2]);
return NO_ERROR;
}
private:
std::vector<int> mGroupSize;
std::vector<int> mInputBinding;
std::vector<int> mOutputBinding;
std::vector<std::pair<int, VulkanBuffer*>> mInputUniforms;
std::vector<std::pair<int, std::shared_ptr<VulkanBuffer>>> mConstIndides;
SharedPtr<VulkanPipeline> mPipeline;
std::shared_ptr<VulkanLayout::DescriptorSet> mDescriptorSet;
};
class VulkanFuseCreator : public VulkanBackend::Creator {
public:
virtual VulkanBasicExecution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const MNN::Op* op,
Backend* backend) const override {
auto extra = op->main_as_Extra();
if (nullptr == extra) {
return nullptr;
}
if (nullptr == extra->attr()) {
return nullptr;
}
return new VulkanFuse(extra, backend, (int)inputs.size(), (int)outputs.size());
}
};
static bool gResistor = []() {
VulkanBackend::addCreator(OpType_Extra, new VulkanFuseCreator);
return true;
}();
} // namespace MNN