mirror of https://github.com/alibaba/MNN.git
61 lines
2.2 KiB
C++
61 lines
2.2 KiB
C++
//
|
|
// VulkanMatMul.hpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2020/03/06.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#ifndef VulkanMatMul_hpp
|
|
#define VulkanMatMul_hpp
|
|
|
|
#include "VulkanMatrixMultier4x4.hpp"
|
|
namespace MNN {
|
|
|
|
class VulkanMatMul : public VulkanBasicExecution {
|
|
public:
|
|
class Reorder {
|
|
public:
|
|
struct nchwBuffer {
|
|
ivec4 size;
|
|
ivec4 stride;
|
|
};
|
|
Reorder(const VulkanBackend* bn, bool reorder, bool revert = false);
|
|
~ Reorder() {
|
|
// Do nothing
|
|
}
|
|
int computeMiddleBufferSize(int b, int h, int w, int c) const;
|
|
void encode(VkBuffer source, size_t sourceSize, VkBuffer middleBuffer, size_t middelBufferSize, const VulkanImage* dest, const VulkanCommandPool::Buffer* cmdBuffer, const nchwBuffer& buffer);
|
|
void revert(VkBuffer dest, size_t destSize, VkBuffer middleBuffer, size_t middelBufferSize, const VulkanImage* source, const VulkanCommandPool::Buffer* cmdBuffer, const nchwBuffer& buffer);
|
|
private:
|
|
const VulkanPipeline* mFirst;
|
|
const VulkanPipeline* mSecond;
|
|
std::shared_ptr<VulkanPipeline::DescriptorSet> mBufferBufferSet;
|
|
std::shared_ptr<VulkanPipeline::DescriptorSet> mImageBufferSet;
|
|
const VulkanBackend* mBackend;
|
|
std::shared_ptr<VulkanBuffer> mUnitBuffer;
|
|
};
|
|
VulkanMatMul(bool transposeA, bool transposeB, Backend* vkBn);
|
|
~ VulkanMatMul() {
|
|
// Do nothing
|
|
}
|
|
virtual ErrorCode onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
|
const VulkanCommandPool::Buffer *cmdBuffer) override;
|
|
|
|
private:
|
|
std::vector<std::shared_ptr<VulkanBuffer>> mTempBuffer;
|
|
std::shared_ptr<VulkanMatrixMultier4x4> mCore;
|
|
bool mTransposeA;
|
|
bool mTransposeB;
|
|
std::vector<const VulkanPipeline*> mPipelines;
|
|
std::vector<std::shared_ptr<VulkanPipeline::DescriptorSet>> mSets;
|
|
std::shared_ptr<Reorder> mInputReorder;
|
|
std::shared_ptr<Reorder> mWeightReorder;
|
|
std::shared_ptr<Reorder> mOutputReorder;
|
|
std::shared_ptr<VulkanImage> mKernelImage;
|
|
std::shared_ptr<VulkanImage> mInputImage;
|
|
std::shared_ptr<VulkanImage> mOutputImage;
|
|
};
|
|
}
|
|
#endif
|