2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//  KVCacheManager.hpp
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//  MNN
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//  Created by MNN on 2024/08/05.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//  Copyright © 2018, Alibaba Group Holding Limited
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#ifndef KVCACHE_MANAGER_HPP
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#define KVCACHE_MANAGER_HPP
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "core/Macro.h"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "core/MNNFileUtils.h"
							 | 
						
					
						
							
								
									
										
										
										
											2024-12-31 15:34:08 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								#include "core/OpCommonUtils.hpp"
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "backend/cpu/CPUBackend.hpp"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "backend/cpu/compute/CommonOptFunction.h"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#if defined (__aarch64__)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#define FLOAT16_T __fp16
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#else
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#define FLOAT16_T float
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#endif
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								typedef uint8_t fp8_t;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								namespace MNN {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								class KVCacheManager : public NonCopyable{
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								public:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    struct KVCacheConfig {
							 | 
						
					
						
							
								
									
										
										
										
											2024-09-12 12:57:57 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        bool mQuantKey      = false;            // Quantize keys to int8 or not
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        bool mQuantValue    = false;            // Quantize values to fp8 or not
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        bool mUseInt8Kernel = false;            // Whether to use int8 gemm kernel in CPU attention
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        std::string mKVCacheDir = "/tmp";       // Path of the kvcache files in disk
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        size_t mKVCacheSizeLimit = -1;          // The limit of the kvcache size
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        int  mExpandChunk = 64;                 // Number of expand chunks when the buffer is full
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    };
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								private:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    Backend * mBackend;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    KVCacheConfig mConfig;
							 | 
						
					
						
							
								
									
										
										
										
											2024-09-12 12:57:57 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    std::shared_ptr<Tensor> mPastKey;               // {numhead, [maxlen/hP, headdim, hP]} or {numhead, [maxlen/hP8, headdim/lP8, hP8, lP8]} 
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    std::shared_ptr<Tensor> mPastValue;             // numhead, [headdim/hP, maxlen, hP]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    std::shared_ptr<Tensor> mKeyScale;              // {numhead, [maxlen/hP, hP]} or {numhead, [maxlen/hP8, hP8]}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    std::shared_ptr<Tensor> mKeyZeroPoint;          // {numhead, [maxlen/hP, hP]} or {numhead, [maxlen/hP8, hP8]}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    std::shared_ptr<Tensor> mKeySum;                // numhead, [maxlen/hP8, hP8]
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    file_t mKeyCacheFD   = INVALID_FILE;            // The file descriptor of keys
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    file_t mValueCacheFD = INVALID_FILE;            // The file descriptor of values
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    char * mMapKeyAddr   = nullptr;                 // Memory-mapped address of keys
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    char * mMapValueAddr = nullptr;                 // Memory-mapped address of values
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    bool mKVCacheInDisk  = false;                   // Whether the kvcache is in disk or in memory now
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    int  mPastLength     = 0;                       // Length of past kvcache
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    int  mMaxLength      = 0;                       // Capacity of current kvcache buffer (how many kv items can be stored at most)
							 | 
						
					
						
							
								
									
										
										
										
											2024-09-12 12:57:57 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    int  eP, lP, hP;                                // Packing mode for float matmul
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    int  eP8, lP8, hP8;                             // Packing mode for int8 gemm kernel
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    int  mBytes = 4, mThreadNum = 1;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    int  mKvNumHead = 0, mHeadDim = 0;
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    void createKVCacheFile();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    void removeKVCacheFile();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    void resetKVCacheFileSize(size_t keySize, size_t valueSize);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    void mmapKVCache(size_t keySize, size_t valueSize);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    void unmapKVCache(size_t keySize, size_t valueSize);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    void expandKVCacheInMem(int oldMaxLength);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    void moveKVCacheFromMemToDisk(int oldMaxLength);
							 | 
						
					
						
							
								
									
										
										
										
											2024-09-12 12:57:57 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    void expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int oldValueSize, int keySize, int valueSize);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    template <typename T> void pack_key(const Tensor* key, int seq_len, int kv_h);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    template <typename T> void pack_value(const Tensor* value, int seq_len, int kv_h);
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								public:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    KVCacheManager(Backend * backend, KVCacheConfig & kvConfig) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        mBackend   = backend;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        mConfig    = kvConfig; 
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    ~KVCacheManager() {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        onClear();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    const Backend * backend() {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return mBackend;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    const KVCacheConfig * config() {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return &mConfig;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    const Tensor * key() {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return mPastKey.get();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    const Tensor * value() {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return mPastValue.get();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    const Tensor * scale() {
							 | 
						
					
						
							
								
									
										
										
										
											2024-09-12 12:57:57 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        return mKeyScale.get();
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    const Tensor * zeroPoint() {
							 | 
						
					
						
							
								
									
										
										
										
											2024-09-12 12:57:57 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        return mKeyZeroPoint.get();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    const Tensor * keySum() {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return mKeySum.get();
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    bool inDisk() {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return mKVCacheInDisk;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    int kvLength() {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return mPastLength;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    int maxLength() {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return mMaxLength;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							
								
									
										
										
										
											2024-12-31 15:34:08 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    uint8_t* keyAddr() {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        char * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<char>();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return (uint8_t*)baseAddr;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    uint8_t* valudAddr() {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        char * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<char>();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return (uint8_t*)baseAddr;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    char * addrOfKey(int kv_h) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        char * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<char>();
							 | 
						
					
						
							
								
									
										
										
										
											2024-09-12 12:57:57 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        if (mConfig.mUseInt8Kernel) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return baseAddr + kv_h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        } else if (mConfig.mQuantKey) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * mHeadDim * hP;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        } else {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        }
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    char * addrOfValue(int kv_h) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        char * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<char>();
							 | 
						
					
						
							
								
									
										
										
										
											2024-09-12 12:57:57 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        if (mConfig.mQuantValue) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * mMaxLength * hP;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        } else {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * mMaxLength * hP * mBytes;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        }
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    char * addrOfScale(int kv_h) {
							 | 
						
					
						
							
								
									
										
										
										
											2024-09-12 12:57:57 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        if (mConfig.mUseInt8Kernel) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return mKeyScale->host<char>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        } else if (mConfig.mQuantKey) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return mKeyScale->host<char>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        } else {
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return nullptr;
							 | 
						
					
						
							
								
									
										
										
										
											2024-09-12 12:57:57 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        }
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    char * addrOfZeroPoint(int kv_h) {
							 | 
						
					
						
							
								
									
										
										
										
											2024-09-12 12:57:57 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        if (mConfig.mUseInt8Kernel) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return mKeyZeroPoint->host<char>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        } else if (mConfig.mQuantKey) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return mKeyZeroPoint->host<char>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        } else {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return nullptr;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    char * addrOfKeySum(int kv_h) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if (mConfig.mUseInt8Kernel) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return mKeySum->host<char>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        }else {
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return nullptr;
							 | 
						
					
						
							
								
									
										
										
										
											2024-09-12 12:57:57 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        }
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    void onResize(int kv_num_head, int head_dim);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    void onAlloc(int kv_seq_len);
							 | 
						
					
						
							
								
									
										
										
										
											2024-12-31 15:34:08 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    void onRealloc(const KVMeta* meta);
							 | 
						
					
						
							
								
									
										
										
										
											2024-08-24 15:46:21 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    void onClear();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    void onPushBack(const Tensor * key, const Tensor * value);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    void onDequantValue(Tensor * dequantedValues);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								};
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								} // namespace MNN
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#endif // KVCACHE_MANAGER_HPP
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2024-12-31 15:34:08 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								#endif // MNN_SUPPORT_TRANSFORMER_FUSE
							 |