mirror of https://github.com/alibaba/MNN.git
OpenCL:Bugfix: Fix bug for llm bench opencl crash
This commit is contained in:
parent
35ec013cab
commit
e457f0ce33
|
|
@ -1559,6 +1559,10 @@ void OpenCLBackend::setGpuMode(const int cl_mode_num) {
|
|||
MNN_PRINT("set multi record kernel mode is not permitted, please check cl_mode:%x!\n", cl_mode_num);
|
||||
}
|
||||
}
|
||||
const Runtime* OpenCLBackend::getRuntime() {
|
||||
return mCLRuntime;
|
||||
}
|
||||
|
||||
#ifdef MNN_OPENCL_SEP_BUILD
|
||||
bool placeholder = []() {
|
||||
static std::once_flag createOnce;
|
||||
|
|
|
|||
|
|
@ -151,6 +151,7 @@ public:
|
|||
bool isCreateError() const;
|
||||
virtual void* onMapTensor(Tensor::MapType mtype, Tensor::DimensionType dtype, const Tensor* srcTensor) override;
|
||||
virtual bool onUnmapTensor(Tensor::MapType mtype, Tensor::DimensionType dtype, const Tensor* dstTensor, void* mapPtr) override;
|
||||
virtual const Runtime* getRuntime() override;
|
||||
|
||||
private:
|
||||
void copyFromDevice(const Tensor* srcTensor, const Tensor* dstTensor) const;
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -12,6 +12,7 @@
|
|||
#define AttentionBufExecution_hpp
|
||||
|
||||
#include "backend/opencl/execution/image/CommonExecution.hpp"
|
||||
#include "core/OpCommonUtils.hpp"
|
||||
|
||||
namespace MNN {
|
||||
namespace OpenCL {
|
||||
|
|
@ -21,19 +22,18 @@ public:
|
|||
KVCacheCLManager(Backend *backend, bool kv_cache);
|
||||
|
||||
~KVCacheCLManager() = default;
|
||||
void allocKVCache();
|
||||
bool reallocKVCache();
|
||||
void setArgs(int pastLength, int numHead, int kvNumHead, int headDim){
|
||||
mPastLength = pastLength;
|
||||
void allocKVCache(const KVMeta* meta, bool isDecodeResize = false);
|
||||
bool reallocKVCache(const KVMeta* meta, bool isDecodeResize = false);
|
||||
void setArgs(int numHead, int kvNumHead, int headDim){
|
||||
mNumHead = numHead;
|
||||
mKvNumHead = kvNumHead;
|
||||
mHeadDim = headDim;
|
||||
}
|
||||
int kvLength() {
|
||||
int pastKvLength() {
|
||||
return mPastLength;
|
||||
}
|
||||
void addKvLength(){
|
||||
mPastLength += 1;
|
||||
void addKvLength(int seq_len){
|
||||
mPastLength += seq_len;
|
||||
}
|
||||
int maxLength() {
|
||||
return mMaxLength;
|
||||
|
|
@ -50,7 +50,7 @@ public:
|
|||
|
||||
private:
|
||||
bool mKVCache;
|
||||
const int mExpandChunk = 2048;
|
||||
const int mExpandChunk = 64;
|
||||
std::shared_ptr<cl::Buffer> mPastKey, mPastValue;
|
||||
int mPastLength = 0, mMaxLength = 0, mNumHead = 0, mKvNumHead = 0, mHeadDim = 0;
|
||||
OpenCLBackend *mOpenCLBackend;
|
||||
|
|
@ -74,11 +74,12 @@ public:
|
|||
|
||||
private:
|
||||
|
||||
KVMeta* mMeta;
|
||||
int getLocalSize(int size, int maxGroupSize);
|
||||
bool mIsDecode = false;
|
||||
void handleKVCache(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs);
|
||||
bool mIsFirstPrefill = true;
|
||||
int mKv_seq_len = 0;
|
||||
int mPastKvSeqlen = 0;
|
||||
int mKvSeqlen = 0;
|
||||
int mKeyValueMaxlen = 0;
|
||||
int mDecodeTmpMaxlen = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -372,7 +372,7 @@ __kernel void rearrange_k(GLOBAL_SIZE_3_DIMS
|
|||
vstore4((FLOAT4)(key_vec0.s3, key_vec1.s3, key_vec2.s3, key_vec3.s3), 0, past_key + output_offset + max_len + max_len + max_len);
|
||||
#else
|
||||
FLOAT4 key_vec = vload4(0, key + (b * kv_head_num + z) * head_dim + y4);
|
||||
const int output_offset = ((b * kv_head_num + z) * head_dim + y4) * max_len + past_len - 1;
|
||||
const int output_offset = ((b * kv_head_num + z) * head_dim + y4) * max_len + past_len;
|
||||
past_key[output_offset] = key_vec.s0;
|
||||
past_key[output_offset + max_len] = key_vec.s1;
|
||||
past_key[output_offset + max_len + max_len] = key_vec.s2;
|
||||
|
|
@ -413,7 +413,7 @@ __kernel void rearrange_v(GLOBAL_SIZE_3_DIMS
|
|||
vstore4(value_vec3, 0, past_value + output_offset + head_dim + head_dim + head_dim);
|
||||
#else
|
||||
FLOAT4 value_vec = vload4(0, value + (b * kv_head_num + z) * head_dim + x4);
|
||||
const int output_offset = ((b * kv_head_num + z) * max_len + past_len - 1) * head_dim + x4;
|
||||
const int output_offset = ((b * kv_head_num + z) * max_len + past_len) * head_dim + x4;
|
||||
vstore4(value_vec, 0, past_value + output_offset);
|
||||
#endif
|
||||
}
|
||||
|
|
@ -424,11 +424,12 @@ __kernel void matmul_qk_div_mask_prefill(GLOBAL_SIZE_3_DIMS
|
|||
#ifdef ADD_MASK
|
||||
__global const FLOAT* mask,
|
||||
#elif defined(SET_MASK)
|
||||
__global const int* mask, // [1 1 query_seq_len key_seq_len]
|
||||
__global const int* mask, // [1 1 query_seq_len mask_key_seq_len]
|
||||
#endif
|
||||
__global FLOAT *qk, // [batch head_num kv_seq_length query_seq_len_4]
|
||||
__private const float scale,
|
||||
__private const int query_seq_len,
|
||||
__private const int mask_key_seq_len,
|
||||
__private const int key_seq_len,
|
||||
__private const int max_len,
|
||||
__private const int head_num,
|
||||
|
|
@ -485,10 +486,10 @@ __kernel void matmul_qk_div_mask_prefill(GLOBAL_SIZE_3_DIMS
|
|||
out3 *= (float4)scale;
|
||||
{
|
||||
#if defined(ADD_MASK) || defined(SET_MASK)
|
||||
int mask_offset = x4 * key_seq_len + y4;
|
||||
float4 mask_tmp0 = convert_float4(vload4(0, mask + mask_offset)); mask_offset += key_seq_len;
|
||||
float4 mask_tmp1 = (x4 + 1 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); mask_offset += key_seq_len;
|
||||
float4 mask_tmp2 = (x4 + 2 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); mask_offset += key_seq_len;
|
||||
int mask_offset = x4 * mask_key_seq_len + y4;
|
||||
float4 mask_tmp0 = convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len;
|
||||
float4 mask_tmp1 = (x4 + 1 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len;
|
||||
float4 mask_tmp2 = (x4 + 2 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset)); mask_offset += mask_key_seq_len;
|
||||
float4 mask_tmp3 = (x4 + 3 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0, mask + mask_offset));
|
||||
float4 mask0 = (float4)(mask_tmp0.s0, mask_tmp1.s0, mask_tmp2.s0, mask_tmp3.s0);
|
||||
float4 mask1 = (float4)(mask_tmp0.s1, mask_tmp1.s1, mask_tmp2.s1, mask_tmp3.s1);
|
||||
|
|
|
|||
|
|
@ -317,7 +317,7 @@ const char* attention_buf =
|
|||
" vstore4((FLOAT4)(key_vec0.s3,key_vec1.s3,key_vec2.s3,key_vec3.s3),0,past_key+output_offset+max_len+max_len+max_len);\n"
|
||||
"#else\n"
|
||||
" FLOAT4 key_vec=vload4(0,key+(b*kv_head_num+z)*head_dim+y4);\n"
|
||||
" const int output_offset=((b*kv_head_num+z)*head_dim+y4)*max_len+past_len-1;\n"
|
||||
" const int output_offset=((b*kv_head_num+z)*head_dim+y4)*max_len+past_len;\n"
|
||||
" past_key[output_offset]=key_vec.s0;\n"
|
||||
" past_key[output_offset+max_len]=key_vec.s1;\n"
|
||||
" past_key[output_offset+max_len+max_len]=key_vec.s2;\n"
|
||||
|
|
@ -357,7 +357,7 @@ const char* attention_buf =
|
|||
" vstore4(value_vec3,0,past_value+output_offset+head_dim+head_dim+head_dim);\n"
|
||||
"#else\n"
|
||||
" FLOAT4 value_vec=vload4(0,value+(b*kv_head_num+z)*head_dim+x4);\n"
|
||||
" const int output_offset=((b*kv_head_num+z)*max_len+past_len-1)*head_dim+x4;\n"
|
||||
" const int output_offset=((b*kv_head_num+z)*max_len+past_len)*head_dim+x4;\n"
|
||||
" vstore4(value_vec,0,past_value+output_offset);\n"
|
||||
"#endif\n"
|
||||
"}\n"
|
||||
|
|
@ -367,11 +367,12 @@ const char* attention_buf =
|
|||
" #ifdef ADD_MASK\n"
|
||||
" __global const FLOAT* mask,\n"
|
||||
" #elif defined(SET_MASK)\n"
|
||||
" __global const int* mask,// [1 1 query_seq_len key_seq_len]\n"
|
||||
" __global const int* mask,// [1 1 query_seq_len mask_key_seq_len]\n"
|
||||
" #endif\n"
|
||||
" __global FLOAT *qk,// [batch head_num kv_seq_length query_seq_len_4]\n"
|
||||
" __private const float scale,\n"
|
||||
" __private const int query_seq_len,\n"
|
||||
" __private const int mask_key_seq_len,\n"
|
||||
" __private const int key_seq_len,\n"
|
||||
" __private const int max_len,\n"
|
||||
" __private const int head_num,\n"
|
||||
|
|
@ -427,10 +428,10 @@ const char* attention_buf =
|
|||
" out3 *= (float4)scale;\n"
|
||||
" {\n"
|
||||
" #if defined(ADD_MASK) || defined(SET_MASK)\n"
|
||||
" int mask_offset=x4*key_seq_len+y4;\n"
|
||||
" float4 mask_tmp0=convert_float4(vload4(0,mask+mask_offset)); mask_offset += key_seq_len;\n"
|
||||
" float4 mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += key_seq_len;\n"
|
||||
" float4 mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += key_seq_len;\n"
|
||||
" int mask_offset=x4*mask_key_seq_len+y4;\n"
|
||||
" float4 mask_tmp0=convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
|
||||
" float4 mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
|
||||
" float4 mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
|
||||
" float4 mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset));\n"
|
||||
" float4 mask0=(float4)(mask_tmp0.s0,mask_tmp1.s0,mask_tmp2.s0,mask_tmp3.s0);\n"
|
||||
" float4 mask1=(float4)(mask_tmp0.s1,mask_tmp1.s1,mask_tmp2.s1,mask_tmp3.s1);\n"
|
||||
|
|
|
|||
Loading…
Reference in New Issue