MNN/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp

128 lines
4.4 KiB
C++
Raw Normal View History

2024-05-11 19:17:02 +08:00
//
// AttentionBufExecution.hpp
// MNN
//
// Created by MNN on 2024/04/11.
// Copyright © 2018, Alibaba Group Holding Limited
//
2024-07-04 11:53:45 +08:00
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
2024-05-11 19:17:02 +08:00
#ifndef AttentionBufExecution_hpp
#define AttentionBufExecution_hpp
#include "backend/opencl/execution/image/CommonExecution.hpp"
namespace MNN {
namespace OpenCL {
2024-09-12 12:57:57 +08:00
class KVCacheCLManager {
2024-05-11 19:17:02 +08:00
public:
2024-09-12 12:57:57 +08:00
KVCacheCLManager(Backend *backend, bool kv_cache);
2024-05-11 19:17:02 +08:00
2024-09-12 12:57:57 +08:00
~KVCacheCLManager() = default;
void allocKVCache();
bool reallocKVCache();
void setArgs(int pastLength, int numHead, int kvNumHead, int headDim){
mPastLength = pastLength;
mNumHead = numHead;
mKvNumHead = kvNumHead;
mHeadDim = headDim;
}
int kvLength() {
return mPastLength;
}
void addKvLength(){
mPastLength += 1;
}
int maxLength() {
return mMaxLength;
}
int numHead() {
return mNumHead;
}
const cl::Buffer * key() {
return mPastKey.get();
}
const cl::Buffer * value() {
return mPastValue.get();
2024-05-11 19:17:02 +08:00
}
private:
bool mKVCache;
2024-07-04 11:53:45 +08:00
const int mExpandChunk = 2048;
std::shared_ptr<cl::Buffer> mPastKey, mPastValue;
2024-09-12 12:57:57 +08:00
int mPastLength = 0, mMaxLength = 0, mNumHead = 0, mKvNumHead = 0, mHeadDim = 0;
OpenCLBackend *mOpenCLBackend;
int mByte = 4;
};
class AttentionBufExecution : public CommonExecution {
public:
AttentionBufExecution(const MNN::Op *op, Backend *backend, bool kv_cache);
AttentionBufExecution(std::shared_ptr<KVCacheCLManager> manager, const MNN::Op *op, Backend *backend);
2024-11-18 14:37:45 +08:00
ErrorCode longPrefillResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
2024-09-12 12:57:57 +08:00
virtual ~AttentionBufExecution() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
private:
int getLocalSize(int size, int maxGroupSize);
void reallocKVCache();
bool mIsDecode = false;
2024-11-18 14:37:45 +08:00
bool mIsFirstPrefill = true;
2024-09-12 12:57:57 +08:00
int mKv_seq_len = 0;
2024-11-18 14:37:45 +08:00
int mMax_len = 0;
2024-05-11 19:17:02 +08:00
std::shared_ptr<KernelWrap> mKernel_qk;
std::shared_ptr<KernelWrap> mKernel_softmax;
std::shared_ptr<KernelWrap> mKernel_qkv;
std::vector<uint32_t> mGlobalWorkSizeQk{1, 1, 1};
std::vector<uint32_t> mLocalWorkSizeQk{1, 1, 1, 1};
std::vector<uint32_t> mGlobalWorkSizeSoftMax{1, 1, 1};
std::vector<uint32_t> mLocalWorkSizeSoftMax{1, 1, 1, 1};
std::vector<uint32_t> mGlobalWorkSizeQkv{1, 1, 1};
std::vector<uint32_t> mLocalWorkSizeQkv{1, 1, 1, 1};
uint32_t mMaxWorkGroupSize;
OpenCLBackend *mOpenCLBackend;
2024-11-18 14:37:45 +08:00
RecordUpdateInfo mRgUpdateInfo;
2024-05-11 19:17:02 +08:00
RecordUpdateInfo mQkUpdateInfo;
RecordUpdateInfo mSoftMaxUpdateInfo;
2024-11-18 14:37:45 +08:00
RecordUpdateInfo mRgVUpdateInfo;
2024-05-11 19:17:02 +08:00
RecordUpdateInfo mQkvUpdateInfo;
2024-09-12 12:57:57 +08:00
int mGlobalWorkSizeQk0 = 0;
2024-11-18 14:37:45 +08:00
size_t mQkGlobal_size[2];
2024-05-11 19:17:02 +08:00
std::vector<RecordUpdateInfo*> mOpRecordUpdateInfo;
2024-09-12 12:57:57 +08:00
std::shared_ptr<KVCacheCLManager> mKVCacheCLManager;
std::shared_ptr<Tensor> mTempQK, mTempSoftMax;
2024-05-11 19:17:02 +08:00
private:
2024-09-12 12:57:57 +08:00
int mAlignQ, mAlignKV, mAlignHDK, mAlignHDN;
bool mLongPrefill = false;
2024-11-18 14:37:45 +08:00
std::shared_ptr<KernelWrap> mKernel_rearrangeQ;
std::vector<uint32_t> mGlobalWorkSizeRearrgQ{1, 1, 1};
std::vector<uint32_t> mLocalWorkSizeRearrgQ{1, 1, 1, 1};
std::shared_ptr<KernelWrap> mKernel_rearrangeV;
std::vector<uint32_t> mGlobalWorkSizeRearrgV{1, 1, 1};
std::vector<uint32_t> mLocalWorkSizeRearrgV{1, 1, 1, 1};
2024-09-12 12:57:57 +08:00
std::shared_ptr<KernelWrap> mKernel_rearrange;
std::vector<uint32_t> mGlobalWorkSizeRearrg{1, 1, 1};
std::vector<uint32_t> mLocalWorkSizeRearrg{1, 1, 1, 1};
std::shared_ptr<KernelWrap> mKernel_mask;
std::vector<uint32_t> mGlobalWorkSizeMask{1, 1, 1};
std::vector<uint32_t> mLocalWorkSizeMask{1, 1, 1, 1};
std::shared_ptr<KernelWrap> mKernel_trans;
std::vector<uint32_t> mGlobalWorkSizeTrans{1, 1, 1};
std::vector<uint32_t> mLocalWorkSizeTrans{1, 1, 1, 1};
std::shared_ptr<KernelWrap> mKernel_clip;
std::vector<uint32_t> mGlobalWorkSizeClip{1, 1, 1};
std::vector<uint32_t> mLocalWorkSizeClip{1, 1, 1, 1};
std::shared_ptr<Tensor> mTempQ, mTempK, mTempV, mTempMask, mTempQKV;
bool mIsAddMask = false;
2024-05-11 19:17:02 +08:00
};
} // namespace OpenCL
} // namespace MNN
#endif /* AttentionBufExecution_hpp */
2024-07-04 11:53:45 +08:00
#endif/* MNN_SUPPORT_TRANSFORMER_FUSE */