MNN/source/backend/cpu/CPUAttention.cpp

308 lines
14 KiB
C++
Raw Normal View History

2024-05-11 19:17:02 +08:00
//
// CPUAttention.cpp
// MNN
//
// Created by MNN on 2024/03/19.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
#include <limits>
#include "CPUAttention.hpp"
#include "CPUBackend.hpp"
#include "compute/CommonOptFunction.h"
#include "core/Macro.h"
#include "core/Concurrency.h"
#include "core/BufferAllocator.hpp"
#include "core/TensorUtils.hpp"
#include "core/OpCommonUtils.hpp"
#if defined (__aarch64__)
#define FLOAT16_T __fp16
#else
#define FLOAT16_T float
#endif
2024-07-04 11:53:45 +08:00
// reduce the value of 'query' to 'query * FP16_QSCALE', avoid fp16 overflow
#define FP16_QSCALE 0.5
2024-05-11 19:17:02 +08:00
namespace MNN {
template <typename T>
2024-07-22 19:51:53 +08:00
static void pack_query(Tensor* query, char* pack_q, int mNumHead, int mHeadDim, int eP, int seq_len, int h, float q_scale) {
T * query_src = query->host<T>();
T * query_dst = reinterpret_cast<T*>(pack_q);
for (int i = 0; i < seq_len; i++) {
int out_index = i / eP;
int in_index = i % eP;
2024-05-11 19:17:02 +08:00
for (int j = 0; j < mHeadDim; j++) {
2024-07-22 19:51:53 +08:00
query_dst[out_index * mHeadDim * eP + j * eP + in_index] = query_src[i * mNumHead * mHeadDim + h * mHeadDim + j] * q_scale;
}
}
}
2024-05-11 19:17:02 +08:00
template <typename T>
2024-07-22 19:51:53 +08:00
static void unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_len, int kv_seq_len, int unit) {
float * dst = unpack_qk_dst;
T * src = (T *)(pack_qk_src);
// [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len]
for (int i = 0; i < seq_len; i++) {
for (int j = 0; j < kv_seq_len; j++) {
int out_index = j / unit;
int in_index = j % unit;
dst[i * kv_seq_len + j] = src[out_index * seq_len * unit + i * unit + in_index];
2024-05-11 19:17:02 +08:00
}
}
}
template <typename T>
2024-07-22 19:51:53 +08:00
static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP) {
T * dst = reinterpret_cast<T*>(pack_qk_dst);
float * src = reinterpret_cast<float*>(qk_src);
// [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len, eP]
2024-05-11 19:17:02 +08:00
for (int i = 0; i < seq_len; i++) {
2024-07-22 19:51:53 +08:00
int out_index = i / eP;
int in_index = i % eP;
for (int j = 0; j < kv_seq_len; j++) {
dst[out_index * kv_seq_len * eP + j * eP + in_index] = src[i * kv_seq_len + j];
2024-05-11 19:17:02 +08:00
}
}
}
template <typename T>
2024-07-22 19:51:53 +08:00
static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, int * mask_ptr, bool float_mask) {
if (seq_len == 1) {
for (int i = 0; i < kv_seq_len; i++) {
unpack_qk[i] = unpack_qk[i] * mScale;
}
} else if (float_mask) {
2024-06-03 20:09:34 +08:00
// float mask
2024-07-22 19:51:53 +08:00
T* fpmask_ptr = reinterpret_cast<T*>(mask_ptr);
for (int i = 0; i < seq_len * kv_seq_len; i++) {
unpack_qk[i] = unpack_qk[i] * mScale + fpmask_ptr[i];
2024-06-03 20:09:34 +08:00
}
} else {
// int mask
2024-07-22 19:51:53 +08:00
for (int i = 0; i < seq_len * kv_seq_len; i++) {
2024-06-03 20:09:34 +08:00
if (mask_ptr[i]) {
2024-07-22 19:51:53 +08:00
unpack_qk[i] = unpack_qk[i] * mScale;
2024-06-03 20:09:34 +08:00
} else {
2024-07-22 19:51:53 +08:00
unpack_qk[i] = min_val;
2024-06-03 20:09:34 +08:00
}
2024-05-11 19:17:02 +08:00
}
}
2024-07-22 19:51:53 +08:00
}
static void softmax_QK(float* softmax_qk_addr, float* unpack_qk_addr, int seq_len, int kv_seq_len) {
for (int i = 0; i < seq_len; i++) { // softmax each row
MNNSoftmax(softmax_qk_addr + i * kv_seq_len, unpack_qk_addr + i * kv_seq_len, kv_seq_len);
2024-05-11 19:17:02 +08:00
}
}
template <typename T>
2024-07-22 19:51:53 +08:00
static void unpack_QKV(char* pack_qkv, char* unpack_qkv, int mNumHead, int mHeadDim, int unit, int seq_len) {
auto src_ptr = reinterpret_cast<T*>(pack_qkv);
auto dst_ptr = reinterpret_cast<T*>(unpack_qkv);
for (int i = 0; i < seq_len; i++) {
for (int j = 0; j < mHeadDim; j++) {
int a = j / unit;
int b = j % unit;
dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * unit + i * unit + b];
}
2024-05-11 19:17:02 +08:00
}
}
ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
2024-05-11 19:17:02 +08:00
auto core = static_cast<CPUBackend *>(backend())->functions();
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
2024-07-22 19:51:53 +08:00
mThreadNum = ((CPUBackend *)backend())->threadNumber();
unit = core->pack;
2024-07-04 11:53:45 +08:00
bytes = core->bytes;
2024-05-11 19:17:02 +08:00
auto query = inputs[0];
2024-07-22 19:51:53 +08:00
auto key = inputs[1];
int seq_len = query->shape()[1];
2024-08-24 15:46:21 +08:00
mNumHead = query->shape()[2];
mHeadDim = query->shape()[3];
mKvNumHead = key->shape()[2];
mKVCacheManager->onResize(mKvNumHead, mHeadDim);
mPackQ.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), mHeadDim, eP}));
mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit}));
2024-05-11 19:17:02 +08:00
backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC);
backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC);
backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC);
backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC);
return NO_ERROR;
}
ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
2024-08-24 15:46:21 +08:00
auto core = static_cast<CPUBackend *>(backend())->functions();
2024-05-11 19:17:02 +08:00
auto query = inputs[0];
2024-07-22 19:51:53 +08:00
auto key = inputs[1];
2024-05-11 19:17:02 +08:00
auto value = inputs[2];
auto mask = inputs[3];
2024-07-22 19:51:53 +08:00
auto mask_shape = mask->shape();
2024-06-03 20:09:34 +08:00
bool float_mask = (mask->getType() == halide_type_of<float>());
2024-07-22 19:51:53 +08:00
int mask_seqlen = mask_shape[2];
int mask_kvlen = mask_shape[3];
int seq_len = query->shape()[1];
MNN_ASSERT(seq_len == mask_seqlen);
mIsPrefill = (seq_len > 1);
// isPrefill and mask is Square Matrix, is FirstPrefill
mIsFirstPrefill = mIsPrefill && (mask_kvlen == mask_seqlen);
2024-08-24 15:46:21 +08:00
int tileCount = UP_DIV(mNumHead, mThreadNum);
int group_size = mNumHead / mKvNumHead;
2024-07-04 11:53:45 +08:00
// reduce the value of 'query' to avoid fp16 overflow
2024-08-24 15:46:21 +08:00
float mScale = 1.0 / sqrt(mHeadDim);
2024-07-04 11:53:45 +08:00
float q_scale = 1.0;
if (bytes == 2) {
q_scale = FP16_QSCALE;
2024-07-22 19:51:53 +08:00
mScale /= q_scale;
2024-07-04 11:53:45 +08:00
}
2024-05-11 19:17:02 +08:00
2024-07-22 19:51:53 +08:00
if (mIsPrefill) {
if (mIsFirstPrefill) {
2024-08-24 15:46:21 +08:00
mKVCacheManager->onClear();
mKVCacheManager->onAlloc(seq_len);
2024-07-22 19:51:53 +08:00
} else {
2024-08-24 15:46:21 +08:00
mKVCacheManager->onRealloc(mKVCacheManager->kvLength() + seq_len);
2024-07-22 19:51:53 +08:00
}
} else { // Decode
2024-08-24 15:46:21 +08:00
mKVCacheManager->onRealloc(mKVCacheManager->kvLength() + 1);
2024-07-04 11:53:45 +08:00
}
2024-08-24 15:46:21 +08:00
// Add the new kv to the kvcache
mKVCacheManager->onPushBack(key, value);
int kv_seq_len = mKVCacheManager->kvLength();
int max_len = mKVCacheManager->maxLength();
bool quant_key = mKVCacheManager->config()->mQuantKey;
bool quant_value = mKVCacheManager->config()->mQuantValue;
2024-07-22 19:51:53 +08:00
// Temporary tensors for intermediate results
std::shared_ptr<Tensor> packQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(kv_seq_len, unit), seq_len, unit}));
std::shared_ptr<Tensor> unpackQK(Tensor::createDevice<int32_t>({mThreadNum, seq_len, kv_seq_len}));
std::shared_ptr<Tensor> softmaxQK(Tensor::createDevice<int>({mThreadNum, seq_len, kv_seq_len}));
std::shared_ptr<Tensor> newPackQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), kv_seq_len, eP}));
2024-08-24 15:46:21 +08:00
std::shared_ptr<Tensor> dequantV(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP}));
2024-07-22 19:51:53 +08:00
backend()->onAcquireBuffer(packQK.get(), Backend::STATIC);
backend()->onAcquireBuffer(unpackQK.get(), Backend::STATIC);
backend()->onAcquireBuffer(softmaxQK.get(), Backend::STATIC);
backend()->onAcquireBuffer(newPackQK.get(), Backend::STATIC);
2024-08-24 15:46:21 +08:00
if (quant_value) {
2024-07-22 19:51:53 +08:00
backend()->onAcquireBuffer(dequantV.get(), Backend::STATIC);
2024-08-24 15:46:21 +08:00
mKVCacheManager->onDequantValue(dequantV.get());
2024-07-22 19:51:53 +08:00
}
2024-05-11 19:17:02 +08:00
2024-07-22 19:51:53 +08:00
std::function<void(int)> mCompute = [=](int tId) {
2024-08-24 15:46:21 +08:00
auto pack_q = mPackQ->host<char>() + tId * UP_DIV(seq_len, eP) * mHeadDim * eP * bytes;
2024-07-22 19:51:53 +08:00
auto pack_qk = packQK->host<char>() + tId * UP_DIV(kv_seq_len, unit) * seq_len * unit * bytes;
auto unpack_qk = unpackQK->host<float>() + tId * seq_len * kv_seq_len;
2024-08-24 15:46:21 +08:00
auto softmax_qk = softmaxQK->host<float>() + tId * seq_len * kv_seq_len;
2024-07-22 19:51:53 +08:00
auto new_pack_qk = newPackQK->host<char>() + tId * UP_DIV(seq_len, eP) * kv_seq_len * eP * bytes;
2024-08-24 15:46:21 +08:00
auto pack_qkv = mPackQKV->host<char>() + tId * UP_DIV(mHeadDim, unit) * seq_len * unit * bytes;
auto QxK = quant_key ? core->MNNPackedMatMul_int8 : core->MNNPackedMatMul;
auto QxK_remain = quant_key ? core->MNNPackedMatMulRemain_int8 : core->MNNPackedMatMulRemain;
int head_index = tId * tileCount;
for (int h = head_index; h < head_index + tileCount && h < mNumHead; h++) {
int kv_h = h / group_size;
char * key_addr = mKVCacheManager->addrOfKey(kv_h);
char * scale_addr = quant_key ? mKVCacheManager->addrOfScale(kv_h) : nullptr;
char * zero_point_addr = quant_key ? mKVCacheManager->addrOfZeroPoint(kv_h) : nullptr;
char * value_addr = quant_value ? dequantV->host<char>() + kv_h * UP_DIV(mHeadDim, hP) * kv_seq_len * hP * bytes : mKVCacheManager->addrOfValue(kv_h);
2024-05-11 19:17:02 +08:00
if (bytes == 2) {
2024-08-24 15:46:21 +08:00
pack_query<FLOAT16_T>(query, pack_q, mNumHead, mHeadDim, eP, seq_len, h, q_scale);
2024-05-11 19:17:02 +08:00
} else {
2024-08-24 15:46:21 +08:00
pack_query<float>(query, pack_q, mNumHead, mHeadDim, eP, seq_len, h, q_scale);
2024-05-11 19:17:02 +08:00
}
// query @ key
int loop_e = seq_len / eP;
int remain = seq_len % eP;
2024-08-24 15:46:21 +08:00
size_t shapeParameters[7] = {(size_t)eP * bytes, (size_t)mHeadDim, (size_t)kv_seq_len, (size_t)seq_len * unit * bytes, 0, 0, 0};
2024-05-11 19:17:02 +08:00
for (int i = 0 ; i < loop_e; i++) {
2024-08-24 15:46:21 +08:00
QxK((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + (i * mHeadDim * eP) * bytes), (float*)key_addr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
2024-05-11 19:17:02 +08:00
}
2024-08-24 15:46:21 +08:00
QxK_remain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + (loop_e * mHeadDim * eP) * bytes), (float*)key_addr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
// qk: [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len, eP]
2024-07-22 19:51:53 +08:00
if(bytes == 2) {
unpack_QK<FLOAT16_T>(unpack_qk, pack_qk, seq_len, kv_seq_len, unit);
mask_QK<FLOAT16_T>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask->host<int>(), float_mask);
softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len);
pack_QK<FLOAT16_T>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP);
2024-05-11 19:17:02 +08:00
} else {
2024-07-22 19:51:53 +08:00
unpack_QK<float>(unpack_qk, pack_qk, seq_len, kv_seq_len, unit);
mask_QK<float>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask->host<int>(), float_mask);
softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len);
pack_QK<float>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP);
}
2024-05-11 19:17:02 +08:00
// qk @ v
2024-08-24 15:46:21 +08:00
shapeParameters[1] = kv_seq_len;
shapeParameters[2] = mHeadDim;
shapeParameters[5] = quant_value ? 0 : (max_len - kv_seq_len) * hP * bytes;
2024-05-11 19:17:02 +08:00
for (int i = 0 ; i < loop_e; i++) {
2024-08-24 15:46:21 +08:00
core->MNNPackedMatMul((float*)(pack_qkv + (i * eP * unit) * bytes), (float*)(new_pack_qk + (i * kv_seq_len * eP) * bytes), (float*)value_addr, shapeParameters, nullptr, nullptr, nullptr, nullptr);
2024-05-11 19:17:02 +08:00
}
2024-08-24 15:46:21 +08:00
core->MNNPackedMatMulRemain((float*)(pack_qkv + (loop_e * eP * unit) * bytes), (float*)(new_pack_qk + (loop_e * kv_seq_len * eP) * bytes), (float*)value_addr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr);
2024-07-22 19:51:53 +08:00
// unpack: [head_dim/unit, seq_len, unit] -> [seq_len, num_head, head_dim]
2024-08-24 15:46:21 +08:00
auto dst_ptr = outputs[0]->host<char>() + h * mHeadDim * bytes;
2024-05-11 19:17:02 +08:00
if (bytes == 2) {
2024-08-24 15:46:21 +08:00
unpack_QKV<int16_t>(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len);
2024-05-11 19:17:02 +08:00
} else {
2024-08-24 15:46:21 +08:00
unpack_QKV<float>(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len);
2024-05-11 19:17:02 +08:00
}
}
};
MNN_CONCURRENCY_BEGIN(tId, mThreadNum) {
2024-07-22 19:51:53 +08:00
mCompute((int)tId);
2024-05-11 19:17:02 +08:00
}
MNN_CONCURRENCY_END();
2024-07-22 19:51:53 +08:00
backend()->onReleaseBuffer(packQK.get(), Backend::STATIC);
backend()->onReleaseBuffer(unpackQK.get(), Backend::STATIC);
backend()->onReleaseBuffer(softmaxQK.get(), Backend::STATIC);
backend()->onReleaseBuffer(newPackQK.get(), Backend::STATIC);
2024-08-24 15:46:21 +08:00
if (quant_value){
2024-07-22 19:51:53 +08:00
backend()->onReleaseBuffer(dequantV.get(), Backend::STATIC);
2024-07-04 11:53:45 +08:00
}
2024-05-11 19:17:02 +08:00
return NO_ERROR;
}
bool CPUAttention::onClone(Backend* bn, const Op* op, Execution** dst) {
if (nullptr == dst) {
return true;
}
auto tmp = new CPUAttention(bn, mKVCache);
2024-08-24 15:46:21 +08:00
tmp->mKVCacheManager = mKVCacheManager;
*dst = tmp;
2024-05-11 19:17:02 +08:00
return true;
}
2024-08-24 15:46:21 +08:00
CPUAttention::CPUAttention(Backend *backend, bool kv_cache) : Execution(backend), mKVCache(kv_cache) {
if (mKVCache) {
MNN::KVCacheManager::KVCacheConfig kvconfig;
int kvcacheQuantOptions = static_cast<CPUBackend *>(backend)->getRuntime()->hint().kvcacheQuantOption;
kvconfig.mQuantKey = (kvcacheQuantOptions & 1);
kvconfig.mQuantValue = ((kvcacheQuantOptions >> 1) & 1);
kvconfig.mKVCacheDir = static_cast<CPUBackend *>(backend)->getRuntime()->hint().kvcacheDirPath;
kvconfig.mKVCacheSizeLimit = static_cast<CPUBackend *>(backend)->getRuntime()->hint().kvcacheSizeLimit;
kvconfig.mExpandChunk = 64;
mKVCacheManager.reset(new KVCacheManager(backend, kvconfig));
}
}
CPUAttention::~CPUAttention() {
}
2024-05-11 19:17:02 +08:00
class CPUAttentionCreator : public CPUBackend::Creator {
public:
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const MNN::Op* op, Backend* backend) const override {
auto param = op->main_as_AttentionParam();
return new CPUAttention(backend, param->kv_cache());
}
};
REGISTER_CPU_OP_CREATOR_TRANSFORMER(CPUAttentionCreator, OpType_Attention);
} // namespace MNN
2024-08-24 15:46:21 +08:00
#endif // MNN_SUPPORT_TRANSFORMER_FUSE