|
|
|
@ -241,6 +241,7 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector<Tensor *> &inputs,
|
|
|
|
|
mGlobalWorkSizeQk0 = UP_DIV(mKvSeqlen, 4);
|
|
|
|
|
mQkPrefillGlobal_size[1] = ROUND_UP(mGlobalWorkSizeQk0, std::max((uint32_t)1, mLocalWorkSizeQk[1]));
|
|
|
|
|
mGlobalWorkSizeQk[1] = mQkPrefillGlobal_size[1];
|
|
|
|
|
mTempQ.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, 4) * ROUND_UP(headDim, 4) * batch * numHead}));
|
|
|
|
|
mTempQK.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, 4) * mKvSeqlen * numHead * batch}));
|
|
|
|
|
mTempSoftMax.reset(Tensor::createDevice<float>({ROUND_UP(seqlen, 4) * mKvSeqlen * numHead * batch}));
|
|
|
|
|
if(mIsAddMask) {
|
|
|
|
@ -248,23 +249,23 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector<Tensor *> &inputs,
|
|
|
|
|
} else {
|
|
|
|
|
mTempMask.reset(Tensor::createDevice<uint32_t>({ROUND_UP(maskQlen, 4) * ROUND_UP(maskKvlen, 4) * batch}));
|
|
|
|
|
}
|
|
|
|
|
mOpenCLBackend->onAcquireBuffer(mTempQ.get(), Backend::DYNAMIC_IN_EXECUTION);
|
|
|
|
|
mOpenCLBackend->onAcquireBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION);
|
|
|
|
|
mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION);
|
|
|
|
|
mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION);
|
|
|
|
|
mOpenCLBackend->onReleaseBuffer(mTempQ.get(), Backend::DYNAMIC_IN_EXECUTION);
|
|
|
|
|
mOpenCLBackend->onReleaseBuffer(mTempMask.get(), Backend::DYNAMIC_IN_EXECUTION);
|
|
|
|
|
mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC_IN_EXECUTION);
|
|
|
|
|
mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC_IN_EXECUTION);
|
|
|
|
|
}
|
|
|
|
|
#ifndef ENABLE_OPENCL_TIME_PROFILER
|
|
|
|
|
if(mOpenCLBackend->isUseRecordQueue()){
|
|
|
|
|
if(mLongPrefill){
|
|
|
|
|
mRgUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))();
|
|
|
|
|
mQkUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))();
|
|
|
|
|
mRgVUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))();
|
|
|
|
|
mQkvUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))();
|
|
|
|
|
}else{
|
|
|
|
|
#ifndef ENABLE_OPENCL_TIME_PROFILER
|
|
|
|
|
if(mOpenCLBackend->isUseRecordQueue()){
|
|
|
|
|
if(mLongPrefill){
|
|
|
|
|
mRgUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))();
|
|
|
|
|
mRgUpdateInfo.update_kernel_args[1].arg_value = &(*(mKVCacheCLManager->value()))();
|
|
|
|
|
}else{
|
|
|
|
|
mRgQUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempQ.get())();
|
|
|
|
|
mRgUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))();
|
|
|
|
|
mRgUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))();
|
|
|
|
|
mRgMUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempMask.get())();
|
|
|
|
|
mQkUpdateInfo.update_kernel_args[1].arg_value = &openCLDeferBuffer(mTempQ.get())();
|
|
|
|
|
mQkUpdateInfo.update_kernel_args[2].arg_value = &(*(mKVCacheCLManager->key()))();
|
|
|
|
@ -276,28 +277,34 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector<Tensor *> &inputs,
|
|
|
|
|
}
|
|
|
|
|
mSoftMaxUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempQK.get())();
|
|
|
|
|
mSoftMaxUpdateInfo.update_kernel_args[1].arg_value = &openCLDeferBuffer(mTempSoftMax.get())();
|
|
|
|
|
mRgVUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))();
|
|
|
|
|
mRgVUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))();
|
|
|
|
|
mQkvUpdateInfo.update_kernel_args[0].arg_value = &openCLDeferBuffer(mTempSoftMax.get())();
|
|
|
|
|
mQkvUpdateInfo.update_kernel_args[1].arg_value = &(*(mKVCacheCLManager->value()))();
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
#endif
|
|
|
|
|
if(mLongPrefill){
|
|
|
|
|
// rearrange key value
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
#endif
|
|
|
|
|
if(mLongPrefill){
|
|
|
|
|
// rearrange key value
|
|
|
|
|
cl_int ret = CL_SUCCESS;
|
|
|
|
|
ret |= mKernel_rearrange_vec[0]->get().setArg(9, *mKVCacheCLManager->key());
|
|
|
|
|
ret |= mKernel_rearrange_vec[0]->get().setArg(10, *mKVCacheCLManager->value());
|
|
|
|
|
ret |= mKernel_rearrange_vec[0]->get().setArg(14, mKeyValueMaxlen);
|
|
|
|
|
MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_k");
|
|
|
|
|
}else{
|
|
|
|
|
{
|
|
|
|
|
// rearrange query
|
|
|
|
|
cl_int ret = CL_SUCCESS;
|
|
|
|
|
ret |= mKernel_rearrange_vec[0]->get().setArg(9, *mKVCacheCLManager->key());
|
|
|
|
|
ret |= mKernel_rearrange_vec[0]->get().setArg(10, *mKVCacheCLManager->value());
|
|
|
|
|
ret |= mKernel_rearrange_vec[0]->get().setArg(14, mKeyValueMaxlen);
|
|
|
|
|
ret |= mKernel_rearrangeQ->get().setArg(4, openCLDeferBuffer(mTempQ.get()));
|
|
|
|
|
MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_q");
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
// rearrange key
|
|
|
|
|
cl_int ret = CL_SUCCESS;
|
|
|
|
|
ret |= mKernel_rearrange->get().setArg(4, *mKVCacheCLManager->key());
|
|
|
|
|
ret |= mKernel_rearrange->get().setArg(5, mPastKvSeqlen);
|
|
|
|
|
ret |= mKernel_rearrange->get().setArg(6, mKeyValueMaxlen);
|
|
|
|
|
MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_k");
|
|
|
|
|
}else{
|
|
|
|
|
{
|
|
|
|
|
// rearrange key
|
|
|
|
|
cl_int ret = CL_SUCCESS;
|
|
|
|
|
ret |= mKernel_rearrange->get().setArg(4, *mKVCacheCLManager->key());
|
|
|
|
|
ret |= mKernel_rearrange->get().setArg(5, mPastKvSeqlen);
|
|
|
|
|
ret |= mKernel_rearrange->get().setArg(6, mKeyValueMaxlen);
|
|
|
|
|
MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_k");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if(mHasMask){
|
|
|
|
|
// rearrange mask
|
|
|
|
|
cl_int ret = CL_SUCCESS;
|
|
|
|
@ -309,6 +316,7 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector<Tensor *> &inputs,
|
|
|
|
|
mGlobalWorkSizeQk = {static_cast<uint32_t>(UP_DIV(seqlen, 4)), static_cast<uint32_t>(UP_DIV(mKvSeqlen, 4)), static_cast<uint32_t>(numHead*batch)};
|
|
|
|
|
cl_int ret = CL_SUCCESS;
|
|
|
|
|
ret |= mKernel_qk->get().setArg(1, mGlobalWorkSizeQk0);
|
|
|
|
|
ret |= mKernel_qk->get().setArg(3, openCLDeferBuffer(mTempQ.get()));
|
|
|
|
|
ret |= mKernel_qk->get().setArg(4, *mKVCacheCLManager->key());
|
|
|
|
|
if(mHasMask) {
|
|
|
|
|
ret |= mKernel_qk->get().setArg(5, openCLDeferBuffer(mTempMask.get()));
|
|
|
|
@ -337,8 +345,8 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector<Tensor *> &inputs,
|
|
|
|
|
ret |= mKernel_rearrangeV->get().setArg(6, mKeyValueMaxlen);
|
|
|
|
|
MNN_CHECK_CL_SUCCESS(ret, "reSetArg rearrange_v");
|
|
|
|
|
}
|
|
|
|
|
// qk * value
|
|
|
|
|
{
|
|
|
|
|
// qk * value
|
|
|
|
|
cl_int ret = CL_SUCCESS;
|
|
|
|
|
ret |= mKernel_qkv->get().setArg(3, openCLDeferBuffer(mTempSoftMax.get()));
|
|
|
|
|
ret |= mKernel_qkv->get().setArg(4, *mKVCacheCLManager->value());
|
|
|
|
|