MNN/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp

419 lines
21 KiB
C++
Raw Normal View History

2024-05-11 19:17:02 +08:00
//
// SoftmaxBufExecution.cpp
// MNN
//
// Created by MNN on 2024/04/11.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef MNN_OPENCL_BUFFER_CLOSED
#include "backend/opencl/execution/buffer/AttentionBufExecution.hpp"
namespace MNN {
namespace OpenCL {
AttentionBufImpl::AttentionBufImpl(const MNN::Op *op, Backend *backend, bool kv_cahce)
: mKVCache(kv_cahce){
mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("softmax_buf", "softmax_channel", {"-DSOFTMAX_LOCAL_SIZE=512"});
mMaxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel));
}
void AttentionBufImpl::allocKVCache() {
if (!mKVCache || mPastLength < mMaxLength) {
return;
}
mMaxLength = mPastLength + mExpandChunk;
2024-06-03 20:09:34 +08:00
size_t buffer_size = UP_DIV(mMaxLength, 4) * mNumHead * mHeadDim * 4 * mByte;
2024-05-11 19:17:02 +08:00
// past_key: [1, numhead, headdim, maxlen]
mPastKey.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
// past_value: [1, numhead, maxlen, headdim]
mPastValue.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
}
void AttentionBufImpl::reallocKVCache() {
if (!mKVCache || mPastLength < mMaxLength) {
return;
}
2024-06-03 20:09:34 +08:00
size_t old_size = mNumHead * UP_DIV(mMaxLength, 4) * mHeadDim * 4 * mByte;
2024-05-11 19:17:02 +08:00
mMaxLength = mPastLength + mExpandChunk;
2024-06-03 20:09:34 +08:00
size_t buffer_size = UP_DIV(mMaxLength, 4) * mNumHead * mHeadDim * 4 * mByte;
2024-05-11 19:17:02 +08:00
// past_key: [1, numhead, headdim, maxlen]
auto new_key = new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size);
// past_value: [1, numhead, maxlen, headdim]
auto new_value = new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size);
// copy
cl_int res;
auto new_key_ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*new_key, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res);
auto key_ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*mPastKey.get(), true, CL_MAP_READ, 0, old_size, nullptr, nullptr, &res);
if(new_key_ptr != nullptr && key_ptr != nullptr && res == CL_SUCCESS){
::memcpy(new_key_ptr, key_ptr, old_size);
}else{
MNN_ERROR("Map error key_ptr == nullptr \n");
MNN_ASSERT(false);
}
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*new_key, new_key_ptr);
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mPastKey.get(), key_ptr);
auto new_value_ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*new_value, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res);
auto value_ptr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*mPastValue.get(), true, CL_MAP_READ, 0, old_size, nullptr, nullptr, &res);
if(new_value_ptr != nullptr && value_ptr != nullptr && res == CL_SUCCESS){
::memcpy(new_value_ptr, value_ptr, old_size);
}else{
MNN_ERROR("Map error value_ptr == nullptr \n");
MNN_ASSERT(false);
}
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*new_value, new_value_ptr);
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mPastValue.get(), value_ptr);
mPastKey.reset(new_key);
mPastValue.reset(new_value);
2024-06-03 20:09:34 +08:00
size_t temp_size = UP_DIV(mMaxLength, 4) * mNumHead * 4 * mByte;
2024-05-11 19:17:02 +08:00
mTempQK.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, temp_size));
mTempSoftMax.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, temp_size));
// reset memory for args
if(mOpenCLBackend->isUseRecordQueue()){
mQkUpdateInfo.update_kernel_args[1].arg_value = &(*(mTempQK.get()))();
mQkUpdateInfo.update_kernel_args[2].arg_value = &(*(mPastKey.get()))();
mSoftMaxUpdateInfo.update_kernel_args[0].arg_value = &(*(mTempQK.get()))();
mSoftMaxUpdateInfo.update_kernel_args[1].arg_value = &(*(mTempSoftMax.get()))();
mQkvUpdateInfo.update_kernel_args[0].arg_value = &(*(mTempSoftMax.get()))();
mQkvUpdateInfo.update_kernel_args[1].arg_value = &(*(mPastValue.get()))();
}else{
cl_int ret = CL_SUCCESS;
ret |= mKernel_qk->get().setArg(5, *mTempQK.get());
ret |= mKernel_qk->get().setArg(6, *mPastKey.get());
ret |= mKernel_softmax->get().setArg(3, *mTempQK.get());
ret |= mKernel_softmax->get().setArg(4, *mTempSoftMax.get());
ret |= mKernel_qkv->get().setArg(3, *mTempSoftMax.get());
ret |= mKernel_qkv->get().setArg(6, *mPastValue.get());
MNN_CHECK_CL_SUCCESS(ret, "reset memory arg for AttentionBufExecution");
}
}
int AttentionBufImpl::getLocalSize(int size, int maxGroupSize){
int local_size = 1;
while(local_size * 2 <= maxGroupSize && local_size * 2 <= size){
local_size *= 2;
}
return local_size;
}
ErrorCode AttentionBufImpl::onResize(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
mOpenCLBackend->startRecord(mRecording);
//clear update arg vector, if prefill and decode use the same one
mOpRecordUpdateInfo.clear();
mQkUpdateInfo.update_kernel_args.clear();
mQkUpdateInfo.update_global_size.clear();
mQkUpdateInfo.update_local_size.clear();
mSoftMaxUpdateInfo.update_kernel_args.clear();
mSoftMaxUpdateInfo.update_global_size.clear();
mSoftMaxUpdateInfo.update_local_size.clear();
mQkvUpdateInfo.update_kernel_args.clear();
mQkvUpdateInfo.update_global_size.clear();
mQkvUpdateInfo.update_local_size.clear();
auto query = inputs[0];
auto key = inputs[1];
auto value = inputs[2];
auto mask = inputs[3];
auto runtime = mOpenCLBackend->getOpenCLRuntime();
auto shape = query->shape();
int seq_len = shape[1];
mNumHead = shape[2];
mHeadDim = shape[3];
mScale = 1.0 / sqrt(mHeadDim);
mIsDecode = seq_len == 1;
mIsFirstDecode = true;
if (mPastLength == 0 || seq_len > 1) {
mPastLength = seq_len;
}
mKv_seq_len = mPastLength;
if(mIsDecode){
mKv_seq_len = mPastLength + 1;
}
if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){
2024-06-03 20:09:34 +08:00
mByte = 2;
2024-05-11 19:17:02 +08:00
}
2024-06-03 20:09:34 +08:00
allocKVCache();
2024-05-11 19:17:02 +08:00
if (mIsDecode) {
2024-06-03 20:09:34 +08:00
size_t buffer_size = UP_DIV(mMaxLength, 4) * mNumHead * 4 * mByte;
2024-05-11 19:17:02 +08:00
mTempQK.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
mTempSoftMax.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
} else {
2024-06-03 20:09:34 +08:00
size_t buffer_size = UP_DIV(mPastLength, 4) * mPastLength * mNumHead * 4 * mByte;
2024-05-11 19:17:02 +08:00
mTempQK.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
mTempSoftMax.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
}
// query * key -> div -> select
{
std::set<std::string> buildOption;
if(!mIsDecode){
buildOption.emplace("-DOPENCL_PREFILL_ATTENTION");
}
if((mHeadDim % 4) != 0){
buildOption.emplace("-DHEADDIM_LEAVE");
}
2024-06-03 20:09:34 +08:00
if(mask->getType() == halide_type_of<float>()){
buildOption.emplace("-DADD_MASK");
}
2024-05-11 19:17:02 +08:00
mKernel_qk = runtime->buildKernel("attention_buf", "matmul_qk_div_mask", buildOption, inputs[0], outputs[0]);
mGlobalWorkSizeQk = {static_cast<uint32_t>(UP_DIV(seq_len, 4)), static_cast<uint32_t>(mNumHead), static_cast<uint32_t>(UP_DIV(mKv_seq_len, 4))};
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_qk));
mGlobalWorkSizeQk2 = UP_DIV(mKv_seq_len, 4);
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[0]);
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[1]);
ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk2);
ret |= mKernel_qk->get().setArg(index++, openCLBuffer(query));
ret |= mKernel_qk->get().setArg(index++, openCLBuffer(key));
ret |= mKernel_qk->get().setArg(index++, *mTempQK.get());
ret |= mKernel_qk->get().setArg(index++, *mPastKey.get());
ret |= mKernel_qk->get().setArg(index++, openCLBuffer(mask));
ret |= mKernel_qk->get().setArg(index++, mScale);
ret |= mKernel_qk->get().setArg(index++, seq_len);
ret |= mKernel_qk->get().setArg(index++, mKv_seq_len);
ret |= mKernel_qk->get().setArg(index++, mNumHead);
ret |= mKernel_qk->get().setArg(index++, mHeadDim);
MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qk_div_mask");
mLocalWorkSizeQk = localWS3DDefault(mGlobalWorkSizeQk, maxWorkGroupSize, runtime, "matmul_qk_div_mask", mKernel_qk).first;
mGlobalWorkSizeQk[0] = ROUND_UP(mGlobalWorkSizeQk[0], std::max((uint32_t)1, mLocalWorkSizeQk[0]));
mGlobalWorkSizeQk[1] = ROUND_UP(mGlobalWorkSizeQk[1], std::max((uint32_t)1, mLocalWorkSizeQk[1]));
mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk[2], std::max((uint32_t)1, mLocalWorkSizeQk[2]));
mQkUpdateInfo.update_kernel_args.push_back({0, 2, sizeof(mGlobalWorkSizeQk2), &mGlobalWorkSizeQk2});
mQkUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(cl_mem), &(*(mTempQK.get()))()});
mQkUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &(*(mPastKey.get()))()});
mQkUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(mKv_seq_len), &mKv_seq_len});
mQkGlobal_size[0] = mGlobalWorkSizeQk[0];
mQkGlobal_size[1] = mGlobalWorkSizeQk[1];
mQkGlobal_size[2] = mGlobalWorkSizeQk[2];
mQkUpdateInfo.update_global_size.push_back({0, mQkGlobal_size});
mOpRecordUpdateInfo.emplace_back(&mQkUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, &mQkUpdateInfo);
}
// softmax
{
2024-06-03 20:09:34 +08:00
auto MaxLocalSize = std::min(std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize), static_cast<uint32_t>(512));
2024-05-11 19:17:02 +08:00
int localSize = getLocalSize(mKv_seq_len, MaxLocalSize);
if(localSize < 4){
localSize = 1;
}
int past_len4 = UP_DIV(mKv_seq_len, 4);
mSoftMaxRemainChannels = past_len4 * 4 - mKv_seq_len;
mSoftmaxShape[0] = mNumHead;
mSoftmaxShape[1] = past_len4;
mSoftmaxShape[2] = 1;
mSoftmaxShape[3] = mPastLength;
std::set<std::string> buildOption;
buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize));
if(!mIsDecode){
mKernel_softmax = runtime->buildKernel("softmax_buf", "softmax_width", buildOption, inputs[0], outputs[0]);
mGlobalWorkSizeSoftMax = {static_cast<uint32_t>(localSize), static_cast<uint32_t>(past_len4), static_cast<uint32_t>(mNumHead)};
} else{
mKernel_softmax = runtime->buildKernel("softmax_buf", "softmax_channel", buildOption, inputs[0], outputs[0]);
mSoftmaxShape[3] = 1;
mGlobalWorkSizeSoftMax = {static_cast<uint32_t>(localSize), static_cast<uint32_t>(1), static_cast<uint32_t>(mNumHead)};
}
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_softmax));
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[0]);
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[1]);
ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[2]);
ret |= mKernel_softmax->get().setArg(index++, *mTempQK.get());
ret |= mKernel_softmax->get().setArg(index++, *mTempSoftMax.get());
ret |= mKernel_softmax->get().setArg(index++, mSoftMaxRemainChannels);
ret |= mKernel_softmax->get().setArg(index++, mSoftmaxShape);
MNN_CHECK_CL_SUCCESS(ret, "setArg softmax");
mLocalWorkSizeSoftMax = {static_cast<uint32_t>(localSize), 1, 1};
if(localSize == 1){
mLocalWorkSizeSoftMax = localWS3DDefault(mGlobalWorkSizeSoftMax, maxWorkGroupSize, runtime, "softmax", mKernel_softmax).first;
}
mGlobalWorkSizeSoftMax[0] = ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0]));
mGlobalWorkSizeSoftMax[1] = ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1]));
mGlobalWorkSizeSoftMax[2] = ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2]));
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &(*(mTempQK.get()))()});
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &(*(mTempSoftMax.get()))()});
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mSoftMaxRemainChannels), &mSoftMaxRemainChannels});
mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mSoftmaxShape), &mSoftmaxShape});
mOpRecordUpdateInfo.emplace_back(&mSoftMaxUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, &mSoftMaxUpdateInfo);
}
// qk * value
{
std::set<std::string> buildOption;
if(!mIsDecode){
buildOption.emplace("-DOPENCL_PREFILL_ATTENTION");
}
if((mHeadDim % 4) != 0){
buildOption.emplace("-DHEADDIM_LEAVE");
}
mKernel_qkv = runtime->buildKernel("attention_buf", "matmul_qkv", buildOption, inputs[0], outputs[0]);
auto maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel_qkv));
2024-06-03 20:09:34 +08:00
mGlobalWorkSizeQkv = {static_cast<uint32_t>(UP_DIV(seq_len, 4)), static_cast<uint32_t>(mNumHead), static_cast<uint32_t>(UP_DIV(mHeadDim, 4))};
2024-05-11 19:17:02 +08:00
uint32_t index = 0;
cl_int ret = CL_SUCCESS;
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[0]);
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[1]);
ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[2]);
ret |= mKernel_qkv->get().setArg(index++, *mTempSoftMax.get());
ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(value));
ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(outputs[0]));
ret |= mKernel_qkv->get().setArg(index++, *mPastValue.get());
ret |= mKernel_qkv->get().setArg(index++, seq_len);
ret |= mKernel_qkv->get().setArg(index++, mKv_seq_len);
ret |= mKernel_qkv->get().setArg(index++, mNumHead);
ret |= mKernel_qkv->get().setArg(index++, mHeadDim);
MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qkv");
mLocalWorkSizeQkv = localWS3DDefault(mGlobalWorkSizeQkv, maxWorkGroupSize, runtime, "matmul_qkv", mKernel_qkv).first;
mGlobalWorkSizeQkv[0] = ROUND_UP(mGlobalWorkSizeQkv[0], std::max((uint32_t)1, mLocalWorkSizeQkv[0]));
mGlobalWorkSizeQkv[1] = ROUND_UP(mGlobalWorkSizeQkv[1], std::max((uint32_t)1, mLocalWorkSizeQkv[1]));
mGlobalWorkSizeQkv[2] = ROUND_UP(mGlobalWorkSizeQkv[2], std::max((uint32_t)1, mLocalWorkSizeQkv[2]));
mQkvUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &(*(mTempSoftMax.get()))()});
mQkvUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &(*(mPastValue.get()))()});
mQkvUpdateInfo.update_kernel_args.push_back({0, 8, sizeof(mKv_seq_len), &mKv_seq_len});
mOpRecordUpdateInfo.emplace_back(&mQkvUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, &mQkvUpdateInfo);
}
mOpenCLBackend->endRecord(mRecording);
return NO_ERROR;
}
ErrorCode AttentionBufImpl::onExecute(Backend *backend, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
#ifdef LOG_VERBOSE
MNN_PRINT("start AttentionBufExecution onExecute !\n");
#endif
mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
reallocKVCache();
#ifdef ENABLE_OPENCL_TIME_PROFILER
{
cl::Event event;
run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk,
mOpenCLBackend->getOpenCLRuntime(), &event);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qk_div_mask", event});
}
{
cl::Event event;
run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax,
mOpenCLBackend->getOpenCLRuntime(), &event);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"softmax", event});
}
{
cl::Event event;
run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv,
mOpenCLBackend->getOpenCLRuntime(), &event);
mOpenCLBackend->getOpenCLRuntime()->pushEvent({"matmul_qkv", event});
}
#else
if(mOpenCLBackend->isUseRecordQueue()){
mOpenCLBackend->addRecord(mRecording, mOpRecordUpdateInfo);
if(mIsDecode){
if(mIsFirstDecode){
mIsFirstDecode = false;
}else{
mPastLength += 1;
mKv_seq_len = mPastLength + 1;
int past_len4 = UP_DIV(mKv_seq_len, 4);
mSoftMaxRemainChannels = past_len4 * 4 - mKv_seq_len;
mSoftmaxShape[1] = past_len4;
mGlobalWorkSizeQk2 = past_len4;
mQkGlobal_size[2] = ROUND_UP(mGlobalWorkSizeQk2, std::max((uint32_t)1, mLocalWorkSizeQk[2]));
}
}
#ifdef LOG_VERBOSE
MNN_PRINT("End AttentionBufExecution onExecute... \n");
#endif
return NO_ERROR;
}
run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime());
run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime());
run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime());
#endif
// decode
if(mIsDecode){
mPastLength += 1;
mKv_seq_len = mPastLength + 1;
int past_len4 = UP_DIV(mKv_seq_len, 4);
mSoftMaxRemainChannels = past_len4 * 4 - mKv_seq_len;
mSoftmaxShape[1] = past_len4;
cl_int ret = CL_SUCCESS;
mGlobalWorkSizeQk2 = past_len4;
mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk2, std::max((uint32_t)1, mLocalWorkSizeQk[2]));
ret |= mKernel_qk->get().setArg(2, mGlobalWorkSizeQk2);
ret |= mKernel_qk->get().setArg(10, mKv_seq_len);
ret |= mKernel_softmax->get().setArg(5, mSoftMaxRemainChannels);
ret |= mKernel_softmax->get().setArg(6, mSoftmaxShape);
ret |= mKernel_qkv->get().setArg(8, mKv_seq_len);
MNN_CHECK_CL_SUCCESS(ret, "reset arg for AttentionBufExecution");
}
#ifdef LOG_VERBOSE
MNN_PRINT("end AttentionBufExecution onExecute !\n");
#endif
return NO_ERROR;
}
AttentionBufExecution::AttentionBufExecution(const MNN::Op *op, Backend* backend, bool kv_cahce) : CommonExecution(backend, op) {
mImpl.reset(new AttentionBufImpl(op, backend, kv_cahce));
}
AttentionBufExecution::AttentionBufExecution(std::shared_ptr<AttentionBufImpl> impl, const MNN::Op *op, Backend *backend) : CommonExecution(backend, op), mImpl(impl) {}
ErrorCode AttentionBufExecution::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
return mImpl->onResize(backend(), inputs, outputs);
}
ErrorCode AttentionBufExecution::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
return mImpl->onExecute(backend(), inputs, outputs);
}
bool AttentionBufExecution::onClone(Backend* bn, const Op* op, Execution** dst) {
if (nullptr == dst) {
return true;
}
*dst = new AttentionBufExecution(mImpl, op, bn);
return true;
}
class AttentionBufCreator : public OpenCLBackend::Creator {
public:
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const MNN::Op *op, Backend *backend) const override {
for (int i = 0; i < inputs.size(); ++i) {
TensorUtils::setTensorSupportPack(inputs[i], false);
}
for (int i = 0; i < outputs.size(); ++i) {
TensorUtils::setTensorSupportPack(outputs[i], false);
}
auto param = op->main_as_AttentionParam();
return new AttentionBufExecution(op, backend, param->kv_cache());
}
};
REGISTER_OPENCL_OP_CREATOR(AttentionBufCreator, OpType_Attention, BUFFER);
} // namespace OpenCL
} // namespace MNN
#endif/* MNN_OPENCL_BUFFER_CLOSED */