MNN/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp

163 lines
5.9 KiB
C++
Raw Normal View History

2024-05-11 19:17:02 +08:00
//
// AttentionBufExecution.hpp
// MNN
//
// Created by MNN on 2024/04/11.
// Copyright © 2018, Alibaba Group Holding Limited
//
2024-07-04 11:53:45 +08:00
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
2024-05-11 19:17:02 +08:00
#ifndef AttentionBufExecution_hpp
#define AttentionBufExecution_hpp
#include "backend/opencl/execution/image/CommonExecution.hpp"
#include "core/OpCommonUtils.hpp"
2024-05-11 19:17:02 +08:00
namespace MNN {
namespace OpenCL {
2024-09-12 12:57:57 +08:00
class KVCacheCLManager {
2024-05-11 19:17:02 +08:00
public:
2024-09-12 12:57:57 +08:00
KVCacheCLManager(Backend *backend, bool kv_cache);
2024-05-11 19:17:02 +08:00
2024-09-12 12:57:57 +08:00
~KVCacheCLManager() = default;
2025-06-05 15:15:29 +08:00
void allocKVCache(const KVMeta* meta);
bool reallocKVCache(const KVMeta* meta, bool isExecute = true);
void setArgs(int numHead, int kvNumHead, int headDim){
2024-09-12 12:57:57 +08:00
mNumHead = numHead;
mKvNumHead = kvNumHead;
mHeadDim = headDim;
}
int pastKvLength() {
2024-09-12 12:57:57 +08:00
return mPastLength;
}
void addKvLength(int seq_len){
mPastLength += seq_len;
2024-09-12 12:57:57 +08:00
}
int maxLength() {
return mMaxLength;
}
int numHead() {
return mNumHead;
}
const cl::Buffer * key() {
return mPastKey.get();
}
const cl::Buffer * value() {
return mPastValue.get();
2024-05-11 19:17:02 +08:00
}
private:
bool mKVCache;
const int mExpandChunk = 64;
2024-07-04 11:53:45 +08:00
std::shared_ptr<cl::Buffer> mPastKey, mPastValue;
2024-09-12 12:57:57 +08:00
int mPastLength = 0, mMaxLength = 0, mNumHead = 0, mKvNumHead = 0, mHeadDim = 0;
OpenCLBackend *mOpenCLBackend;
int mByte = 4;
};
class AttentionBufExecution : public CommonExecution {
public:
AttentionBufExecution(const MNN::Op *op, Backend *backend, bool kv_cache);
AttentionBufExecution(std::shared_ptr<KVCacheCLManager> manager, const MNN::Op *op, Backend *backend);
2024-11-18 14:37:45 +08:00
ErrorCode longPrefillResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
2025-03-12 11:35:16 +08:00
ErrorCode prefillResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
ErrorCode decodeResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
2024-09-12 12:57:57 +08:00
2025-03-12 11:35:16 +08:00
ErrorCode UpdateArgs(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
ErrorCode init();
2025-06-05 15:15:29 +08:00
int getExecuteTime();
2024-09-12 12:57:57 +08:00
virtual ~AttentionBufExecution() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
private:
KVMeta* mMeta;
2024-09-12 12:57:57 +08:00
int getLocalSize(int size, int maxGroupSize);
bool mIsDecode = false;
2025-03-12 11:35:16 +08:00
void handleKVCache(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
int mPastKvSeqlen = 0;
int mKvSeqlen = 0;
2024-12-31 15:34:08 +08:00
int mKeyValueMaxlen = 0;
int mDecodeTmpMaxlen = 0;
2025-03-12 11:35:16 +08:00
2024-05-11 19:17:02 +08:00
uint32_t mMaxWorkGroupSize;
OpenCLBackend *mOpenCLBackend;
2024-11-18 14:37:45 +08:00
RecordUpdateInfo mRgUpdateInfo;
2025-02-12 11:14:19 +08:00
RecordUpdateInfo mRgQUpdateInfo;
2025-06-05 15:15:29 +08:00
RecordUpdateInfo mRgMUpdateInfo;
2024-05-11 19:17:02 +08:00
RecordUpdateInfo mQkUpdateInfo;
RecordUpdateInfo mSoftMaxUpdateInfo;
2024-11-18 14:37:45 +08:00
RecordUpdateInfo mRgVUpdateInfo;
2024-05-11 19:17:02 +08:00
RecordUpdateInfo mQkvUpdateInfo;
2024-09-12 12:57:57 +08:00
int mGlobalWorkSizeQk0 = 0;
2024-11-18 14:37:45 +08:00
size_t mQkGlobal_size[2];
2025-06-05 15:15:29 +08:00
size_t mQkPrefillGlobal_size[3];
2024-05-11 19:17:02 +08:00
std::vector<RecordUpdateInfo*> mOpRecordUpdateInfo;
2024-09-12 12:57:57 +08:00
std::shared_ptr<KVCacheCLManager> mKVCacheCLManager;
std::shared_ptr<Tensor> mTempQK, mTempSoftMax;
2024-05-11 19:17:02 +08:00
private:
2024-09-12 12:57:57 +08:00
int mAlignQ, mAlignKV, mAlignHDK, mAlignHDN;
bool mLongPrefill = false;
2025-03-12 11:35:16 +08:00
int mQseqSplitNum = 1;
std::shared_ptr<Tensor> mTempQ, mTempK, mTempV, mTempMask, mTempQKV;
bool mIsAddMask = false;
bool mNeedKvCache = true;
bool mHasMask = false;
private:
std::vector<std::shared_ptr<KernelWrap>> mKernel_rearrange_vec;
std::vector<std::shared_ptr<KernelWrap>> mKernel_mask_vec;
std::vector<std::shared_ptr<KernelWrap>> mKernel_trans_vec;
std::vector<std::shared_ptr<KernelWrap>> mKernel_clip_vec;
std::vector<std::shared_ptr<KernelWrap>> mKernel_qk_vec;
std::vector<std::shared_ptr<KernelWrap>> mKernel_softmax_vec;
std::vector<std::shared_ptr<KernelWrap>> mKernel_qkv_vec;
std::vector<std::vector<uint32_t>> mGwsQkVec;
std::vector<std::vector<uint32_t>> mLwsQkVec;
std::vector<std::vector<uint32_t>> mGwsSoftMaxVec;
std::vector<std::vector<uint32_t>> mLwsSoftMaxVec;
std::vector<std::vector<uint32_t>> mGwsQkvVec;
std::vector<std::vector<uint32_t>> mLwsQkvVec;
std::vector<std::vector<uint32_t>> mGwsRearrgVec;
std::vector<std::vector<uint32_t>> mLwsRearrgVec;
std::vector<std::vector<uint32_t>> mGwsMaskVec;
std::vector<std::vector<uint32_t>> mLwsMaskVec;
std::vector<std::vector<uint32_t>> mGwsTransVec;
std::vector<std::vector<uint32_t>> mLwsTransVec;
std::vector<std::vector<uint32_t>> mGwsClipVec;
std::vector<std::vector<uint32_t>> mLwsClipVec;
private:
2024-11-18 14:37:45 +08:00
std::shared_ptr<KernelWrap> mKernel_rearrangeQ;
std::shared_ptr<KernelWrap> mKernel_rearrangeV;
2025-06-05 15:15:29 +08:00
std::shared_ptr<KernelWrap> mKernel_rearrangeMask;
2024-09-12 12:57:57 +08:00
std::shared_ptr<KernelWrap> mKernel_rearrange;
2025-03-12 11:35:16 +08:00
std::shared_ptr<KernelWrap> mKernel_qk;
std::shared_ptr<KernelWrap> mKernel_softmax;
std::shared_ptr<KernelWrap> mKernel_qkv;
std::vector<uint32_t> mGlobalWorkSizeQk;
std::vector<uint32_t> mLocalWorkSizeQk;
std::vector<uint32_t> mGlobalWorkSizeSoftMax;
std::vector<uint32_t> mLocalWorkSizeSoftMax;
std::vector<uint32_t> mGlobalWorkSizeQkv;
std::vector<uint32_t> mLocalWorkSizeQkv;
std::vector<uint32_t> mGlobalWorkSizeRearrgQ;
std::vector<uint32_t> mLocalWorkSizeRearrgQ;
std::vector<uint32_t> mGlobalWorkSizeRearrgV;
std::vector<uint32_t> mLocalWorkSizeRearrgV;
std::vector<uint32_t> mGlobalWorkSizeRearrg;
std::vector<uint32_t> mLocalWorkSizeRearrg;
2025-06-05 15:15:29 +08:00
std::vector<uint32_t> mGlobalWorkSizeRearrgM;
std::vector<uint32_t> mLocalWorkSizeRearrgM;
2025-03-12 11:35:16 +08:00
2024-05-11 19:17:02 +08:00
};
} // namespace OpenCL
} // namespace MNN
#endif /* AttentionBufExecution_hpp */
2024-07-04 11:53:45 +08:00
#endif/* MNN_SUPPORT_TRANSFORMER_FUSE */