add pd disaggregation and separate acceleration on CPU backend

This commit is contained in:
hzx 2025-05-28 19:50:53 +08:00
parent 5f0d59958e
commit 664ee20e2b
14 changed files with 259 additions and 54 deletions

View File

@ -268,6 +268,22 @@ public:
NetModule* module(new NetModule(submodule, newInfo, nullptr, 0, 0.0f));
#ifdef MNN_INTERNAL_ENABLED
module->mLogInfo = mLogInfo;
#endif
return this->cloneBaseTo(ctx, module);
}
virtual Module* clone(CloneContext* ctx, const ScheduleConfig* config) const override {
auto mModule = mChildren[0];
auto origin = mInfo->runTimeManager->getInside();
std::shared_ptr<Executor::RuntimeManager> newRt (Executor::RuntimeManager::createRuntimeManager(*config));
const_cast<RuntimeAttr*>(newRt->getInside())->mContent->mExternalFile = origin->mContent->mExternalFile;
std::shared_ptr<Module::Info> newInfo(new Module::Info);
*newInfo = *mInfo;
ctx->pRuntimeManager = newRt;
newInfo->runTimeManager = newRt;
std::shared_ptr<Module> submodule(mModule->clone(ctx));
NetModule* module(new NetModule(submodule, newInfo, nullptr, 0, 0.0f));
#ifdef MNN_INTERNAL_ENABLED
module->mLogInfo = mLogInfo;
#endif
return this->cloneBaseTo(ctx, module);
}
@ -506,6 +522,11 @@ Module* Module::clone(const Module* module, const bool shareParams) {
return module->clone(&context);
}
Module* Module::clone(const Module* module, const ScheduleConfig* config, const bool shareParams) {
CloneContext context(shareParams);
return module->clone(&context, config);
}
Module* Module::cloneBaseTo(CloneContext* ctx, Module* module) const {
for (const Express::VARP& var : mParameters) {
module->mParameters.push_back(ctx->getOrClone(var));

View File

@ -78,6 +78,7 @@ public:
static Module* extract(std::vector<Express::VARP> inputs, std::vector<Express::VARP> outputs, bool fortrain, const std::map<std::string, SubGraph>& subGraph = {});
static Module* clone(const Module* module, const bool shareParams = false);
static Module* clone(const Module* module, const ScheduleConfig* config, const bool shareParams = false);
struct Info {
// Input info load from model
@ -102,6 +103,9 @@ public:
virtual Module* clone(CloneContext* ctx) const {
return nullptr;
}
virtual Module* clone(CloneContext* ctx, const ScheduleConfig* config) const {
return clone(ctx);
}
void registerModel(const std::vector<std::shared_ptr<Module>>& children);
static void destroy(Module* m);

View File

@ -196,6 +196,7 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
}
int tileCount = UP_DIV(mNumHead, mThreadNum);
int group_size = mNumHead / mKvNumHead;
mKVCacheManager->setThreadNum(mThreadNum);
// reduce the value of 'query' to avoid fp16 overflow
float mScale = 1.0 / sqrt(mHeadDim);
float q_scale = 1.0;

View File

@ -50,6 +50,14 @@ ErrorCode CastWrapExecution::onExecute(const std::vector<Tensor*>& inputs, const
CPUCastCreator::cast(inputs[0], outputs[0], cpuBackend, convertType);
return NO_ERROR;
}
int getMajorCPUNumber(const std::vector<CPUGroup>& groups) {
int sum = 0;
for (const auto& g: groups) {
if (g.cpuType != CPUGroup::Efficient) { sum+=g.ids.size(); }
}
return sum;
}
void CPUBackend::computeDivideSizes(int size, int* dst, float avgDiv) const {
if (mGroupWithComputeRate.size() <= 1 || (avgDiv > 0 && avgDiv < mComputeI)) {
// Avg divide
@ -136,13 +144,14 @@ void CPURuntime::_bindCPUCore() const {
}
void CPURuntime::_resetThreadPool() {
if (mThreadNumber <= 0) { mThreadNumber=getMajorCPUNumber(MNNGetCPUInfo()->groups); }
mThreadNumber = std::max(1, mThreadNumber);
mThreadNumber = std::min(mThreadNumber, MAX_THREAD_NUMBER);
#ifdef MNN_USE_THREAD_POOL
ThreadPool::releaseWorkIndex(mTaskIndex);
auto cpuInfo = MNNGetCPUInfo();
if (mThreadNumber > 1) {
int systemThreadNumber = (int)cpuInfo->cpuNumber;
if (mThreadNumber > 1) {
if (systemThreadNumber == 0) {
systemThreadNumber = mThreadNumber;
}
@ -389,25 +398,18 @@ BufferAllocator* CPURuntime::createDynamicBufferAlloctor(int index) const {
}
return new EagerBufferAllocator(BufferAllocator::Allocator::createRecurse(mStaticAllocator.get()));
}
CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode precision, BackendConfig::MemoryMode memory, MNNForwardType type, size_t flags, int initThreadNumber) : Backend(type) {
#ifdef LOG_VERBOSE
MNN_PRINT("cpu backend create\n");
#endif
mMemory = memory;
mRuntime = const_cast<CPURuntime*>(runtime);
mThreadNumber = mRuntime->mThreadNumber;
// Compute Group Rate
do {
void CPUBackend::computeGroupRate() {
{
if (mThreadNumber <= 1 || mRuntime->mPower == BackendConfig::Power_Low) {
break;
return;
}
auto rate = mRuntime->hint().cpuDecreaseRate;
if (rate >= 100 || rate <= 0) {
break;
return;
}
auto cpuInfo = MNNGetCPUInfo();
if (cpuInfo->groups.size() < 2) {
break;
return;
}
if (cpuInfo->i8mm) {
mComputeI = 28.f;
@ -435,7 +437,18 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p
for (auto& g : mGroupWithComputeRate) {
g.first = g.first / totalComputeRate;
}
} while (false);
}
}
CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode precision, BackendConfig::MemoryMode memory, MNNForwardType type, size_t flags, int initThreadNumber) : Backend(type) {
#ifdef LOG_VERBOSE
MNN_PRINT("cpu backend create\n");
#endif
mMemory = memory;
mRuntime = const_cast<CPURuntime*>(runtime);
mThreadNumber = mRuntime->mThreadNumber;
// Compute Group Rate
computeGroupRate();
// initialize Allocator
auto dynamicAlloc = mRuntime->mSharedDmaInfo;
if (nullptr == dynamicAlloc.get()) {
mDmaInfo.reset(new CPURuntime::DynamicAllocator);

View File

@ -181,6 +181,7 @@ public:
void enqueueTask(std::function<int()>&& task);
protected:
void computeGroupRate();
MemObj* allocBuffer(size_t size, Tensor* dest, StorageType storageType);
CoreFunctions* mCoreFunctions;
CoreInt8Functions* mInt8CoreFunctions;

View File

@ -38,6 +38,9 @@
#include <algorithm>
#include <string>
#include <iostream>
#include <fstream>
#include <sstream>
#include "core/Macro.h"
#ifdef __ANDROID__
@ -117,7 +120,7 @@ int MNNSetSchedAffinity(const int* cpuIDs, int size) {
// cpuinfo
// Reference from: https://github.com/pytorch/cpuinfo
#if defined(ENABLE_ARMV82) && defined(__arm__)
#if defined(__ANDROID__) && (defined(__arm__) || defined(__aarch64__))
/* As per include/sys/system_properties.h in Android NDK */
#define CPUINFO_HARDWARE_VALUE_MAX 64
@ -1360,6 +1363,36 @@ const MNNCPUInfo* MNNGetCPUInfo() {
return gCPUInfo;
}
#ifdef __linux__
// Function to trim leading and trailing spaces from a string
static std::string trim(const std::string& str) {
size_t first = str.find_first_not_of(" \t");
if (first == std::string::npos)
return ""; // Return empty string if all characters are spaces
size_t last = str.find_last_not_of(" \t");
return str.substr(first, (last - first + 1));
}
static std::vector<std::string> _fillCpuPart() {
std::vector<std::string> cpu_parts;
std::ifstream file("/proc/cpuinfo");
std::string line;
if (!file.is_open()) { return cpu_parts; } // return empty list if file not exist!
while (std::getline(file, line)) {
std::istringstream iss(line);
std::string key, value;
if (std::getline(iss, key, ':') && std::getline(iss, value)) {
key = trim(key); // Trim leading and trailing spaces from key
value = trim(value); // Trim leading and trailing spaces from value
if (key == "CPU part") {
cpu_parts.push_back(value);
}
}
}
file.close();
return cpu_parts;
}
#endif
static void _fillInfo(MNNCPUInfo* cpuinfo_isa) {
cpuinfo_isa->dot = false;
cpuinfo_isa->fp16arith = false;
@ -1371,6 +1404,7 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) {
#ifdef __linux__
do {
DIR* root;
// deal with the CPU policy info and frequency info (maxFreq, minFreq).
std::string dir = "/sys/devices/system/cpu/cpufreq";
if ((root = opendir(dir.c_str())) == NULL) {
break;
@ -1418,20 +1452,46 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) {
std::sort(cpuinfo_isa->groups.begin(), cpuinfo_isa->groups.end(), [](const CPUGroup& left, const CPUGroup& right) {
return left.maxFreq < right.maxFreq;
});
// Merge group if needed
if (cpuinfo_isa->groups.size() >= 2 && cpuinfo_isa->groups[0].maxFreq == cpuinfo_isa->groups[1].maxFreq) {
auto backupGroups = std::move(cpuinfo_isa->groups);
CPUGroup&& current = std::move(backupGroups[0]);
for (int v=1; v<backupGroups.size(); ++v) {
if (backupGroups[v].maxFreq != current.maxFreq) {
cpuinfo_isa->groups.emplace_back(current);
current = std::move(backupGroups[v]);
// do not merge group
// deal with cpu capacity info
do {
dir = "/sys/devices/system/cpu/";
if (opendir(dir.c_str()) == NULL) {
break;
}
for (auto& group: cpuinfo_isa->groups) {
std::string cpu_name = "cpu"+std::to_string(group.ids[0]);
MNN::AutoStorage<uint8_t> buffer;
if (false == _readAll(dir+cpu_name+"/cpu_capacity", buffer)) {
continue;
}
group.capacity = _readNumber((const char*)buffer.get(), buffer.size())[0];
}
} while(false);
// get CPU part from /proc/cpuinfo
std::vector<std::string> cpu_parts = _fillCpuPart();
// classify cpuType
// 1. get prime maxFreq, minFreq, capacity, /proc/cpuinfo type code
// 2. All the cores with 1) same type code; or 2) >=80% freq and capacity, are classified as prime.
// 3. All the cores with 1) >=70% freq and >=50% capacity; or 2) not the lowest freq, are classified as performance.
// 4. The rest are classfied as efficient.
const auto& prime_info = cpuinfo_isa->groups.back();
auto lowest_maxFreq = cpuinfo_isa->groups.front().maxFreq;
auto lowesr_minFreq = cpuinfo_isa->groups.front().minFreq;
for (auto& group: cpuinfo_isa->groups) {
if (cpu_parts.empty()) {
if (((float)group.maxFreq >= 0.8*(float)prime_info.maxFreq) && ((float)group.capacity >= 0.8*(float)prime_info.capacity))
{ group.cpuType=CPUGroup::Prime; continue; }
} else {
current.ids.insert(current.ids.end(), backupGroups[v].ids.begin(), backupGroups[v].ids.end());
if (cpu_parts[prime_info.ids.front()] == cpu_parts[group.ids.front()])
{ group.cpuType=CPUGroup::Prime; continue; }
}
if ((((float)group.maxFreq >= 0.6*(float)prime_info.maxFreq) && ((float)group.capacity >= 0.4*(float)prime_info.capacity)) \
|| ((float)group.minFreq > (float)lowesr_minFreq) && ((float)group.maxFreq > (float)lowest_maxFreq))
{ group.cpuType=CPUGroup::Performance; continue; }
group.cpuType=CPUGroup::Efficient;
}
cpuinfo_isa->groups.emplace_back(current);
}
// count total cpu number and display info
cpuinfo_isa->cpuNumber = 0;
for (auto& group : cpuinfo_isa->groups) {
cpuinfo_isa->cpuNumber += group.ids.size();
@ -1440,6 +1500,13 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) {
message += " " + std::to_string(group.ids[v]) + " ";
}
message += "], " + std::to_string(group.minFreq) + " - " + std::to_string(group.maxFreq);
if (group.capacity!=0) { message += ", capacity: " + std::to_string(group.capacity); }
message += ", cpu type: ";
switch (group.cpuType) {
case CPUGroup::Prime: message += "Prime"; break;
case CPUGroup::Performance: message += "Performance"; break;
case CPUGroup::Efficient: message += "Efficient"; break;
}
MNN_PRINT("%s\n", message.c_str());
}
} while (false);

View File

@ -12,8 +12,15 @@
#include <vector>
#include "core/Macro.h"
struct CPUGroup {
uint32_t minFreq;
uint32_t maxFreq;
enum CPUCapacityType {
Prime = 0,
Performance,
Efficient
};
uint32_t minFreq = 0;
uint32_t maxFreq = 0;
uint32_t capacity = 0;
CPUCapacityType cpuType = Prime;
std::vector<int> ids;
};
struct MNNCPUInfo {

View File

@ -326,10 +326,6 @@ void KVCacheManager::onResize(int kv_num_head, int head_dim) {
auto core = static_cast<CPUBackend *>(mBackend)->functions();
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
mBytes = core->bytes;
mThreadNum = static_cast<CPUBackend *>(mBackend)->threadNumber();
if (mThreadNum > mKvNumHead) {
mThreadNum = mKvNumHead;
}
if (mConfig.mUseInt8Kernel) {
static_cast<CPUBackend *>(mBackend)->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8);
}

View File

@ -94,6 +94,12 @@ public:
const Tensor * keySum() {
return mKeySum.get();
}
void setThreadNum(int numThread) {
mThreadNum = numThread;
if (mThreadNum > mKvNumHead) {
mThreadNum = mKvNumHead;
}
}
bool inDisk() {
return mKVCacheInDisk;
}

View File

@ -18,6 +18,7 @@ using namespace MNN::Transformer;
static void tuning_prepare(Llm* llm) {
MNN_PRINT("Prepare for tuning opt Begin\n");
llm->tuning(OP_ENCODER_NUMBER, {1, 5, 10, 20, 30, 50, 100});
llm->tuning(PREFILL_BIGLITTLE_CORE, {});
MNN_PRINT("Prepare for tuning opt End\n");
}

View File

@ -39,6 +39,7 @@ using ChatMessages = std::vector<ChatMessage>;
enum TuneType {
// op encoder number for commit
OP_ENCODER_NUMBER = 0,
PREFILL_BIGLITTLE_CORE,
};
enum class MatchStrictLevel : int;
enum class NgramSelectRule : int;
@ -129,6 +130,7 @@ protected:
std::shared_ptr<Express::Executor::RuntimeManager> mRuntimeManager, mProcessorRuntimeManager;
std::vector<std::shared_ptr<Express::Module>> mModules, mPrefillModules, mDecodeModules, mCurrentModules;
const Express::Module* mBaseModule = nullptr;
ScheduleConfig mPrefillConfig, mDecodeConfig;
Express::VARP inputsEmbeds, attentionMask, positionIds;
std::vector<Express::VARP> mInputsEmbedsVarVec, mAttentionMaskVarVec, mPositionIdsVarVec;
Express::VARP logitsAllIdx, logitsLastIdx;

View File

@ -93,17 +93,20 @@ bool Llm::set_config(const std::string& content) {
}
void Llm::initRuntime() {
ScheduleConfig config;
BackendConfig cpuBackendConfig;
config.type = backend_type_convert(mConfig->backend_type());
config.numThread = mConfig->thread_num();
if(config.type == 3){
// setup mPrefillConfig
mPrefillConfig.type = backend_type_convert(mConfig->backend_type());
mPrefillConfig.numThread = (mConfig->prefill_thread_num() < 0) \
? mConfig->thread_num() : mConfig->prefill_thread_num();
if(mPrefillConfig.type == 3){
// opencl need set numThread = 64(buffer mode)
config.numThread |= 64;
mPrefillConfig.numThread |= 64;
}
if (mConfig->power() == "high") {
std::string powerConfig = (mConfig->prefill_power().empty()) \
? mConfig->power() : mConfig->prefill_power();
if (powerConfig == "high") {
cpuBackendConfig.power = BackendConfig::Power_High;
} else if (mConfig->power() == "low") {
} else if (powerConfig == "low") {
cpuBackendConfig.power = BackendConfig::Power_Low;
}
if (mConfig->memory() == "high") {
@ -116,9 +119,26 @@ void Llm::initRuntime() {
} else if (mConfig->precision() == "low") {
cpuBackendConfig.precision = BackendConfig::Precision_Low;
}
config.backendConfig = &cpuBackendConfig;
ExecutorScope::Current()->setGlobalExecutorConfig(mPrefillConfig.type, cpuBackendConfig, mPrefillConfig.numThread);
mPrefillConfig.backendConfig = new BackendConfig(cpuBackendConfig);
// set up mDecodeConfig
mDecodeConfig = mPrefillConfig;
mDecodeConfig.backendConfig = new BackendConfig(cpuBackendConfig);
mDecodeConfig.numThread = (mConfig->decode_thread_num() < 0) \
? mConfig->thread_num() : mConfig->decode_thread_num();
if(mDecodeConfig.type == 3){
// opencl need set numThread = 64(buffer mode)
mDecodeConfig.numThread |= 64;
}
powerConfig = (mConfig->decode_power().empty()) \
? mConfig->power() : mConfig->decode_power();
if (powerConfig == "high") {
mDecodeConfig.backendConfig->power = BackendConfig::Power_High;
} else if (powerConfig == "low") {
mDecodeConfig.backendConfig->power = BackendConfig::Power_Low;
}
mRuntimeManager.reset(Executor::RuntimeManager::createRuntimeManager(config));
mRuntimeManager.reset(Executor::RuntimeManager::createRuntimeManager(mPrefillConfig));
// Use 4 thread to load llm
mRuntimeManager->setHint(MNN::Interpreter::INIT_THREAD_NUMBER, 4);
@ -152,7 +172,7 @@ void Llm::initRuntime() {
mRuntimeManager->setMode(MNN::Interpreter::Session_Debug);
_initDebug();
#endif
if (config.type != 0) { // not cpu
if (mPrefillConfig.type != 0) { // not cpu
std::string cacheFilePath = tmpPath.length() != 0 ? tmpPath : ".";
mRuntimeManager->setCache(cacheFilePath + "/mnn_cachefile.bin");
}
@ -244,6 +264,7 @@ void Llm::load() {
mModules[0].reset(Module::load(inputNames, outputNames, model_path.c_str(), mRuntimeManager, &module_config));
// set speculative decoding params
ExecutorScope::Current()->setGlobalExecutorConfig(mDecodeConfig.type, *(mDecodeConfig.backendConfig), mDecodeConfig.numThread);
setSpeculativeConfig();
int decode_type_num = 1;
if(mLookAhead) {
@ -253,7 +274,7 @@ void Llm::load() {
mDecodeModules.resize(decode_type_num);
for (int v = 0; v < mDecodeModules.size(); ++v) {
mDecodeModules[v].reset(Module::clone(mModules[0].get()));
mDecodeModules[v].reset(Module::clone(mModules[0].get(), &mDecodeConfig));
}
mPrefillModules = mModules;
@ -335,15 +356,55 @@ bool Llm::select_module(size_t index) {
}
void Llm::tuning(TuneType type, std::vector<int> candidates) {
if (type != OP_ENCODER_NUMBER) {
MNN_ERROR("tuning type not supported\n");
if (type == PREFILL_BIGLITTLE_CORE) {
// only CPU power high is tuned
if (mPrefillConfig.type != MNN_FORWARD_CPU) {
return;
}
if (mPrefillConfig.backendConfig->power != BackendConfig::Power_High) {
return;
}
if (candidates.empty()){
candidates = {40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95};
}
auto itp_type = Interpreter::CPU_LITTLECORE_DECREASE_RATE;
int length = 64;
int64_t min_time = INT64_MAX;
int prefer_candidate = 0;
for (auto& candidate : candidates) {
mRuntimeManager->setHint(itp_type, candidate);
// load prefill module again to take effect! the following 2 lines can't be deleted!!
for (int v = 0; v < mPrefillModules.size(); ++v) {
mPrefillModules[v].reset(Module::clone(mPrefillModules[v].get()));
}
switchMode(Prefill);
Timer _t;
std::vector<int> input_ids(length, 0);
auto logits = forward(input_ids);
auto token = sample(logits);
auto time = _t.durationInUs();
MNN_PRINT("CPU_LITTLECORE_DECREASE_RATE:%d, prefill time: %lld us\n", candidate, time);
if (time < min_time) {
prefer_candidate = candidate;
min_time = time;
}
setKVCacheInfo(0, getCurrentHistory());
reset();
}
mRuntimeManager->setHint(itp_type, prefer_candidate);
// load prefill module again to take effect! the following 2 lines can't be deleted!!
for (int v = 0; v < mPrefillModules.size(); ++v) {
mPrefillModules[v].reset(Module::clone(mPrefillModules[v].get()));
}
switchMode(Prefill);
}
if (type == OP_ENCODER_NUMBER) {
// FIXME: Currently OpenCL Don't support KVMeta
if (mConfig->backend_type() == "opencl") {
return;
}
mCurrentModules = mDecodeModules;
auto itp_type = MNN::Interpreter::OP_ENCODER_NUMBER_FOR_COMMIT;
switchMode(Llm::Decode);
int decode_seq = 1;
if(mLookAhead) {
// start autoregressive decoding
@ -354,7 +415,7 @@ void Llm::tuning(TuneType type, std::vector<int> candidates) {
int64_t min_time = INT64_MAX;
int prefer_candidate = 10;
for (auto& candidate : candidates) {
mRuntimeManager->setHint(MNN::Interpreter::OP_ENCODER_NUMBER_FOR_COMMIT, candidate);
mRuntimeManager->setHint(itp_type, candidate);
Timer _t;
std::vector<int> input_ids(decode_seq, 0);
auto logits = forward(input_ids);
@ -372,18 +433,21 @@ void Llm::tuning(TuneType type, std::vector<int> candidates) {
// MNN_PRINT("op encode number:%d, decode time: %lld us\n", candidate, time);
}
}
mRuntimeManager->setHint(MNN::Interpreter::OP_ENCODER_NUMBER_FOR_COMMIT, prefer_candidate);
mRuntimeManager->setHint(itp_type, prefer_candidate);
// clear dirty tuning kv history
setKVCacheInfo(0, getCurrentHistory());
reset();
}
}
void Llm::switchMode(Llm::Stage stage) {
switch (stage) {
case Prefill:
ExecutorScope::Current()->setGlobalExecutorConfig(mPrefillConfig.type, *(mPrefillConfig.backendConfig), mPrefillConfig.numThread);
mCurrentModules = mPrefillModules;
break;
case Decode:
ExecutorScope::Current()->setGlobalExecutorConfig(mDecodeConfig.type, *(mDecodeConfig.backendConfig), mDecodeConfig.numThread);
mCurrentModules = mDecodeModules;
break;
default:
@ -532,7 +596,7 @@ void Llm::generate_init(std::ostream* os, const char* end_with) {
mMeta->remove = mMeta->previous;
}
mContext->output_tokens.clear();
mCurrentModules = mPrefillModules;
switchMode(Llm::Prefill);
}
size_t Llm::getCurrentHistory() const {
@ -618,7 +682,7 @@ std::vector<int> Llm::generate(MNN::Express::VARP input_embeds, int max_tokens)
}
mContext->prompt_len = static_cast<int>(input_embeds->getInfo()->dim[0]);
Timer _t;
mCurrentModules = mPrefillModules;
switchMode(Llm::Prefill);
auto logits = forward(input_embeds);
if (nullptr == logits.get()) {
return {};
@ -632,7 +696,7 @@ std::vector<int> Llm::generate(MNN::Express::VARP input_embeds, int max_tokens)
mContext->history_tokens.push_back(mContext->current_token);
mContext->output_tokens.push_back(mContext->current_token);
logits = nullptr;
mCurrentModules = mDecodeModules;
switchMode(Llm::Decode);
generate(max_tokens - 1);
return mContext->output_tokens;
@ -707,6 +771,8 @@ Llm::~Llm() {
mModules.clear();
mRuntimeManager.reset();
mProcessorRuntimeManager.reset();
if (mPrefillConfig.backendConfig != nullptr) delete mPrefillConfig.backendConfig;
if (mDecodeConfig.backendConfig != nullptr) delete mDecodeConfig.backendConfig;
}
bool Llm::reuse_kv() { return mConfig->reuse_kv(); }

View File

@ -341,6 +341,15 @@ public:
return config_.value("thread_num", 4);
}
int prefill_thread_num(bool mllm = false) const {
if (mllm) return mllm_config_.value("prefill_thread_num", -1);
return config_.value("prefill_thread_num", -1);
}
int decode_thread_num(bool mllm = false) const {
if (mllm) return mllm_config_.value("decode_thread_num", -1);
return config_.value("decode_thread_num", -1);
}
std::string precision(bool mllm = false) const {
if (mllm) return mllm_config_.value("precision", "low");
return config_.value("precision", "low");
@ -349,6 +358,14 @@ public:
if (mllm) return mllm_config_.value("power", "normal");
return config_.value("power", "normal");
}
std::string prefill_power(bool mllm = false) const {
if (mllm) return mllm_config_.value("prefill_power", "");
return config_.value("prefill_power", "");
}
std::string decode_power(bool mllm = false) const {
if (mllm) return mllm_config_.value("decode_power", "");
return config_.value("decode_power", "");
}
std::string memory(bool mllm = false) const {
if (mllm) return mllm_config_.value("memory", "low");

View File

@ -544,7 +544,10 @@ class LlmExporter(torch.nn.Module):
"llm_model": f"{self.dst_name}.mnn",
"llm_weight": f"{self.dst_name}.mnn.weight",
"backend_type": "cpu",
"thread_num": 4,
"prefill_thread_num": 0,
"prefill_power": "high",
"decode_thread_num": 4,
"decode_power": "normal",
"precision": "low",
"memory": "low",
# "system_prompt": "You are a helpful assistant.",