mirror of https://github.com/alibaba/MNN.git
629 lines
24 KiB
Plaintext
629 lines
24 KiB
Plaintext
//
|
|
// MetalAttention.mm
|
|
// MNN
|
|
//
|
|
// Created by MNN on b'2024/04/29'.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include <set>
|
|
#import "core/Macro.h"
|
|
#import "MetalCast.hpp"
|
|
#import "MetalBackend.hpp"
|
|
#import "MNNMetalContext.h"
|
|
#import "MetalAttentionShader.hpp"
|
|
#include "MNN_generated.h"
|
|
#include "core/OpCommonUtils.hpp"
|
|
|
|
#if MNN_METAL_ENABLED
|
|
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
|
|
|
namespace MNN {
|
|
class AttentionBufExecution : public MetalExecution {
|
|
public:
|
|
struct SharedCache {
|
|
std::shared_ptr<Tensor> mPastKey;
|
|
std::shared_ptr<Tensor> mPastValue;
|
|
int mPastLength = 0, mMaxLength = 0, mKv_seq_len = 0;
|
|
};
|
|
AttentionBufExecution(Backend *backend, bool kv_cache);
|
|
|
|
virtual ~AttentionBufExecution() = default;
|
|
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
|
|
|
virtual void onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) override;
|
|
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override {
|
|
if (nullptr == dst) {
|
|
return true;
|
|
}
|
|
auto exe = new AttentionBufExecution(bn, mKVCache);
|
|
exe->mCache = mCache;
|
|
*dst = exe;
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
void _init();
|
|
void reallocKVCache();
|
|
void compilerShader(const std::vector<Tensor *> &inputs);
|
|
void handleKVAllocMemory();
|
|
bool mKVCache;
|
|
std::shared_ptr<SharedCache> mCache;
|
|
float mScale;
|
|
const int mExpandChunk = 64;
|
|
bool mShortSeq = false;
|
|
std::shared_ptr<Tensor> mTempQK, mTempSoftMax;
|
|
int mNumHead = 0, mHeadDim = 0, mValueH = 0, mKvNumHead = 0;
|
|
int mSeqLen;
|
|
id<MTLComputePipelineState> mKernel_softmax = nil;
|
|
|
|
id<MTLComputePipelineState> mKernel_qk = nil;
|
|
id<MTLComputePipelineState> mKernel_qkv = nil;
|
|
id<MTLComputePipelineState> mKernel_copy = nil;
|
|
id<MTLComputePipelineState> mKernelPrefill_qk = nil;
|
|
id<MTLComputePipelineState> mKernelPrefill_qkv = nil;
|
|
id<MTLBuffer> mParamQKV;
|
|
id<MTLBuffer> mParamSoftmax;
|
|
id<MTLBuffer> mParamCopy;
|
|
|
|
private:
|
|
KVMeta* mMeta;
|
|
bool mQkSimdReduce = false;
|
|
bool mQkSimdMatrix = false;
|
|
bool mSftmSimdReduce = false;
|
|
bool mQkvSimdReduce = false;
|
|
bool mQkvSimdMatrix = false;
|
|
private:
|
|
bool mHasMask = false;
|
|
bool mIsAddMask = false;
|
|
int mBatch, mKvSeqLen, mKvMaxLen;
|
|
int mQseqSplitNum = 1;
|
|
std::shared_ptr<Tensor> mTempK, mTempV;
|
|
};
|
|
|
|
struct Param {
|
|
int query_seq_len;
|
|
int q_seq_piece_len;
|
|
int key_seq_len;
|
|
int head_num;
|
|
int group;
|
|
int head_dim;
|
|
float scale;
|
|
int max_kv_len;
|
|
int batch;
|
|
};
|
|
AttentionBufExecution::AttentionBufExecution(Backend *backend, bool kv_cahce)
|
|
: MetalExecution(backend) , mKVCache(kv_cahce) {
|
|
_init();
|
|
}
|
|
void AttentionBufExecution::_init() {
|
|
mCache.reset(new SharedCache);
|
|
auto mtbn = static_cast<MetalBackend *>(backend());
|
|
auto context = (__bridge MNNMetalContext *)mtbn->context();
|
|
mMeta = (KVMeta*)(mtbn->getMetaPtr());
|
|
|
|
mParamQKV = [context newDeviceBuffer:sizeof(Param) access:CPUWriteOnly];
|
|
mParamSoftmax = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly];
|
|
mParamCopy = [context newDeviceBuffer:6 * sizeof(int) access:CPUWriteOnly];
|
|
mTempQK.reset(Tensor::createDevice<float>({0, 0}));
|
|
mTempSoftMax.reset(Tensor::createDevice<float>({0, 0}));
|
|
}
|
|
|
|
void AttentionBufExecution::reallocKVCache() {
|
|
if (!mKVCache) {
|
|
return;
|
|
}
|
|
auto kv_seq_len = mMeta->previous + mMeta->add - mMeta->remove + mMeta->computeReverseSize();
|
|
|
|
auto mtbn = static_cast<MetalBackend *>(backend());
|
|
int byte = 4;
|
|
if(mtbn->useFp16InsteadFp32()) {
|
|
byte = 2;
|
|
}
|
|
|
|
auto start = mCache->mPastLength - mMeta->remove;
|
|
// latest length larger than maxLen
|
|
if (kv_seq_len > mCache->mMaxLength) {
|
|
|
|
auto copy_len = mCache->mPastLength - mMeta->remove + mMeta->computeReverseSize();
|
|
bool needCopy = copy_len > 0;
|
|
|
|
size_t old_size = mKvNumHead * start * mHeadDim * byte;
|
|
size_t old_piece_size = start * byte;
|
|
size_t old_piece_stride = mCache->mMaxLength * byte;
|
|
|
|
mCache->mMaxLength = kv_seq_len + mExpandChunk;
|
|
// past_key: [1, numhead, headdim, maxlen]
|
|
auto new_key = Tensor::createDevice<float>({mCache->mMaxLength, mKvNumHead, mHeadDim});
|
|
// past_value: [1, numhead, maxlen, headdim]
|
|
auto new_value = Tensor::createDevice<float>({mKvNumHead, mHeadDim, mCache->mMaxLength});
|
|
size_t size = mKvNumHead * mCache->mMaxLength * mHeadDim * byte;
|
|
auto res = backend()->onAcquireBuffer(new_key, Backend::STATIC);
|
|
res = res && backend()->onAcquireBuffer(new_value, Backend::STATIC);
|
|
if(!res) {
|
|
MNN_ERROR("attition kv cache realloc memory error:%d\n", res);
|
|
}
|
|
if (needCopy) {
|
|
auto newKeyBuf = MetalBackend::getBuffer(new_key);
|
|
auto new_key_ptr = (uint8_t*)[newKeyBuf.first contents] + newKeyBuf.second;
|
|
auto keyBuf = MetalBackend::getBuffer(mCache->mPastKey.get());
|
|
auto key_ptr = (uint8_t*)[keyBuf.first contents] + keyBuf.second;;
|
|
::memcpy(new_key_ptr, key_ptr, old_size);
|
|
|
|
auto newValueBuf = MetalBackend::getBuffer(new_value);
|
|
auto new_value_ptr = (uint8_t*)[newValueBuf.first contents] + newValueBuf.second;
|
|
auto valueBuf = MetalBackend::getBuffer(mCache->mPastValue.get());
|
|
auto value_ptr = (uint8_t*)[valueBuf.first contents] + valueBuf.second;
|
|
for(int i = 0; i < mKvNumHead * mHeadDim; i++) {
|
|
::memcpy(new_value_ptr + i * mCache->mMaxLength * byte, value_ptr + i * old_piece_stride, old_piece_size);
|
|
}
|
|
}
|
|
mCache->mPastLength = (int)start;
|
|
|
|
mCache->mPastKey.reset(new_key);
|
|
mCache->mPastValue.reset(new_value);
|
|
}
|
|
|
|
// Remove
|
|
{
|
|
if (0 == mMeta->n_reserve) {
|
|
mCache->mPastLength = start;
|
|
return;
|
|
}
|
|
|
|
auto keyBuf = MetalBackend::getBuffer(mCache->mPastKey.get());
|
|
auto key_ptr = (uint8_t*)[keyBuf.first contents] + keyBuf.second;
|
|
auto valueBuf = MetalBackend::getBuffer(mCache->mPastValue.get());
|
|
auto value_ptr = (uint8_t*)[valueBuf.first contents] + valueBuf.second;
|
|
|
|
// TODO: need to ensure reserve info is sorted
|
|
for (int n = 0; n < mMeta->n_reserve; ++n) {
|
|
auto begin = mMeta->reserve[2 * n];
|
|
auto length = mMeta->reserve[2 * n + 1];
|
|
// past_key : [mCache->mPastLength, mKvNumHead, mHeadDim]
|
|
// past_value : [mKvNumHead, mHeadDim, mCache->mMaxLength]
|
|
|
|
auto copy_src_index = start + begin;
|
|
auto copy_dst_index = start;
|
|
for(int i = 0; i < length; i++) {
|
|
::memcpy(key_ptr + (copy_dst_index + i) * mKvNumHead * mHeadDim * byte, key_ptr + (copy_src_index + i) * mKvNumHead * mHeadDim * byte, mKvNumHead * mHeadDim * byte);
|
|
}
|
|
for(int j = 0; j < mKvNumHead * mHeadDim; j++) {
|
|
for(int i = 0; i < length; i++) {
|
|
::memcpy(value_ptr + (j * mCache->mMaxLength + copy_dst_index + i) * byte, value_ptr + (j * mCache->mMaxLength + copy_src_index + i) * byte, byte);
|
|
}
|
|
}
|
|
start += length;
|
|
}
|
|
mCache->mPastLength = (int)start;
|
|
}
|
|
}
|
|
|
|
void AttentionBufExecution::compilerShader(const std::vector<Tensor *> &inputs) {
|
|
auto mtbn = static_cast<MetalBackend *>(backend());
|
|
auto rt = (MetalRuntime*)mtbn->runtime();
|
|
auto context = (__bridge MNNMetalContext *)mtbn->context();
|
|
|
|
auto seq_len = inputs[0]->length(1);
|
|
int group_size = inputs[0]->length(2) / inputs[1]->length(2);
|
|
std::string group_str = std::to_string(group_size);
|
|
|
|
// Init Kernel
|
|
std::string T = "float";
|
|
std::string T4 = "float4";
|
|
if (mtbn->useFp16InsteadFp32()) {
|
|
T = "half";
|
|
T4 = "half4";
|
|
}
|
|
std::vector<std::string> qkKeys = {
|
|
{"matmul_qk_div_mask", T, group_str}
|
|
};
|
|
if(mHeadDim % 4 != 0) {
|
|
qkKeys.emplace_back("HEAD_DIM_UNALIGNED_4");
|
|
}
|
|
|
|
std::vector<std::string> qkvKeys = {
|
|
{"matmul_qkv", T, group_str}
|
|
};
|
|
if(mQkvSimdReduce) {
|
|
qkvKeys.emplace_back("SIMD_GROUP_REDUCE");
|
|
}
|
|
std::vector<std::string> qkPrefillKeys = {
|
|
{"matmul_qk_div_mask", T, group_str, "FOR_PREFILL"}
|
|
};
|
|
if(mHasMask) {
|
|
if (mIsAddMask) {
|
|
qkPrefillKeys.emplace_back("ADD_MASK");
|
|
if(seq_len > 1) {
|
|
qkKeys.emplace_back("ADD_MASK");
|
|
}
|
|
} else {
|
|
qkPrefillKeys.emplace_back("SET_MASK");
|
|
if(seq_len > 1) {
|
|
qkKeys.emplace_back("SET_MASK");
|
|
}
|
|
}
|
|
}
|
|
if(mQkSimdMatrix) {
|
|
qkPrefillKeys.emplace_back("SIMD_GROUP_MATRIX");
|
|
}
|
|
std::vector<std::string> qkvPrefillKeys = {
|
|
{"matmul_qkv", T, group_str, "FOR_PREFILL"}
|
|
};
|
|
if(mQkvSimdMatrix) {
|
|
qkvPrefillKeys.emplace_back("SIMD_GROUP_MATRIX");
|
|
}
|
|
std::vector<std::string> copyPastKeys = {
|
|
{"pastkv_copy", T, group_str}
|
|
};
|
|
std::vector<std::vector<std::string>> keys = {
|
|
qkKeys,
|
|
qkvKeys,
|
|
qkPrefillKeys,
|
|
qkvPrefillKeys,
|
|
copyPastKeys
|
|
};
|
|
std::vector<const char*> sources = {
|
|
gMatMulDivMask,
|
|
gMatMulQKV,
|
|
gMatMulDivMask,
|
|
gMatMulQKV,
|
|
gCopyPastKV
|
|
};
|
|
std::vector<std::string> shaders = {
|
|
"decode_qk",
|
|
"decode_qkv",
|
|
"prefill_qk",
|
|
"prefill_qkv",
|
|
"copy"
|
|
};
|
|
std::vector<id<MTLComputePipelineState>> pipelines(keys.size());
|
|
for (int i=0; i<keys.size(); ++i) {
|
|
auto pipeline = rt->findPipeline(keys[i]);
|
|
if (nil == pipeline) {
|
|
// Rebuild Pipeline
|
|
MTLCompileOptions *option = [[MTLCompileOptions alloc] init];
|
|
auto dic = [NSMutableDictionary dictionaryWithCapacity:0];
|
|
[dic setValue:@(keys[i][1].c_str()) forKey:@"T"];
|
|
[dic setValue:@(T4.c_str()) forKey:@"T4"];
|
|
[dic setValue:@(keys[i][2].c_str()) forKey:@"GROUP_SIZE"];
|
|
for (int j=3; j<keys[i].size(); ++j) {
|
|
[dic setValue:@"1" forKey:@(keys[i][j].c_str())];;
|
|
}
|
|
option.preprocessorMacros = dic;
|
|
|
|
pipeline = mtbn->makeComputePipelineWithSourceOption(sources[i], shaders[i].c_str(), option);
|
|
rt->insertPipeline(keys[i], pipeline);
|
|
}
|
|
pipelines[i] = pipeline;
|
|
}
|
|
mKernel_qk = pipelines[0];
|
|
mKernel_qkv = pipelines[1];
|
|
mKernelPrefill_qk = pipelines[2];
|
|
mKernelPrefill_qkv = pipelines[3];
|
|
mKernel_copy = pipelines[4];
|
|
MNN_ASSERT(nil != mKernel_qk);
|
|
MNN_ASSERT(nil != mKernel_qkv);
|
|
MNN_ASSERT(nil != mKernelPrefill_qk);
|
|
MNN_ASSERT(nil != mKernelPrefill_qkv);
|
|
MNN_ASSERT(nil != mKernel_copy);
|
|
|
|
if(mSftmSimdReduce) {
|
|
// basic marco info
|
|
std::string ftype = "float";
|
|
std::string ftype4 = "float4";
|
|
if (mtbn->useFp16InsteadFp32()) {
|
|
ftype = "half";
|
|
ftype4 = "half4";
|
|
}
|
|
|
|
MTLCompileOptions *option = [[MTLCompileOptions alloc] init];
|
|
auto dic = [NSMutableDictionary dictionaryWithCapacity:0];
|
|
option.preprocessorMacros = @{
|
|
@"ftype" : @(ftype.c_str()),
|
|
@"ftype4" : @(ftype4.c_str()),
|
|
};
|
|
std::vector<std::string> keys = {"softmax_sg_reduce", ftype};
|
|
keys.emplace_back("softmax_plane_sg");
|
|
auto pipeline = rt->findPipeline(keys);
|
|
if (nil == pipeline) {
|
|
pipeline = mtbn->makeComputePipelineWithSourceOption(gSoftmaxSgReduce, keys.back().c_str(), option);
|
|
rt->insertPipeline(keys, pipeline);
|
|
}
|
|
mKernel_softmax = pipeline;
|
|
} else {
|
|
mKernel_softmax = [context pipelineWithName:@"softmax_plane" fp16:mtbn->useFp16InsteadFp32()];
|
|
}
|
|
|
|
}
|
|
|
|
void AttentionBufExecution::handleKVAllocMemory() {
|
|
if(mKVCache) {
|
|
mCache->mPastLength = mMeta != nullptr ? mMeta->previous : 0;
|
|
// kv-cache realloc function
|
|
reallocKVCache();
|
|
mCache->mKv_seq_len = mCache->mPastLength + mSeqLen;
|
|
mKvSeqLen = mCache->mKv_seq_len;
|
|
mKvMaxLen = mCache->mMaxLength;
|
|
|
|
float useMemorySize = 1.0 * mKvMaxLen / 1024.0 * mSeqLen / 1024.0 * mBatch * mNumHead;
|
|
// elementSize larger than 32M
|
|
mQseqSplitNum = 1;
|
|
if(useMemorySize > 32.0) {
|
|
mQseqSplitNum = useMemorySize >= 256.0 ? 16 : ((useMemorySize < 128.0) ? 4 : 8);
|
|
}
|
|
|
|
int qSeqLenPiece = UP_DIV(mSeqLen, mQseqSplitNum);
|
|
// temp tensor alloc memory
|
|
bool needMalloc = mTempQK->length(0) != mBatch * mNumHead;
|
|
if (mTempQK->length(1) != qSeqLenPiece * mKvMaxLen) {
|
|
needMalloc = true;
|
|
}
|
|
mTempQK->setLength(0, mBatch * mNumHead);
|
|
mTempQK->setLength(1, qSeqLenPiece * mKvMaxLen);
|
|
mTempSoftMax->setLength(0, mBatch * mNumHead);
|
|
mTempSoftMax->setLength(1, qSeqLenPiece * mKvMaxLen);
|
|
|
|
if (needMalloc) {
|
|
auto res = backend()->onAcquireBuffer(mTempQK.get(), Backend::STATIC) && backend()->onAcquireBuffer(mTempSoftMax.get(), Backend::STATIC);
|
|
if (!res) {
|
|
MNN_ERROR("MNN::Metal: OUT_OF_MEMORY when execute attention metal %d\n", res);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
ErrorCode AttentionBufExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
mHasMask = inputs.size() > 3;
|
|
if(mHasMask) {
|
|
mIsAddMask = (inputs[3]->getType() == halide_type_of<float>());
|
|
}
|
|
auto query = inputs[0];
|
|
auto key = inputs[1];
|
|
auto value = inputs[2];
|
|
auto mtbn = static_cast<MetalBackend *>(backend());
|
|
auto context = (__bridge MNNMetalContext *)mtbn->context();
|
|
auto shape = query->shape();
|
|
mBatch = shape[0];
|
|
mSeqLen = shape[1];
|
|
mNumHead = shape[2];
|
|
mHeadDim = shape[3];
|
|
mScale = 1.0 / sqrt(mHeadDim);
|
|
// TODO : define short_seq more accurately
|
|
mShortSeq = mSeqLen <= 10;
|
|
mKvNumHead = key->shape()[2];
|
|
mKvSeqLen = key->shape()[1];
|
|
mKvMaxLen = ROUND_UP(mKvSeqLen, 4);
|
|
|
|
if(mKVCache) {
|
|
return NO_ERROR;
|
|
}
|
|
|
|
float useMemorySize = 1.0 * mKvMaxLen / 1024.0 * mSeqLen / 1024.0 * mBatch * mNumHead;
|
|
// elementSize larger than 32M
|
|
mQseqSplitNum = 1;
|
|
if(useMemorySize > 32.0) {
|
|
mQseqSplitNum = useMemorySize >= 256.0 ? 8 : ((useMemorySize < 128.0) ? 2 : 4);
|
|
}
|
|
|
|
// no kv_cache memory, should create temp q/k memory
|
|
mTempK.reset(Tensor::createDevice<float>({mKvMaxLen * mHeadDim * mBatch * mKvNumHead}));
|
|
mTempV.reset(Tensor::createDevice<float>({mKvMaxLen * mHeadDim * mBatch * mKvNumHead}));
|
|
mTempQK.reset(Tensor::createDevice<float>({mKvMaxLen * UP_DIV(mSeqLen, mQseqSplitNum) * mBatch * mNumHead}));
|
|
mTempSoftMax.reset(Tensor::createDevice<float>({mKvMaxLen * UP_DIV(mSeqLen, mQseqSplitNum) * mBatch * mNumHead}));
|
|
|
|
backend()->onAcquireBuffer(mTempK.get(), Backend::DYNAMIC);
|
|
backend()->onAcquireBuffer(mTempV.get(), Backend::DYNAMIC);
|
|
backend()->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC);
|
|
backend()->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC);
|
|
backend()->onReleaseBuffer(mTempK.get(), Backend::DYNAMIC);
|
|
backend()->onReleaseBuffer(mTempV.get(), Backend::DYNAMIC);
|
|
backend()->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC);
|
|
backend()->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC);
|
|
return NO_ERROR;
|
|
}
|
|
void AttentionBufExecution::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, id<MTLComputeCommandEncoder> encoder) {
|
|
if(mKVCache) {
|
|
// if has kv_cache, default has mask
|
|
MNN_ASSERT(inputs.size() > 3);
|
|
}
|
|
auto query = inputs[0];
|
|
auto key = inputs[1];
|
|
auto value = inputs[2];
|
|
auto mtbn = static_cast<MetalBackend *>(backend());
|
|
auto context = (__bridge MNNMetalContext *)mtbn->context();
|
|
auto rt = (MetalRuntime*)mtbn->runtime();
|
|
|
|
int group_size = mNumHead / mKvNumHead;
|
|
|
|
// temp memory alloc, handle variable set
|
|
Tensor* tempTensorK;
|
|
Tensor* tempTensorV;
|
|
handleKVAllocMemory();
|
|
|
|
if(mKVCache) {
|
|
tempTensorK = mCache->mPastKey.get();
|
|
tempTensorV = mCache->mPastValue.get();
|
|
} else {
|
|
tempTensorK = mTempK.get();
|
|
tempTensorV = mTempV.get();
|
|
}
|
|
|
|
// whether use simdgroup
|
|
bool supportSimdReduce = rt->supportSimdGroupReduce();
|
|
bool supportSimdMatrix = rt->supportSimdGroupMatrix();
|
|
|
|
// decode and thread number not too large
|
|
mQkSimdReduce = supportSimdReduce && mShortSeq;
|
|
// loop_k can divide 8, thus avoid branch
|
|
mQkSimdMatrix = supportSimdMatrix && mSeqLen >= 16 && mHeadDim % 8 == 0;
|
|
|
|
mSftmSimdReduce = supportSimdReduce;
|
|
mQkvSimdReduce = supportSimdReduce && mShortSeq && mHeadDim * mNumHead < mKvSeqLen * 32;
|
|
mQkvSimdMatrix = supportSimdMatrix && mSeqLen >= 16;
|
|
|
|
// start to compile attention shaders
|
|
compilerShader(inputs);
|
|
|
|
// Run Copy and Format-Convert Kernel
|
|
{
|
|
auto copyp = (int*)mParamCopy.contents;
|
|
/*
|
|
Key -> K-Cache : [mBatch, mKvSeqLen, mKvNumHead, mHeadDim] -> [mKvMaxLen, mBatch, mKvNumHead, mHeadDim]
|
|
Value -> V-Cache : [mBatch, mKvSeqLen, mKvNumHead, mHeadDim] -> [mBatch, mKvNumHead, mHeadDim, mKvMaxLen (fill when decode)]
|
|
*/
|
|
copyp[0] = mKvNumHead * mHeadDim;
|
|
// current new kv_len
|
|
copyp[1] = key->shape()[1];
|
|
copyp[2] = mKvMaxLen;
|
|
copyp[3] = mCache->mPastLength * copyp[0];
|
|
copyp[4] = mCache->mPastLength;
|
|
copyp[5] = mBatch;
|
|
int copy_line = key->shape()[1];
|
|
|
|
id<MTLComputePipelineState> pipeline = mKernel_copy;
|
|
[encoder setComputePipelineState:pipeline];
|
|
MetalBackend::setTensor(key, encoder, 0);
|
|
MetalBackend::setTensor(value, encoder, 1);
|
|
MetalBackend::setTensor(tempTensorK, encoder, 2);
|
|
MetalBackend::setTensor(tempTensorV, encoder, 3);
|
|
[encoder setBuffer:mParamCopy offset:0 atIndex:4];
|
|
|
|
std::pair<MTLSize, MTLSize> gl;
|
|
gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(mKvNumHead * mHeadDim, copy_line, mBatch)];
|
|
|
|
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
|
|
|
|
}
|
|
|
|
// Update Parameters
|
|
int seqLenPiece = UP_DIV(mSeqLen, mQseqSplitNum);
|
|
{
|
|
auto param = (Param*)mParamQKV.contents;
|
|
param->scale = mScale;
|
|
param->head_dim = mHeadDim;
|
|
param->key_seq_len = mKvSeqLen;
|
|
param->head_num = mNumHead;
|
|
param->group = group_size;
|
|
param->query_seq_len = mSeqLen;
|
|
param->q_seq_piece_len = seqLenPiece;
|
|
param->max_kv_len = mKvMaxLen;
|
|
param->batch = mBatch;
|
|
}
|
|
|
|
for(int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) {
|
|
// Run QK Kernel
|
|
{
|
|
id<MTLComputePipelineState> pipeline;
|
|
if (mShortSeq) {
|
|
pipeline = mKernel_qk;
|
|
} else {
|
|
pipeline = mKernelPrefill_qk;
|
|
}
|
|
[encoder setComputePipelineState:pipeline];
|
|
// [mBatch, mSeqLen, mNumHead, mHeadDim]
|
|
MetalBackend::setTensor(query, encoder, 0);
|
|
// [mBatch, mNumHead, mSeqLen, mKvSeqLen]
|
|
MetalBackend::setTensor(mTempQK.get(), encoder, 1);
|
|
// [mKvSeqLen, mBatch, mKvNumHead, mHeadDim]
|
|
MetalBackend::setTensor(tempTensorK, encoder, 2);
|
|
[encoder setBytes:&seq_idx length:sizeof(seq_idx) atIndex:3];
|
|
[encoder setBuffer:mParamQKV offset:0 atIndex:4];
|
|
if(mHasMask) {
|
|
MetalBackend::setTensor(inputs[3], encoder, 5);
|
|
}
|
|
|
|
int decode_grid_y = mBatch * mNumHead;
|
|
std::pair<MTLSize, MTLSize> gl;
|
|
if(mShortSeq) {
|
|
gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seqLenPiece, decode_grid_y / group_size, mKvSeqLen)];
|
|
} else if(mQkSimdMatrix) {
|
|
gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 16), UP_DIV(mKvSeqLen, 16), decode_grid_y), MTLSizeMake(32, 1, 1));
|
|
} else {
|
|
gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seqLenPiece, decode_grid_y, mKvSeqLen)];
|
|
}
|
|
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
|
|
|
|
}
|
|
// Run Softmax Kernel
|
|
{
|
|
// For softmax parameter
|
|
// [mBatch, mNumHead, mSeqLen, mKvSeqLen]
|
|
int inside = 1;
|
|
int outside = mBatch * mNumHead * seqLenPiece;
|
|
int axis = mKvSeqLen;
|
|
{
|
|
auto softmax = (int*)mParamSoftmax.contents;
|
|
// Inside, axis, outside, plane(invalid)
|
|
softmax[0] = inside;
|
|
softmax[1] = axis;
|
|
softmax[2] = outside;
|
|
softmax[3] = 0;
|
|
}
|
|
[encoder setComputePipelineState:mKernel_softmax];
|
|
MetalBackend::setTensor(mTempQK.get(), encoder, 0);
|
|
MetalBackend::setTensor(mTempSoftMax.get(), encoder, 1);
|
|
[encoder setBuffer:mParamSoftmax offset:0 atIndex:2];
|
|
|
|
int thread_group_size = 32;
|
|
std::pair<MTLSize, MTLSize> gl;
|
|
if(mSftmSimdReduce) {
|
|
gl = std::make_pair(MTLSizeMake(inside, outside, 1), MTLSizeMake(thread_group_size, 1, 1));
|
|
} else {
|
|
gl = [context computeBestGroupAndLocal: mKernel_softmax threads:MTLSizeMake(inside, outside, 1)];
|
|
}
|
|
|
|
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
|
|
|
|
}
|
|
// Run QKV Kernel
|
|
{
|
|
|
|
id<MTLComputePipelineState> pipeline;
|
|
if (mShortSeq) {
|
|
pipeline = mKernel_qkv;
|
|
} else {
|
|
pipeline = mKernelPrefill_qkv;
|
|
}
|
|
[encoder setComputePipelineState:pipeline];
|
|
// [mBatch, mNumHead, mSeqLen, mKvSeqLen]
|
|
MetalBackend::setTensor(mTempSoftMax.get(), encoder, 0);
|
|
// [mBatch, mSeqLen, mNumHead, mHeadDim]
|
|
MetalBackend::setTensor(outputs[0], encoder, 1);
|
|
// [mBatch, mKvNumHead, mHeadDim, mMaxSeqLen]
|
|
MetalBackend::setTensor(tempTensorV, encoder, 2);
|
|
[encoder setBytes:&seq_idx length:sizeof(seq_idx) atIndex:3];
|
|
[encoder setBuffer:mParamQKV offset:0 atIndex:4];
|
|
std::pair<MTLSize, MTLSize> gl;
|
|
if(mQkvSimdReduce) {
|
|
gl = std::make_pair(MTLSizeMake(seqLenPiece, mBatch * mNumHead, mHeadDim), MTLSizeMake(32, 1, 1));
|
|
} else if(mQkvSimdMatrix){
|
|
gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 16), UP_DIV(mHeadDim, 16), mBatch * mNumHead), MTLSizeMake(32, 1, 1));
|
|
} else {
|
|
gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seqLenPiece, mBatch * mNumHead, mHeadDim)];
|
|
}
|
|
[encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second];
|
|
|
|
}
|
|
}
|
|
// Update status
|
|
if(mKVCache) {
|
|
mCache->mPastLength += mSeqLen;
|
|
}
|
|
return;
|
|
}
|
|
|
|
class AttentionBufCreator : public MetalBackend::Creator {
|
|
public:
|
|
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend, const std::vector<Tensor *> &outputs) const override {
|
|
auto param = op->main_as_AttentionParam();
|
|
return new AttentionBufExecution(backend, param->kv_cache());
|
|
}
|
|
};
|
|
REGISTER_METAL_OP_TRANSFORMER_CREATOR(AttentionBufCreator, OpType_Attention);
|
|
|
|
} // namespace MNN
|
|
#endif/* MNN_SUPPORT_TRANSFORMER_FUSE */
|
|
#endif
|
|
|