MNN/source/backend/vulkan/buffer/execution/VulkanMatMul.cpp

137 lines
4.4 KiB
C++
Raw Normal View History

2022-12-30 15:18:58 +08:00
//
// VulkanMatMul.cpp
// MNN
//
// Created by MNN on 2020/03/06.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "VulkanMatMul.hpp"
#include "core/Macro.h"
#include "core/TensorUtils.hpp"
namespace MNN {
VulkanMatMul::VulkanMatMul(bool transposeA, bool transposeB, Backend* bn, bool hasBias) : VulkanBasicExecution(bn) {
mTransposeA = transposeA;
mTransposeB = transposeB;
2023-03-20 11:32:29 +08:00
auto vkbackend = static_cast<VulkanBackend*>(bn);
2023-12-27 17:26:44 +08:00
mParam = vkbackend->allocUniform();
2023-03-20 11:32:29 +08:00
mHasBias = hasBias;
if (!mHasBias) {
mPipeline = vkbackend->getPipeline("glsl_matmulunit_comp", {
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
2022-12-30 15:18:58 +08:00
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
2023-03-20 11:32:29 +08:00
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
2022-12-30 15:18:58 +08:00
});
} else {
2023-03-20 11:32:29 +08:00
mPipeline = vkbackend->getPipeline("glsl_matmulunit_HAS_BIAS_comp", {
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
2022-12-30 15:18:58 +08:00
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
});
}
2023-03-20 11:32:29 +08:00
mDescribe.reset(mPipeline->createSet());
2022-12-30 15:18:58 +08:00
}
2023-12-27 17:26:44 +08:00
VulkanMatMul::~ VulkanMatMul() {
auto vkbackend = static_cast<VulkanBackend*>(backend());
vkbackend->recycleUniform(mParam);
}
2022-12-30 15:18:58 +08:00
ErrorCode VulkanMatMul::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const VulkanCommandPool::Buffer *cmdBuffer) {
2023-03-20 11:32:29 +08:00
const Tensor* A = inputs[0];
const Tensor* B = inputs[1];
2022-12-30 15:18:58 +08:00
Tensor* C = outputs[0];
auto w0 = inputs[0]->length(1);
auto h0 = inputs[0]->length(0);
2023-03-20 11:32:29 +08:00
auto e = A->length(0);
auto h = B->length(1);
auto l = A->length(1);
2022-12-30 15:18:58 +08:00
if (mTransposeA) {
2023-03-20 11:32:29 +08:00
l = A->length(0);
e = A->length(1);
2022-12-30 15:18:58 +08:00
}
2023-03-20 11:32:29 +08:00
if (mTransposeB) {
h = B->length(0);
2022-12-30 15:18:58 +08:00
}
2023-03-20 11:32:29 +08:00
int totalSize = e * h;
auto param = reinterpret_cast<VulkanBatchMatMulInfo*>(mParam->map());
2023-12-27 17:26:44 +08:00
::memset(param, 0, sizeof(VulkanBatchMatMulInfo));
2023-03-20 11:32:29 +08:00
param->size[3] = 1;
param->size[0] = e;
param->size[1] = l;
param->size[2] = h;
param->stride_o[1] = 0;
param->stride_o[0] = h;
param->stride_o[2] = 1;
2023-07-18 09:36:26 +08:00
param->stride_c[0] = 0;
param->stride_c[1] = 0;
param->stride_c[2] = 1;
2023-03-20 11:32:29 +08:00
param->iter[0] = -1;
param->iter[1] = -1;
param->iter[2] = -1;
2022-12-30 15:18:58 +08:00
2023-03-20 11:32:29 +08:00
param->stride_a[2] = 0;
if (mTransposeA) {
param->stride_a[0] = 1;
param->stride_a[1] = e;
} else {
param->stride_a[0] = l;
param->stride_a[1] = 1;
2022-12-30 15:18:58 +08:00
}
2023-03-20 11:32:29 +08:00
param->stride_b[0] = 0;
if (mTransposeB) {
param->stride_b[1] = 1;
param->stride_b[2] = l;
} else {
param->stride_b[1] = h;
param->stride_b[2] = 1;
2022-12-30 15:18:58 +08:00
}
2023-03-20 11:32:29 +08:00
mParam->unmap();
auto vkBn = static_cast<VulkanBackend*>(backend());
mDescribe->writeBuffer(vkBn->getBuffer(C), 0);
mDescribe->writeBuffer(vkBn->getBuffer(A), 1);
mDescribe->writeBuffer(vkBn->getBuffer(B), 2);
int offset = 3;
if (inputs.size() > 2) {
mDescribe->writeBuffer(vkBn->getBuffer(inputs[2]), 3);
offset = 4;
2022-12-30 15:18:58 +08:00
}
2023-03-20 11:32:29 +08:00
// stride's
mDescribe->writeBuffer(vkBn->getBuffer(A), offset);
mDescribe->writeBuffer(vkBn->getBuffer(A), offset + 1);
mDescribe->writeBuffer(vkBn->getBuffer(A), offset + 2);
if (inputs.size() > 2) {
mDescribe->writeBuffer(vkBn->getBuffer(A), offset + 3);
2022-12-30 15:18:58 +08:00
}
2023-03-20 11:32:29 +08:00
mDescribe->writeBuffer(mParam->buffer(), offset * 2, mParam->size());
mPipeline->bind(cmdBuffer->get(), mDescribe->get());
vkCmdDispatch(cmdBuffer->get(), UP_DIV(totalSize,256), 1, 1);
2022-12-30 15:18:58 +08:00
return NO_ERROR;
}
class VulkanMatMulCreator : public VulkanBackend::Creator {
public:
virtual VulkanBasicExecution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const MNN::Op* op, Backend* bn) const override {
const auto mat = op->main_as_MatMul();
return new VulkanMatMul(mat->transposeA(), mat->transposeB(), bn, inputs.size() > 2);
}
};
static bool gResistor = []() {
VulkanBackend::addCreator(OpType_MatMul, new VulkanMatMulCreator);
return true;
}();
}