2024-08-24 15:46:21 +08:00
//
// KVCacheManager.cpp
// MNN
//
// Created by MNN on 2024/08/05.
// Copyright © 2018, Alibaba Group Holding Limited
//
# ifdef MNN_SUPPORT_TRANSFORMER_FUSE
# include "KVCacheManager.hpp"
# include "core/Concurrency.h"
namespace MNN {
2024-09-12 12:57:57 +08:00
// Translate an address to a hex number string
2024-08-24 15:46:21 +08:00
static inline std : : string addrToHex ( void * addr ) {
std : : string result = " " ;
uint64_t n = ( uint64_t ) addr ;
for ( int i = 15 ; i > = 0 ; i - - ) {
int t = ( n > > ( i * 4 ) ) & 0x0f ;
result . push_back ( ( t < 10 ) ? ( ' 0 ' + t ) : ( ' A ' + t - 10 ) ) ;
}
return result ;
}
void KVCacheManager : : createKVCacheFile ( ) {
// Each layer has its own kvcache, so we have to create a key file and a value file for each layer and the file name must be unique
// Here we use the address of the mResource as the file name because the addresses of mResource in different layers are guaranteed to be different
std : : string fileName = addrToHex ( this ) ;
std : : string pathk = MNNFilePathConcat ( mConfig . mKVCacheDir , fileName ) + " .k " ;
std : : string pathv = MNNFilePathConcat ( mConfig . mKVCacheDir , fileName ) + " .v " ;
mKeyCacheFD = MNNCreateFile ( pathk . c_str ( ) ) ;
mValueCacheFD = MNNCreateFile ( pathv . c_str ( ) ) ;
if ( mKeyCacheFD = = INVALID_FILE ) {
MNN_PRINT ( " Failed to create the file: %s \n " , pathk . c_str ( ) ) ;
}
if ( mValueCacheFD = = INVALID_FILE ) {
MNN_PRINT ( " Failed to create the file: %s \n " , pathv . c_str ( ) ) ;
}
}
void KVCacheManager : : removeKVCacheFile ( ) {
std : : string fileName = addrToHex ( this ) ;
std : : string pathk = MNNFilePathConcat ( mConfig . mKVCacheDir , fileName ) + " .k " ;
std : : string pathv = MNNFilePathConcat ( mConfig . mKVCacheDir , fileName ) + " .v " ;
if ( mKeyCacheFD ! = INVALID_FILE ) {
MNNCloseFile ( mKeyCacheFD ) ;
mKeyCacheFD = INVALID_FILE ;
if ( MNNRemoveFile ( pathk . c_str ( ) ) ! = MNN : : NO_ERROR ) {
MNN_PRINT ( " Failed to remove the file: %s \n " , pathk . c_str ( ) ) ;
}
}
if ( mValueCacheFD ! = INVALID_FILE ) {
MNNCloseFile ( mValueCacheFD ) ;
mValueCacheFD = INVALID_FILE ;
if ( MNNRemoveFile ( pathv . c_str ( ) ) ! = MNN : : NO_ERROR ) {
MNN_PRINT ( " Failed to remove the file: %s \n " , pathv . c_str ( ) ) ;
}
}
}
void KVCacheManager : : resetKVCacheFileSize ( size_t keySize , size_t valueSize ) {
if ( MNNSetFileSize ( mKeyCacheFD , keySize ) ! = MNN : : NO_ERROR | | MNNSetFileSize ( mValueCacheFD , valueSize ) ! = MNN : : NO_ERROR ) {
MNN_PRINT ( " Failed to resize the kvcache files! \n " ) ;
}
}
/*
* * @ brief Memory - map the kvcache file
* * @ hint After memory - mapping , we can access the kvcache files with pointers , just like accessing memory buffer
* * But the data actually resides in disk .
* * The OS will set some kernel page cache and manage the data swaping , which we do not need to care .
*/
void KVCacheManager : : mmapKVCache ( size_t keySize , size_t valueSize )
{
if ( mMapKeyAddr = = nullptr ) {
mMapKeyAddr = ( char * ) MNNMmapFile ( mKeyCacheFD , keySize ) ;
if ( mMapKeyAddr = = nullptr ) {
MNN_PRINT ( " Failed to memory-map the kvcache! \n " ) ;
}
}
if ( mMapValueAddr = = nullptr ) {
mMapValueAddr = ( char * ) MNNMmapFile ( mValueCacheFD , valueSize ) ;
if ( mMapValueAddr = = nullptr ) {
MNN_PRINT ( " Failed to memory-map the kvcache! \n " ) ;
}
}
}
void KVCacheManager : : unmapKVCache ( size_t keySize , size_t valueSize )
{
if ( mMapKeyAddr ! = nullptr ) {
MNNUnmapFile ( mMapKeyAddr , keySize ) ;
mMapKeyAddr = nullptr ;
}
if ( mMapValueAddr ! = nullptr ) {
MNNUnmapFile ( mMapValueAddr , valueSize ) ;
mMapValueAddr = nullptr ;
}
}
/*
* * @ brief Expand the size of kvcache and copy it from the old tensor in memory to the new tensor in memory
* * Finally reset the pointer to the new tensor
*/
void KVCacheManager : : expandKVCacheInMem ( int oldMaxLength ) {
/*=================================== Key ===================================*/
2024-09-12 12:57:57 +08:00
if ( mConfig . mUseInt8Kernel ) {
auto new_key = Tensor : : createDevice < int8_t > ( { mKvNumHead , UP_DIV ( mMaxLength , hP8 ) , UP_DIV ( mHeadDim , lP8 ) , hP8 * lP8 } ) ;
mBackend - > onAcquireBuffer ( new_key , Backend : : STATIC ) ;
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
memcpy (
new_key - > host < char > ( ) + h * UP_DIV ( mMaxLength , hP8 ) * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8 ,
mPastKey - > host < char > ( ) + h * UP_DIV ( oldMaxLength , hP8 ) * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8 ,
UP_DIV ( oldMaxLength , hP8 ) * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8
) ;
}
mPastKey . reset ( new_key ) ;
}
else if ( mConfig . mQuantKey ) {
2025-07-23 14:10:58 +08:00
auto new_key = Tensor : : createDevice < int8_t > ( { mKvNumHead , UP_DIV ( mMaxLength , hP ) , UP_DIV ( mHeadDim , lP ) , hP , lP } ) ;
2024-08-24 15:46:21 +08:00
mBackend - > onAcquireBuffer ( new_key , Backend : : STATIC ) ;
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
2024-09-12 12:57:57 +08:00
memcpy (
2025-07-23 14:10:58 +08:00
new_key - > host < char > ( ) + h * new_key - > stride ( 0 ) ,
mPastKey - > host < char > ( ) + h * ROUND_UP ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) ,
ROUND_UP ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP )
2024-09-12 12:57:57 +08:00
) ;
2024-08-24 15:46:21 +08:00
}
mPastKey . reset ( new_key ) ;
}
else {
2025-07-23 14:10:58 +08:00
auto new_key = Tensor : : createDevice < float > ( { mKvNumHead , UP_DIV ( mMaxLength , hP ) , UP_DIV ( mHeadDim , lP ) , hP , lP } ) ;
2024-08-24 15:46:21 +08:00
mBackend - > onAcquireBuffer ( new_key , Backend : : STATIC ) ;
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
2024-09-12 12:57:57 +08:00
memcpy (
2025-07-23 14:10:58 +08:00
new_key - > host < char > ( ) + h * new_key - > stride ( 0 ) * mBytes ,
mPastKey - > host < char > ( ) + h * ROUND_UP ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * mBytes ,
ROUND_UP ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * mBytes
2024-09-12 12:57:57 +08:00
) ;
2025-07-23 14:10:58 +08:00
if ( ( new_key - > stride ( 0 ) - mPastKey - > stride ( 0 ) ) > 0 ) {
memset ( new_key - > host < char > ( ) + h * new_key - > stride ( 0 ) * mBytes + mPastKey - > stride ( 0 ) * mBytes , 0 , ( new_key - > stride ( 0 ) - mPastKey - > stride ( 0 ) ) * mBytes ) ;
}
2024-08-24 15:46:21 +08:00
}
mPastKey . reset ( new_key ) ;
}
/*=================================== Value ===================================*/
if ( mConfig . mQuantValue ) {
2025-07-23 14:10:58 +08:00
auto new_value = Tensor : : createDevice < fp8_t > ( { mKvNumHead , UP_DIV ( mHeadDim , hP ) , UP_DIV ( mMaxLength , lP ) , hP , lP } ) ;
2024-08-24 15:46:21 +08:00
mBackend - > onAcquireBuffer ( new_value , Backend : : STATIC ) ;
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
for ( int i = 0 ; i < UP_DIV ( mHeadDim , hP ) ; i + + ) {
2024-09-12 12:57:57 +08:00
memcpy (
2025-07-23 14:10:58 +08:00
new_value - > host < char > ( ) + ( h * UP_DIV ( mHeadDim , hP ) + i ) * ROUND_UP ( mMaxLength , lP ) * hP ,
mPastValue - > host < char > ( ) + ( h * UP_DIV ( mHeadDim , hP ) + i ) * ROUND_UP ( oldMaxLength , lP ) * hP ,
ROUND_UP ( oldMaxLength , lP ) * hP
2024-09-12 12:57:57 +08:00
) ;
2024-08-24 15:46:21 +08:00
}
}
mPastValue . reset ( new_value ) ;
}
else {
2025-07-23 14:10:58 +08:00
auto new_value = Tensor : : createDevice < float > ( { mKvNumHead , UP_DIV ( mHeadDim , hP ) , UP_DIV ( mMaxLength , lP ) , hP , lP } ) ;
2024-08-24 15:46:21 +08:00
mBackend - > onAcquireBuffer ( new_value , Backend : : STATIC ) ;
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
for ( int i = 0 ; i < UP_DIV ( mHeadDim , hP ) ; i + + ) {
2024-09-12 12:57:57 +08:00
memcpy (
2025-07-23 14:10:58 +08:00
new_value - > host < char > ( ) + ( h * UP_DIV ( mHeadDim , hP ) + i ) * ROUND_UP ( mMaxLength , lP ) * hP * mBytes ,
mPastValue - > host < char > ( ) + ( h * UP_DIV ( mHeadDim , hP ) + i ) * ROUND_UP ( oldMaxLength , lP ) * hP * mBytes ,
ROUND_UP ( oldMaxLength , lP ) * hP * mBytes
2024-09-12 12:57:57 +08:00
) ;
2025-07-23 14:10:58 +08:00
if ( ( new_value - > stride ( 1 ) - mPastValue - > stride ( 1 ) ) > 0 ) {
memset ( new_value - > host < char > ( ) + ( h * new_value - > stride ( 0 ) + i * new_value - > stride ( 1 ) ) * mBytes + mPastValue - > stride ( 1 ) * mBytes , 0 , ( new_value - > stride ( 1 ) - mPastValue - > stride ( 1 ) ) * mBytes ) ;
}
2024-08-24 15:46:21 +08:00
}
}
mPastValue . reset ( new_value ) ;
}
}
/*
* * @ brief Move the kvcache from memory to the memory - mapped kvcache files in disk
* * Then release the memory buffer of old kvcache
*/
void KVCacheManager : : moveKVCacheFromMemToDisk ( int oldMaxLength ) {
/*=================================== Key ===================================*/
2024-09-12 12:57:57 +08:00
if ( mConfig . mUseInt8Kernel ) {
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
memcpy (
mMapKeyAddr + h * UP_DIV ( mMaxLength , hP8 ) * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8 ,
mPastKey - > host < char > ( ) + h * UP_DIV ( oldMaxLength , hP8 ) * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8 ,
UP_DIV ( oldMaxLength , hP8 ) * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8
) ;
}
mBackend - > onReleaseBuffer ( mPastKey . get ( ) , Backend : : STATIC ) ;
mPastKey . reset ( ) ;
}
2024-08-24 15:46:21 +08:00
if ( mConfig . mQuantKey ) {
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
2024-09-12 12:57:57 +08:00
memcpy (
2025-07-23 14:10:58 +08:00
mMapKeyAddr + h * UP_DIV ( mMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP ,
mPastKey - > host < char > ( ) + h * UP_DIV ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP ,
UP_DIV ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP
2024-09-12 12:57:57 +08:00
) ;
2024-08-24 15:46:21 +08:00
}
mBackend - > onReleaseBuffer ( mPastKey . get ( ) , Backend : : STATIC ) ;
mPastKey . reset ( ) ;
}
else {
2025-07-23 14:10:58 +08:00
if ( mHeadDim % lP ) {
memset ( mMapKeyAddr , 0 , mKvNumHead * ROUND_UP ( mMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * mBytes ) ;
}
2024-08-24 15:46:21 +08:00
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
2024-09-12 12:57:57 +08:00
memcpy (
2025-07-23 14:10:58 +08:00
mMapKeyAddr + h * UP_DIV ( mMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP * mBytes ,
mPastKey - > host < char > ( ) + h * UP_DIV ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP * mBytes ,
UP_DIV ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP * mBytes
2024-09-12 12:57:57 +08:00
) ;
2024-08-24 15:46:21 +08:00
}
mBackend - > onReleaseBuffer ( mPastKey . get ( ) , Backend : : STATIC ) ;
mPastKey . reset ( ) ;
}
/*=================================== Value ===================================*/
if ( mConfig . mQuantValue ) {
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
for ( int i = 0 ; i < UP_DIV ( mHeadDim , hP ) ; i + + ) {
2024-09-12 12:57:57 +08:00
memcpy (
2025-07-23 14:10:58 +08:00
mMapValueAddr + ( h * UP_DIV ( mHeadDim , hP ) + i ) * ROUND_UP ( mMaxLength , lP ) * hP ,
mPastValue - > host < char > ( ) + ( h * UP_DIV ( mHeadDim , hP ) + i ) * ROUND_UP ( oldMaxLength , lP ) * hP ,
ROUND_UP ( oldMaxLength , lP ) * hP
2024-09-12 12:57:57 +08:00
) ;
2024-08-24 15:46:21 +08:00
}
}
mBackend - > onReleaseBuffer ( mPastValue . get ( ) , Backend : : STATIC ) ;
mPastValue . reset ( ) ;
}
else {
2025-07-23 14:10:58 +08:00
if ( lP > 1 ) {
memset ( mMapValueAddr , 0 , mKvNumHead * ROUND_UP ( mHeadDim , hP ) * ROUND_UP ( mMaxLength , lP ) * mBytes ) ;
}
2024-08-24 15:46:21 +08:00
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
for ( int i = 0 ; i < UP_DIV ( mHeadDim , hP ) ; i + + ) {
2024-09-12 12:57:57 +08:00
memcpy (
2025-07-23 14:10:58 +08:00
mMapValueAddr + ( h * UP_DIV ( mHeadDim , hP ) + i ) * ROUND_UP ( mMaxLength , lP ) * hP * mBytes ,
mPastValue - > host < char > ( ) + ( h * UP_DIV ( mHeadDim , hP ) + i ) * ROUND_UP ( oldMaxLength , lP ) * hP * mBytes ,
ROUND_UP ( oldMaxLength , lP ) * hP * mBytes
2024-09-12 12:57:57 +08:00
) ;
2024-08-24 15:46:21 +08:00
}
}
mBackend - > onReleaseBuffer ( mPastValue . get ( ) , Backend : : STATIC ) ;
mPastValue . reset ( ) ;
}
}
/*
* * @ brief Expand the size of kvcache files in disk
*/
2024-09-12 12:57:57 +08:00
void KVCacheManager : : expandKVCacheInDisk ( int oldMaxLength , int oldKeySize , int oldValueSize , int keySize , int valueSize ) {
2024-08-24 15:46:21 +08:00
// Step 1: Copy the old kvcache from files to temporary buffers in memory
std : : shared_ptr < Tensor > old_key , old_value ;
2024-09-12 12:57:57 +08:00
if ( mConfig . mUseInt8Kernel ) {
old_key . reset ( Tensor : : createDevice < int8_t > ( { mKvNumHead , UP_DIV ( oldMaxLength , hP8 ) , UP_DIV ( mHeadDim , lP8 ) , hP8 * lP8 } ) ) ;
} else if ( mConfig . mQuantKey ) {
2025-07-23 14:10:58 +08:00
old_key . reset ( Tensor : : createDevice < int8_t > ( { mKvNumHead , UP_DIV ( oldMaxLength , hP ) , UP_DIV ( mHeadDim , lP ) , hP , lP } ) ) ;
2024-08-24 15:46:21 +08:00
} else {
2025-07-23 14:10:58 +08:00
old_key . reset ( Tensor : : createDevice < float > ( { mKvNumHead , UP_DIV ( oldMaxLength , hP ) , UP_DIV ( mHeadDim , lP ) , hP , lP } ) ) ;
2024-08-24 15:46:21 +08:00
}
if ( mConfig . mQuantValue ) {
2025-07-23 14:10:58 +08:00
old_value . reset ( Tensor : : createDevice < fp8_t > ( { mKvNumHead , UP_DIV ( mHeadDim , hP ) , UP_DIV ( oldMaxLength , lP ) , hP , lP } ) ) ;
2024-08-24 15:46:21 +08:00
} else {
2025-07-23 14:10:58 +08:00
old_value . reset ( Tensor : : createDevice < float > ( { mKvNumHead , UP_DIV ( mHeadDim , hP ) , UP_DIV ( oldMaxLength , lP ) , hP , lP } ) ) ;
2024-08-24 15:46:21 +08:00
}
mBackend - > onAcquireBuffer ( old_key . get ( ) , Backend : : STATIC ) ;
mBackend - > onAcquireBuffer ( old_value . get ( ) , Backend : : STATIC ) ;
2025-07-23 14:10:58 +08:00
if ( mHeadDim % lP ) {
memset ( old_key - > host < uint8_t > ( ) , 0 , old_key - > length ( 0 ) * old_key - > stride ( 0 ) * mBytes ) ;
}
if ( lP > 1 ) {
// can't be mMaxLenth % lP, since mMaxLength may be larger than seq_len for prefilling, we should ensure the (mMaxLength - seq_len)'s buffer is 0.
// computing L is seq_len
memset ( old_value - > host < uint8_t > ( ) , 0 , old_value - > length ( 0 ) * old_value - > stride ( 0 ) * mBytes ) ;
}
2024-08-24 15:46:21 +08:00
mmapKVCache ( oldKeySize , oldValueSize ) ;
memcpy ( old_key - > host < char > ( ) , mMapKeyAddr , oldKeySize ) ;
memcpy ( old_value - > host < char > ( ) , mMapValueAddr , oldValueSize ) ;
// Step 2: Resize the kvcache files and remap them
unmapKVCache ( oldKeySize , oldValueSize ) ;
resetKVCacheFileSize ( keySize , valueSize ) ;
mmapKVCache ( keySize , valueSize ) ;
// Step 3: Move the kvcache from temporary buffers in memory to disk
2024-09-12 12:57:57 +08:00
if ( mConfig . mUseInt8Kernel ) {
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
memcpy (
mMapKeyAddr + h * UP_DIV ( mMaxLength , hP8 ) * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8 ,
old_key - > host < char > ( ) + h * UP_DIV ( oldMaxLength , hP8 ) * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8 ,
UP_DIV ( oldMaxLength , hP8 ) * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8
) ;
}
} else if ( mConfig . mQuantKey ) {
2024-08-24 15:46:21 +08:00
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
2024-09-12 12:57:57 +08:00
memcpy (
2025-07-23 14:10:58 +08:00
mMapKeyAddr + h * UP_DIV ( mMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP ,
old_key - > host < char > ( ) + h * UP_DIV ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP ,
UP_DIV ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP
2024-09-12 12:57:57 +08:00
) ;
2024-08-24 15:46:21 +08:00
}
} else {
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
2024-09-12 12:57:57 +08:00
memcpy (
2025-07-23 14:10:58 +08:00
mMapKeyAddr + h * UP_DIV ( mMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP * mBytes ,
old_key - > host < char > ( ) + h * UP_DIV ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP * mBytes ,
UP_DIV ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP * mBytes
2024-09-12 12:57:57 +08:00
) ;
2024-08-24 15:46:21 +08:00
}
}
if ( mConfig . mQuantValue ) {
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
for ( int i = 0 ; i < UP_DIV ( mHeadDim , hP ) ; i + + ) {
2024-09-12 12:57:57 +08:00
memcpy (
2025-07-23 14:10:58 +08:00
mMapValueAddr + ( h * UP_DIV ( mHeadDim , hP ) + i ) * ROUND_UP ( mMaxLength , lP ) * hP ,
old_value - > host < char > ( ) + ( h * UP_DIV ( mHeadDim , hP ) + i ) * ROUND_UP ( oldMaxLength , lP ) * hP ,
ROUND_UP ( oldMaxLength , lP ) * hP
2024-09-12 12:57:57 +08:00
) ;
2024-08-24 15:46:21 +08:00
}
}
} else {
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
for ( int i = 0 ; i < UP_DIV ( mHeadDim , hP ) ; i + + ) {
2024-09-12 12:57:57 +08:00
memcpy (
2025-07-23 14:10:58 +08:00
mMapValueAddr + ( h * UP_DIV ( mHeadDim , hP ) + i ) * ROUND_UP ( mMaxLength , lP ) * hP * mBytes ,
old_value - > host < char > ( ) + ( h * UP_DIV ( mHeadDim , hP ) + i ) * ROUND_UP ( oldMaxLength , lP ) * hP * mBytes ,
ROUND_UP ( oldMaxLength , lP ) * hP * mBytes
2024-09-12 12:57:57 +08:00
) ;
2024-08-24 15:46:21 +08:00
}
}
}
// Step 4: Release the temporary buffers
mBackend - > onReleaseBuffer ( old_key . get ( ) , Backend : : STATIC ) ;
mBackend - > onReleaseBuffer ( old_value . get ( ) , Backend : : STATIC ) ;
}
void KVCacheManager : : onResize ( int kv_num_head , int head_dim ) {
mKvNumHead = kv_num_head ;
mHeadDim = head_dim ;
auto core = static_cast < CPUBackend * > ( mBackend ) - > functions ( ) ;
core - > MNNGetMatMulPackMode ( & eP , & lP , & hP ) ;
mBytes = core - > bytes ;
mThreadNum = static_cast < CPUBackend * > ( mBackend ) - > threadNumber ( ) ;
if ( mThreadNum > mKvNumHead ) {
mThreadNum = mKvNumHead ;
}
2024-09-12 12:57:57 +08:00
if ( mConfig . mUseInt8Kernel ) {
static_cast < CPUBackend * > ( mBackend ) - > int8Functions ( ) - > MNNGetGemmUnit ( & hP8 , & lP8 , & eP8 ) ;
}
2024-08-24 15:46:21 +08:00
}
void KVCacheManager : : onAlloc ( int kv_seq_len ) {
mMaxLength = kv_seq_len + mConfig . mExpandChunk ;
2024-09-12 12:57:57 +08:00
size_t keySize = 0 , valueSize = 0 ;
if ( mConfig . mUseInt8Kernel ) {
keySize = ( size_t ) mKvNumHead * UP_DIV ( mMaxLength , hP8 ) * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8 ;
} else if ( mConfig . mQuantKey ) {
2025-07-23 14:10:58 +08:00
keySize = ( size_t ) mKvNumHead * ROUND_UP ( mMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) ;
2024-09-12 12:57:57 +08:00
} else {
2025-07-23 14:10:58 +08:00
keySize = ( size_t ) mKvNumHead * ROUND_UP ( mMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * mBytes ;
2024-09-12 12:57:57 +08:00
}
2025-07-23 14:10:58 +08:00
valueSize = ( size_t ) mKvNumHead * ROUND_UP ( mHeadDim , hP ) * ROUND_UP ( mMaxLength , lP ) * ( mConfig . mQuantValue ? 1 : mBytes ) ;
2024-08-24 15:46:21 +08:00
/*============== Put the kvcache in disk ===========*/
if ( mConfig . mKVCacheSizeLimit ! = - 1 & & keySize + valueSize > mConfig . mKVCacheSizeLimit ) {
createKVCacheFile ( ) ;
resetKVCacheFileSize ( keySize , valueSize ) ;
mmapKVCache ( keySize , valueSize ) ;
mKVCacheInDisk = true ;
}
/*============== Put the kvcache in memory ===========*/
else {
2024-09-12 12:57:57 +08:00
if ( mConfig . mUseInt8Kernel ) {
mPastKey . reset ( Tensor : : createDevice < int8_t > ( { mKvNumHead , UP_DIV ( mMaxLength , hP8 ) , UP_DIV ( mHeadDim , lP8 ) , hP8 * lP8 } ) ) ;
} else if ( mConfig . mQuantKey ) {
2025-07-23 14:10:58 +08:00
mPastKey . reset ( Tensor : : createDevice < int8_t > ( { mKvNumHead , UP_DIV ( mMaxLength , hP ) , UP_DIV ( mHeadDim , lP ) , hP , lP } ) ) ;
2024-08-24 15:46:21 +08:00
} else {
2025-07-23 14:10:58 +08:00
mPastKey . reset ( Tensor : : createDevice < float > ( { mKvNumHead , UP_DIV ( mMaxLength , hP ) , UP_DIV ( mHeadDim , lP ) , hP , lP } ) ) ;
2024-08-24 15:46:21 +08:00
}
if ( mConfig . mQuantValue ) {
2025-07-23 14:10:58 +08:00
mPastValue . reset ( Tensor : : createDevice < fp8_t > ( { mKvNumHead , UP_DIV ( mHeadDim , hP ) , UP_DIV ( mMaxLength , lP ) , hP , lP } ) ) ;
2024-08-24 15:46:21 +08:00
} else {
2025-07-23 14:10:58 +08:00
mPastValue . reset ( Tensor : : createDevice < float > ( { mKvNumHead , UP_DIV ( mHeadDim , hP ) , UP_DIV ( mMaxLength , lP ) , hP , lP } ) ) ;
2024-08-24 15:46:21 +08:00
}
2024-09-12 12:57:57 +08:00
mBackend - > onAcquireBuffer ( mPastKey . get ( ) , Backend : : STATIC ) ;
2025-07-23 14:10:58 +08:00
mBackend - > onAcquireBuffer ( mPastValue . get ( ) , Backend : : STATIC ) ;
if ( mHeadDim % lP ) {
memset ( mPastKey - > host < int8_t > ( ) , 0 , mPastKey - > length ( 0 ) * mPastKey - > stride ( 0 ) * mBytes ) ;
}
if ( lP > 1 ) { // can't be mMaxLenth % lP, since mMaxLength may be larger than seq_len for prefilling, we should ensure the (mMaxLength - seq_len)'s buffer is 0.
memset ( mPastValue - > host < int8_t > ( ) , 0 , mPastValue - > length ( 0 ) * mPastValue - > stride ( 0 ) * mBytes ) ;
}
2024-09-12 12:57:57 +08:00
}
// scale, zero point and sum of key for quantization
if ( mConfig . mUseInt8Kernel ) {
mKeyScale . reset ( Tensor : : createDevice < int32_t > ( { mKvNumHead , UP_DIV ( mMaxLength , hP8 ) , hP8 } ) ) ;
mKeyZeroPoint . reset ( Tensor : : createDevice < int32_t > ( { mKvNumHead , UP_DIV ( mMaxLength , hP8 ) , hP8 } ) ) ;
mKeySum . reset ( Tensor : : createDevice < int32_t > ( { mKvNumHead , UP_DIV ( mMaxLength , hP8 ) , hP8 } ) ) ;
mBackend - > onAcquireBuffer ( mKeyScale . get ( ) , Backend : : STATIC ) ;
mBackend - > onAcquireBuffer ( mKeyZeroPoint . get ( ) , Backend : : STATIC ) ;
mBackend - > onAcquireBuffer ( mKeySum . get ( ) , Backend : : STATIC ) ;
} else if ( mConfig . mQuantKey ) {
mKeyScale . reset ( Tensor : : createDevice < float > ( { mKvNumHead , UP_DIV ( mMaxLength , hP ) , hP } ) ) ;
mKeyZeroPoint . reset ( Tensor : : createDevice < float > ( { mKvNumHead , UP_DIV ( mMaxLength , hP ) , hP } ) ) ;
mBackend - > onAcquireBuffer ( mKeyScale . get ( ) , Backend : : STATIC ) ;
mBackend - > onAcquireBuffer ( mKeyZeroPoint . get ( ) , Backend : : STATIC ) ;
2024-08-24 15:46:21 +08:00
}
}
2024-12-31 15:34:08 +08:00
void KVCacheManager : : onRealloc ( const KVMeta * meta ) {
auto kv_seq_len = meta - > previous + meta - > add - meta - > remove + meta - > computeReverseSize ( ) ;
if ( kv_seq_len > mMaxLength ) {
// Realloc
int oldMaxLength = mMaxLength ;
mMaxLength = kv_seq_len + mConfig . mExpandChunk ;
size_t oldKeySize , oldValueSize , keySize , valueSize ;
if ( mConfig . mUseInt8Kernel ) {
oldKeySize = ( size_t ) mKvNumHead * UP_DIV ( oldMaxLength , hP8 ) * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8 ;
keySize = ( size_t ) mKvNumHead * UP_DIV ( mMaxLength , hP8 ) * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8 ;
} else if ( mConfig . mQuantKey ) {
2025-07-23 14:10:58 +08:00
oldKeySize = ( size_t ) mKvNumHead * UP_DIV ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP ;
keySize = ( size_t ) mKvNumHead * UP_DIV ( mMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP ;
2024-12-31 15:34:08 +08:00
} else {
2025-07-23 14:10:58 +08:00
oldKeySize = ( size_t ) mKvNumHead * UP_DIV ( oldMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP * mBytes ;
keySize = ( size_t ) mKvNumHead * UP_DIV ( mMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP * mBytes ;
2024-12-31 15:34:08 +08:00
}
2025-07-23 14:10:58 +08:00
oldValueSize = ( size_t ) mKvNumHead * UP_DIV ( mHeadDim , hP ) * ROUND_UP ( oldMaxLength , lP ) * hP * ( mConfig . mQuantValue ? 1 : mBytes ) ;
valueSize = ( size_t ) mKvNumHead * UP_DIV ( mHeadDim , hP ) * ROUND_UP ( mMaxLength , lP ) * hP * ( mConfig . mQuantValue ? 1 : mBytes ) ;
2024-12-31 15:34:08 +08:00
/*==== No limit for kvcache ====*/
if ( mConfig . mKVCacheSizeLimit = = - 1 ) {
expandKVCacheInMem ( oldMaxLength ) ;
}
/*==== Last time the kvcache is memory, now it should be in memory too ====*/
else if ( keySize + valueSize < = mConfig . mKVCacheSizeLimit ) {
expandKVCacheInMem ( oldMaxLength ) ;
}
/*==== Last time the kvcache is in memory, but now it should be moved to disk ====*/
else if ( oldKeySize + oldValueSize < = mConfig . mKVCacheSizeLimit ) {
createKVCacheFile ( ) ;
resetKVCacheFileSize ( keySize , valueSize ) ;
mmapKVCache ( keySize , valueSize ) ;
moveKVCacheFromMemToDisk ( oldMaxLength ) ;
mKVCacheInDisk = true ;
}
/*==== Last time the kvcache is disk, now it should be in disk too ====*/
else {
expandKVCacheInDisk ( oldMaxLength , oldKeySize , oldValueSize , keySize , valueSize ) ;
}
/* No matter where is the kvcache, the scales and zero points are always in memory, since their size is very small */
if ( mConfig . mUseInt8Kernel ) {
auto new_scale = Tensor : : createDevice < int32_t > ( { mKvNumHead , UP_DIV ( mMaxLength , hP8 ) , hP8 } ) ;
auto new_zeroPoint = Tensor : : createDevice < int32_t > ( { mKvNumHead , UP_DIV ( mMaxLength , hP8 ) , hP8 } ) ;
auto new_sum = Tensor : : createDevice < int32_t > ( { mKvNumHead , UP_DIV ( mMaxLength , hP8 ) , hP8 } ) ;
mBackend - > onAcquireBuffer ( new_scale , Backend : : STATIC ) ;
mBackend - > onAcquireBuffer ( new_zeroPoint , Backend : : STATIC ) ;
mBackend - > onAcquireBuffer ( new_sum , Backend : : STATIC ) ;
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
memcpy ( new_scale - > host < char > ( ) + h * UP_DIV ( mMaxLength , hP8 ) * hP8 * 4 , mKeyScale - > host < char > ( ) + h * UP_DIV ( oldMaxLength , hP8 ) * hP8 * 4 , UP_DIV ( oldMaxLength , hP8 ) * hP8 * 4 ) ;
memcpy ( new_zeroPoint - > host < char > ( ) + h * UP_DIV ( mMaxLength , hP8 ) * hP8 * 4 , mKeyZeroPoint - > host < char > ( ) + h * UP_DIV ( oldMaxLength , hP8 ) * hP8 * 4 , UP_DIV ( oldMaxLength , hP8 ) * hP8 * 4 ) ;
memcpy ( new_sum - > host < char > ( ) + h * UP_DIV ( mMaxLength , hP8 ) * hP8 * 4 , mKeySum - > host < char > ( ) + h * UP_DIV ( oldMaxLength , hP8 ) * hP8 * 4 , UP_DIV ( oldMaxLength , hP8 ) * hP8 * 4 ) ;
}
mKeyScale . reset ( new_scale ) ;
mKeyZeroPoint . reset ( new_zeroPoint ) ;
mKeySum . reset ( new_sum ) ;
} else if ( mConfig . mQuantKey ) {
auto new_scale = Tensor : : createDevice < float > ( { mKvNumHead , UP_DIV ( mMaxLength , hP ) , 1 , hP } ) ;
auto new_zeroPoint = Tensor : : createDevice < float > ( { mKvNumHead , UP_DIV ( mMaxLength , hP ) , 1 , hP } ) ;
mBackend - > onAcquireBuffer ( new_scale , Backend : : STATIC ) ;
mBackend - > onAcquireBuffer ( new_zeroPoint , Backend : : STATIC ) ;
for ( int h = 0 ; h < mKvNumHead ; h + + ) {
memcpy ( new_scale - > host < char > ( ) + h * UP_DIV ( mMaxLength , hP ) * hP * mBytes , mKeyScale - > host < char > ( ) + h * UP_DIV ( oldMaxLength , hP ) * hP * mBytes , UP_DIV ( oldMaxLength , hP ) * hP * mBytes ) ;
memcpy ( new_zeroPoint - > host < char > ( ) + h * UP_DIV ( mMaxLength , hP ) * hP * mBytes , mKeyZeroPoint - > host < char > ( ) + h * UP_DIV ( oldMaxLength , hP ) * hP * mBytes , UP_DIV ( oldMaxLength , hP ) * hP * mBytes ) ;
}
mKeyScale . reset ( new_scale ) ;
mKeyZeroPoint . reset ( new_zeroPoint ) ;
}
2024-08-24 15:46:21 +08:00
}
2024-12-31 15:34:08 +08:00
// Remove
auto start = mPastLength - meta - > remove ;
if ( 0 = = meta - > n_reserve ) {
mPastLength = start ;
return ;
2024-08-24 15:46:21 +08:00
}
2024-12-31 15:34:08 +08:00
// Don't support not align reserve
auto align = hP ;
auto dstStart = start ;
auto lastValidSrcEnd = start ;
for ( int n = 0 ; n < meta - > n_reserve ; + + n ) {
auto lastEndAlign = UP_DIV ( lastValidSrcEnd , align ) * align ;
auto begin = meta - > reserve [ 2 * n ] ;
auto size = meta - > reserve [ 2 * n + 1 ] ;
auto startAlign = ( ( begin + start ) / align ) * align ;
if ( startAlign < = lastEndAlign ) {
// Fullly reserve
dstStart = dstStart + size ;
lastValidSrcEnd = begin + start + size ;
continue ;
2024-09-12 12:57:57 +08:00
}
2024-12-31 15:34:08 +08:00
auto end = begin + start + size ;
auto endAlign = UP_DIV ( end , align ) * align ;
auto sizeUnit = ( endAlign - startAlign ) / align ;
auto dstStartAlign = UP_DIV ( dstStart , align ) * align ;
//TODO: Support Quant
// mPastKey.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP}));
// Move K
2025-07-23 14:10:58 +08:00
auto keyStride = UP_DIV ( mMaxLength , align ) * align * ROUND_UP ( mHeadDim , lP ) ;
auto dstKAddr = keyAddr ( ) + dstStartAlign * ROUND_UP ( mHeadDim , lP ) * mBytes ;
auto srcKAddr = keyAddr ( ) + startAlign * ROUND_UP ( mHeadDim , lP ) * mBytes ;
2024-12-31 15:34:08 +08:00
for ( int i = 0 ; i < mKvNumHead ; + + i ) {
auto dst = dstKAddr + i * keyStride * mBytes ;
auto src = srcKAddr + i * keyStride * mBytes ;
for ( int j = 0 ; j < sizeUnit ; + + j ) {
2025-07-23 14:10:58 +08:00
: : memcpy ( dst + j * align * ROUND_UP ( mHeadDim , lP ) * mBytes , src + j * align * ROUND_UP ( mHeadDim , lP ) * mBytes , align * ROUND_UP ( mHeadDim , lP ) * mBytes ) ;
2024-12-31 15:34:08 +08:00
}
}
// mPastValue.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), mMaxLength, hP}));
// Move V
auto dstVAddr = valudAddr ( ) + dstStartAlign * align * mBytes ;
auto srcVAddr = valudAddr ( ) + startAlign * align * mBytes ;
auto number = mKvNumHead * UP_DIV ( mHeadDim , align ) ;
for ( int i = 0 ; i < number ; + + i ) {
2025-07-23 14:10:58 +08:00
auto dst = dstVAddr + i * ROUND_UP ( mMaxLength , lP ) * align * mBytes ;
auto src = srcVAddr + i * ROUND_UP ( mMaxLength , lP ) * align * mBytes ;
2024-12-31 15:34:08 +08:00
for ( int j = 0 ; j < sizeUnit ; + + j ) {
: : memcpy ( dst + j * align * align * mBytes , src + j * align * align * mBytes , align * align * mBytes ) ;
}
2024-08-24 15:46:21 +08:00
}
2024-12-31 15:34:08 +08:00
dstStart = dstStart + size ;
lastValidSrcEnd = begin + start + size ;
2024-08-24 15:46:21 +08:00
}
2024-12-31 15:34:08 +08:00
mPastLength = dstStart ;
2024-08-24 15:46:21 +08:00
}
void KVCacheManager : : onClear ( ) {
if ( mKVCacheInDisk ) {
2024-09-12 12:57:57 +08:00
size_t keySize = 0 , valueSize = 0 ;
if ( mConfig . mUseInt8Kernel ) {
keySize = ( size_t ) mKvNumHead * UP_DIV ( mMaxLength , hP8 ) * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8 ;
} else if ( mConfig . mQuantKey ) {
2025-07-23 14:10:58 +08:00
keySize = ( size_t ) mKvNumHead * UP_DIV ( mMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP ;
2024-09-12 12:57:57 +08:00
} else {
2025-07-23 14:10:58 +08:00
keySize = ( size_t ) mKvNumHead * UP_DIV ( mMaxLength , hP ) * ROUND_UP ( mHeadDim , lP ) * hP * mBytes ;
2024-09-12 12:57:57 +08:00
}
2025-07-23 14:10:58 +08:00
valueSize = ( size_t ) mKvNumHead * UP_DIV ( mHeadDim , hP ) * ROUND_UP ( mMaxLength , lP ) * hP * ( mConfig . mQuantValue ? 1 : mBytes ) ;
2024-09-12 12:57:57 +08:00
unmapKVCache ( keySize , valueSize ) ;
2024-08-24 15:46:21 +08:00
removeKVCacheFile ( ) ;
mKVCacheInDisk = false ;
}
2024-09-12 12:57:57 +08:00
mPastKey . reset ( ) ;
mPastValue . reset ( ) ;
mKeyScale . reset ( ) ;
mKeyZeroPoint . reset ( ) ;
mKeySum . reset ( ) ;
2024-08-24 15:46:21 +08:00
mMaxLength = mPastLength = 0 ;
}
template < typename T >
2024-09-12 12:57:57 +08:00
void KVCacheManager : : pack_key ( const Tensor * key , int seq_len , int kv_h ) {
if ( mConfig . mUseInt8Kernel ) { // [maxlen/hP8, headdim/lP8, hP8, lP8]
int8_t * key_dst = reinterpret_cast < int8_t * > ( addrOfKey ( kv_h ) ) ;
float * scale_dst = reinterpret_cast < float * > ( addrOfScale ( kv_h ) ) ;
float * zeroPoint_dst = reinterpret_cast < float * > ( addrOfZeroPoint ( kv_h ) ) ;
float * sum_dst = reinterpret_cast < float * > ( addrOfKeySum ( kv_h ) ) ;
for ( int s = 0 ; s < seq_len ; s + + ) {
T * key_src = key - > host < T > ( ) + s * mKvNumHead * mHeadDim + kv_h * mHeadDim ;
float minKey = key_src [ 0 ] ;
float maxKey = key_src [ 0 ] ;
float sumKey = key_src [ 0 ] ;
for ( int d = 1 ; d < mHeadDim ; d + + ) {
minKey = ALIMIN ( minKey , key_src [ d ] ) ;
maxKey = ALIMAX ( maxKey , key_src [ d ] ) ;
sumKey + = key_src [ d ] ;
}
int out_index = ( mPastLength + s ) / hP8 ;
int in_index = ( mPastLength + s ) % hP8 ;
scale_dst [ out_index * hP8 + in_index ] = ( maxKey - minKey ) / 255.0f ;
zeroPoint_dst [ out_index * hP8 + in_index ] = - 255.0f * minKey / ( maxKey - minKey ) - 128.0 ;
sum_dst [ out_index * hP8 + in_index ] = sumKey ;
for ( int d = 0 ; d < mHeadDim ; d + + ) {
int i = d / lP8 ;
int j = d % lP8 ;
key_dst [ out_index * UP_DIV ( mHeadDim , lP8 ) * hP8 * lP8 + i * hP8 * lP8 + in_index * lP8 + j ] = roundf ( ( key_src [ d ] - minKey ) / ( maxKey - minKey ) * 255.0f - 128.0f ) ;
}
}
}
else if ( mConfig . mQuantKey ) { // [maxlen/hP, headdim, hP]
int8_t * key_dst = reinterpret_cast < int8_t * > ( addrOfKey ( kv_h ) ) ;
T * scale_dst = reinterpret_cast < T * > ( addrOfScale ( kv_h ) ) ;
T * zeroPoint_dst = reinterpret_cast < T * > ( addrOfZeroPoint ( kv_h ) ) ;
2024-08-24 15:46:21 +08:00
for ( int i = 0 ; i < seq_len ; i + + ) {
T * key_src = key - > host < T > ( ) + i * mKvNumHead * mHeadDim + kv_h * mHeadDim ;
int out_index = ( mPastLength + i ) / hP ;
int in_index = ( mPastLength + i ) % hP ;
T minKey , maxKey ;
2024-09-12 12:57:57 +08:00
static_cast < CPUBackend * > ( mBackend ) - > functions ( ) - > MNNCountMaxMinValue ( ( float * ) key_src , ( float * ) & minKey , ( float * ) & maxKey , mHeadDim ) ;
2024-08-24 15:46:21 +08:00
scale_dst [ out_index * hP + in_index ] = ( maxKey - minKey ) / 255.0f ;
zeroPoint_dst [ out_index * hP + in_index ] = 128.0f * ( maxKey - minKey ) / 255.0f + minKey ;
for ( int j = 0 ; j < mHeadDim ; j + + ) {
key_dst [ out_index * mHeadDim * hP + j * hP + in_index ] = roundf ( ( key_src [ j ] - minKey ) / ( maxKey - minKey ) * 255 - 128 ) ;
}
}
}
2025-07-23 14:10:58 +08:00
else { // target: [maxlen/hP, headdim/lP, hP, lP]
2024-09-12 12:57:57 +08:00
T * key_dst = reinterpret_cast < T * > ( addrOfKey ( kv_h ) ) ;
2025-07-23 14:10:58 +08:00
auto stride0 = ROUND_UP ( mHeadDim , lP ) * hP ;
auto stride1 = hP * lP ;
2024-08-24 15:46:21 +08:00
for ( int i = 0 ; i < seq_len ; i + + ) {
T * key_src = key - > host < T > ( ) + i * mKvNumHead * mHeadDim + kv_h * mHeadDim ;
int out_index = ( mPastLength + i ) / hP ;
int in_index = ( mPastLength + i ) % hP ;
for ( int j = 0 ; j < mHeadDim ; j + + ) {
2025-07-23 14:10:58 +08:00
key_dst [ out_index * stride0 + ( j / lP ) * stride1 + in_index * lP + ( j % lP ) ] = key_src [ j ] ;
2024-08-24 15:46:21 +08:00
}
}
}
}
template < typename T >
2024-09-12 12:57:57 +08:00
void KVCacheManager : : pack_value ( const Tensor * value , int seq_len , int kv_h ) { // [headdim/hP, maxlen, hP]
if ( mConfig . mQuantValue ) {
fp8_t * value_dst = reinterpret_cast < fp8_t * > ( addrOfValue ( kv_h ) ) ;
2024-08-24 15:46:21 +08:00
uint8_t * buf = ( uint8_t * ) MNNMemoryAllocAlign ( mHeadDim , MNN_MEMORY_ALIGN_DEFAULT ) ;
for ( int i = 0 ; i < seq_len ; i + + ) {
T * value_src = value - > host < T > ( ) + i * mKvNumHead * mHeadDim + kv_h * mHeadDim ;
if ( sizeof ( T ) = = 2 ) {
2024-09-12 12:57:57 +08:00
static_cast < CPUBackend * > ( mBackend ) - > functions ( ) - > MNNFp16ToFp8 ( buf , ( uint16_t * ) value_src , mHeadDim ) ;
2024-08-24 15:46:21 +08:00
} else {
2024-09-12 12:57:57 +08:00
static_cast < CPUBackend * > ( mBackend ) - > functions ( ) - > MNNFp32ToFp8 ( buf , ( float * ) value_src , mHeadDim ) ;
2024-08-24 15:46:21 +08:00
}
for ( int j = 0 ; j < mHeadDim ; j + + ) {
int out_index = j / hP ;
int in_index = j % hP ;
value_dst [ out_index * mMaxLength * hP + ( mPastLength + i ) * hP + in_index ] = buf [ j ] ;
}
}
MNNMemoryFreeAlign ( buf ) ;
}
else {
2025-07-23 14:10:58 +08:00
// [mHeadDim/hP, mMaxLength/lP, hP, lP]
auto stride0 = ROUND_UP ( mMaxLength , lP ) * hP ;
auto stride1 = hP * lP ;
2024-09-12 12:57:57 +08:00
T * value_dst = reinterpret_cast < T * > ( addrOfValue ( kv_h ) ) ;
2024-08-24 15:46:21 +08:00
for ( int i = 0 ; i < seq_len ; i + + ) {
T * value_src = value - > host < T > ( ) + i * mKvNumHead * mHeadDim + kv_h * mHeadDim ;
2025-07-23 14:10:58 +08:00
int seqLenOut = ( mPastLength + i ) / lP ;
int seqLenIn = ( mPastLength + i ) % lP ;
2024-08-24 15:46:21 +08:00
for ( int j = 0 ; j < mHeadDim ; j + + ) {
int out_index = j / hP ;
int in_index = j % hP ;
2025-07-23 14:10:58 +08:00
value_dst [ out_index * stride0 + seqLenOut * stride1 + in_index * lP + seqLenIn ] = value_src [ j ] ;
2024-08-24 15:46:21 +08:00
}
}
}
}
void KVCacheManager : : onPushBack ( const Tensor * key , const Tensor * value ) {
auto core = static_cast < CPUBackend * > ( mBackend ) - > functions ( ) ;
2024-12-02 10:12:08 +08:00
int seq_len = key - > length ( 1 ) ;
2024-08-24 15:46:21 +08:00
int tileCount = UP_DIV ( mKvNumHead , mThreadNum ) ;
std : : function < void ( int ) > packKV = [ = ] ( int tid ) {
for ( int kv_h = tid * tileCount ; kv_h < ( tid + 1 ) * tileCount & & kv_h < mKvNumHead ; kv_h + + ) {
if ( mBytes = = 2 ) {
2024-09-12 12:57:57 +08:00
pack_key < FLOAT16_T > ( key , seq_len , kv_h ) ;
pack_value < FLOAT16_T > ( value , seq_len , kv_h ) ;
2024-08-24 15:46:21 +08:00
} else {
2024-09-12 12:57:57 +08:00
pack_key < float > ( key , seq_len , kv_h ) ;
pack_value < float > ( value , seq_len , kv_h ) ;
2024-08-24 15:46:21 +08:00
}
}
} ;
MNN_CONCURRENCY_BEGIN ( tid , mThreadNum ) {
packKV ( ( int ) tid ) ;
}
MNN_CONCURRENCY_END ( ) ;
mPastLength + = seq_len ;
}
void KVCacheManager : : onDequantValue ( Tensor * dequantedValues ) {
auto core = static_cast < CPUBackend * > ( mBackend ) - > functions ( ) ;
int tileCount = UP_DIV ( mKvNumHead , mThreadNum ) ;
std : : function < void ( int ) > dequant = [ = ] ( int tid ) {
for ( int kv_h = tid * tileCount ; kv_h < ( tid + 1 ) * tileCount & & kv_h < mKvNumHead ; kv_h + + ) {
char * dst = dequantedValues - > host < char > ( ) + kv_h * UP_DIV ( mHeadDim , hP ) * mPastLength * hP * mBytes ;
char * src = addrOfValue ( kv_h ) ;
for ( int i = 0 ; i < UP_DIV ( mHeadDim , hP ) ; i + + ) {
if ( mBytes = = 2 ) {
core - > MNNFp8ToFp16 ( ( uint16_t * ) dst , ( uint8_t * ) src , mPastLength * hP ) ;
} else {
core - > MNNFp8ToFp32 ( ( float * ) dst , ( uint8_t * ) src , mPastLength * hP ) ;
}
dst + = mPastLength * hP * mBytes ;
src + = mMaxLength * hP ;
}
}
} ;
MNN_CONCURRENCY_BEGIN ( tid , mThreadNum ) {
dequant ( ( int ) tid ) ;
}
MNN_CONCURRENCY_END ( ) ;
}
} // namespace MNN
2024-12-31 15:34:08 +08:00
# endif // MNN_SUPPORT_TRANSFORMER_FUSE