Merge pull request #3874 from alibaba/feature/llmrefractor

Feature/llmrefractor
This commit is contained in:
jxt1234 2025-09-05 18:16:29 +08:00 committed by GitHub
commit 2ddf67d37f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 77 additions and 42 deletions

View File

@ -1,4 +1,5 @@
#include <cmath>
#include <limits.h>
#include "VulkanGaussianRender.hpp"
namespace MNN {
struct ImageConstant {

View File

@ -71,8 +71,6 @@ std::vector<std::vector<std::string>> parse_csv(const std::vector<std::string>&
static int benchmark(Llm* llm, const std::vector<std::string>& prompts, int max_token_number) {
int prompt_len = 0;
int decode_len = 0;
int64_t vision_time = 0;
int64_t audio_time = 0;
int64_t prefill_time = 0;
int64_t decode_time = 0;
int64_t sample_time = 0;
@ -117,29 +115,39 @@ static int benchmark(Llm* llm, const std::vector<std::string>& prompts, int max_
}
prompt_len += context->prompt_len;
decode_len += context->gen_seq_len;
vision_time += context->vision_us;
audio_time += context->audio_us;
prefill_time += context->prefill_us;
decode_time += context->decode_us;
sample_time += context->sample_us;
}
llm->generateWavform();
float vision_s = vision_time / 1e6;
float audio_s = audio_time / 1e6;
float vision_s = context->vision_us / 1e6;
float audio_s = context->audio_us / 1e6;
float prefill_s = prefill_time / 1e6;
float decode_s = decode_time / 1e6;
float sample_s = sample_time / 1e6;
float vision_speed = 0.0f;
if (context->pixels_mp > 0.0f) {
vision_speed = context->pixels_mp / vision_s;
}
float audio_speed = 0.0f;
if (context->audio_input_s > 0.0f) {
audio_speed = context->audio_input_s / audio_s;
}
printf("\n#################################\n");
printf("prompt tokens num = %d\n", prompt_len);
printf("decode tokens num = %d\n", decode_len);
printf(" vision time = %.2f s\n", vision_s);
printf(" audio time = %.2f s\n", audio_s);
printf(" pixels_mp = %.2f MP\n", context->pixels_mp);
printf(" audio process time = %.2f s\n", audio_s);
printf(" audio input time = %.2f s\n", context->audio_input_s);
printf("prefill time = %.2f s\n", prefill_s);
printf(" decode time = %.2f s\n", decode_s);
printf(" sample time = %.2f s\n", sample_s);
printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s);
printf(" decode speed = %.2f tok/s\n", decode_len / decode_s);
printf(" vision speed = %.3f MP/s\n", vision_speed);
printf(" audio RTF = %.3f \n", audio_s / context->audio_input_s);
printf("##################################\n");
return 0;
}
@ -256,7 +264,11 @@ int main(int argc, const char* argv[]) {
llm->set_config("{\"tmp_path\":\"tmp\"}");
{
AUTOTIME;
llm->load();
bool res = llm->load();
if (!res) {
MNN_ERROR("LLM init error\n");
return 0;
}
}
if (true) {
AUTOTIME;

View File

@ -76,8 +76,8 @@ struct LlmContext {
int64_t prefill_us = 0;
int64_t decode_us = 0;
int64_t sample_us = 0;
float prefill_mb = 0;
float decode_mb = 0;
float pixels_mp = 0;
float audio_input_s = 0;
// tokens
int current_token;
std::vector<int> history_tokens;
@ -95,7 +95,7 @@ public:
static void destroy(Llm* llm);// For Windows RT mode should use destroy
Llm(std::shared_ptr<LlmConfig> config);
virtual ~Llm();
virtual void load();
virtual bool load();
virtual Express::VARP gen_attention_mask(int seq_len);
virtual Express::VARP gen_position_ids(int seq_len);
virtual Express::VARP embedding(const std::vector<int>& input_ids);
@ -152,7 +152,7 @@ protected:
std::shared_ptr<DiskEmbedding> mDiskEmbedding;
std::shared_ptr<Sampler> mSampler;
std::shared_ptr<Express::Executor::RuntimeManager> mRuntimeManager, mProcessorRuntimeManager;
std::vector<std::shared_ptr<Express::Module>> mModules;
std::shared_ptr<Express::Module> mModule;
/**
key: <seq_len, all_logists>
value : module
@ -190,7 +190,7 @@ public:
static Embedding* createEmbedding(const std::string& config_path, bool load = true);
static float dist(Express::VARP var0, Express::VARP var1);
static float cos_sim(Express::VARP var0, Express::VARP var1);
virtual void load() override;
virtual bool load() override;
Express::VARP ids_embedding(const std::vector<int>& ids);
Express::VARP txt_embedding(const std::string& txt);
int dim() const;

View File

@ -45,7 +45,7 @@ int Embedding::dim() const {
return mConfig->hidden_size();
}
void Embedding::load() {
bool Embedding::load() {
initRuntime();
printf("load tokenizer\n");
std::cout << mConfig->tokenizer_file() << std::endl;
@ -59,10 +59,13 @@ void Embedding::load() {
module_config.rearrange = true;
auto model_path = mConfig->llm_model();
MNN_PRINT("load %s ... ", model_path.c_str());
mModules.resize(1);
mModules[0].reset(Module::load({"input_ids", "attention_mask", "position_ids"}, {"sentence_embeddings"},
mModule.reset(Module::load({"input_ids", "attention_mask", "position_ids"}, {"sentence_embeddings"},
model_path.c_str(), mRuntimeManager, &module_config));
if (nullptr == mModule.get()) {
return false;
}
MNN_PRINT("Done!\n");
return true;
}
VARP Embedding::ids_embedding(const std::vector<int>& ids) {
@ -70,7 +73,7 @@ VARP Embedding::ids_embedding(const std::vector<int>& ids) {
auto inputs_ids = embedding(ids);
auto attention_mask = gen_attention_mask(prompt_len);
auto position_ids = gen_position_ids(prompt_len);
auto outputs = mModules[0]->onForward({inputs_ids, attention_mask, position_ids});
auto outputs = mModule->onForward({inputs_ids, attention_mask, position_ids});
auto sentence_embeddings = outputs[0];
return sentence_embeddings;
}

View File

@ -234,7 +234,7 @@ static bool canSpecDecode(std::shared_ptr<Express::Module> module) {
void Llm::setSpeculativeConfig() {
auto specultive_type = mConfig->speculative_type();
if(!specultive_type.empty()) {
if(!canSpecDecode(mModules[0])) {
if(!canSpecDecode(mModule)) {
mInSpec = false;
return;
}
@ -243,7 +243,7 @@ void Llm::setSpeculativeConfig() {
}
}
void Llm::load() {
bool Llm::load() {
initRuntime();
// init module status
// 1. load vocab
@ -264,7 +264,6 @@ void Llm::load() {
module_config.base = mBaseModule;
}
// load single model
mModules.resize(1);
std::string model_path = mConfig->llm_model();
std::vector<std::string> inputNames {"input_ids", "attention_mask", "position_ids", "logits_index"};
@ -284,14 +283,14 @@ void Llm::load() {
}
mRuntimeManager->setExternalFile(mConfig->llm_weight());
mModules[0].reset(Module::load(inputNames, outputNames, model_path.c_str(), mRuntimeManager, &module_config));
mModule.reset(Module::load(inputNames, outputNames, model_path.c_str(), mRuntimeManager, &module_config));
mRuntimeManager->setExternalFile("");
if(nullptr == mModules[0]) {
if(nullptr == mModule) {
MNN_ERROR("[Error]: Load module failed, please check model.\n");
if(outputNames.size() > 1) {
MNN_ERROR("[Warning]: Set module multi outputs, please double check.\n");
}
return;
return false;
}
// set speculative decoding params
setSpeculativeConfig();
@ -305,13 +304,13 @@ void Llm::load() {
decode_type_num = 2;
verify_length = mDraftLength + 1;
// speculative decode module
mModulePool[std::make_pair(verify_length, true)].reset(Module::clone(mModules[0].get()));
mModulePool[std::make_pair(verify_length, true)].reset(Module::clone(mModule.get()));
}
// autoregressive decode module
mModulePool[std::make_pair(1, false)].reset(Module::clone(mModules[0].get()));
mModulePool[std::make_pair(1, false)].reset(Module::clone(mModule.get()));
// prefill module
mModulePool[std::make_pair(mPrefillKey, mConfig->all_logits())] = mModules[0];
mModulePool[std::make_pair(mPrefillKey, mConfig->all_logits())] = mModule;
// module input varp setting
logitsLastIdx = _var<int>({-1}, {1});
@ -340,12 +339,13 @@ void Llm::load() {
// MTP model load
mGenerationStrategy->load(module_config);
return true;
}
Llm* Llm::create_lora(const std::string& lora_path) {
auto llm = new Llm(std::make_shared<LlmConfig>(*mConfig));
llm->set_config("{\"llm_model\": \"" + lora_path + "\", \"use_mmap\": false, \"use_cached_mmap\": false}");
llm->mBaseModule = mModules.begin()->get();
llm->mBaseModule = mModule.get();
llm->load();
return llm;
}
@ -426,7 +426,7 @@ std::vector<Express::VARP> Llm::forwardRaw(Express::VARP hiddenState, Express::V
if(mModulePool.find(moduleKey) == mModulePool.end()) {
MNN_PRINT("Warning: module need new clone, cloning now.\n");
mRuntimeManager->setHintPtr(Interpreter::KVCACHE_INFO, mMeta.get());
mModulePool[moduleKey].reset(Module::clone(mModules[0].get()));
mModulePool[moduleKey].reset(Module::clone(mModule.get()));
}
if (isAllLogists) {
@ -554,6 +554,10 @@ void Llm::reset() {
mContext->history_tokens.clear();
mContext->all_seq_len = 0;
mContext->gen_seq_len = 0;
mContext->vision_us = 0;
mContext->pixels_mp = 0.0f;
mContext->audio_us = 0;
mContext->audio_input_s = 0.0f;
mMeta->remove = mMeta->previous;
}
@ -756,7 +760,7 @@ Llm::~Llm() {
}
#endif
mGenerateParam.reset();
mModules.clear();
mModule.reset();
mRuntimeManager.reset();
mProcessorRuntimeManager.reset();
}

View File

@ -4,6 +4,7 @@
// Created by MNN on 2025/04/08.
// Copyright © 2018, Alibaba Group Holding Limited
//
//#define MNN_OPEN_TIME_TRACE
#ifdef _WIN32
#define _USE_MATH_DEFINES
@ -25,7 +26,6 @@
#ifdef LLM_SUPPORT_AUDIO
#include <audio/audio.hpp>
#endif
namespace MNN {
using namespace Express;
namespace Transformer {
@ -69,11 +69,17 @@ Omni::Omni(std::shared_ptr<LlmConfig> config) : Llm(config) {
if (config->is_audio()) {}
}
void Omni::load() {
Llm::load();
bool Omni::load() {
auto res = Llm::load();
if (!res) {
return false;
}
if (mConfig->has_talker()) {
mTalker.reset(new Talker(mConfig, this));
mTalker->load();
res = mTalker->load();
}
if (!res) {
return false;
}
ScheduleConfig config;
if (mConfig->mllm_config_.empty()) {
@ -118,6 +124,7 @@ void Omni::load() {
if (mConfig->is_audio()) {
mAudioModule.reset(Module::load({}, {}, mConfig->audio_model().c_str(), mProcessorRuntimeManager, &module_config));
}
return mAudioModule.get() != nullptr && mVisionModule.get() != nullptr;
}
#ifdef LLM_SUPPORT_VISION
@ -142,6 +149,7 @@ std::vector<int> Omni::defaultVisionProcess(VARP image) {
}
std::vector<int> Omni::qwen2VisionProcess(VARP image) {
AUTOTIME;
const auto inputNames = mVisionModule->getInfo()->inputNames;
bool hasWindowIndex = inputNames.size() == 4 && inputNames[3] == "window_index";
// Qwen2-VL / Qwen2.5-VL
@ -583,6 +591,7 @@ std::vector<int> Omni::visionProcess(VARP image) {
imgIds = defaultVisionProcess(image);
}
mContext->vision_us += _t.durationInUs();
mContext->pixels_mp += (mVisionWidth / 1000.0f) * (mVisionHeight / 1000.0f);
// set vision number for image idx
mVisionNum += 1;
return imgIds;
@ -600,6 +609,7 @@ std::vector<int> Omni::audioProcess(const std::string& file) {
MNN_PRINT("Omni Can't open audio: %s\n", file.c_str());
return std::vector<int>(0);
}
mContext->audio_input_s += (float)(waveform->getInfo()->size) / sample_rate;
return audioProcess(waveform);
#else
return std::vector<int>(0);
@ -899,7 +909,7 @@ static inline bool needNewVar(VARP var, int axis, int seq_len) {
}
VARP Omni::gen_position_ids(int seq_len) {
auto positionIdsDims = mModules[0]->getInfo()->inputs[2].dim;
auto positionIdsDims = mModule->getInfo()->inputs[2].dim;
if (positionIdsDims[0] == 1) {
return Llm::gen_position_ids(seq_len);
}
@ -989,7 +999,7 @@ void Omni::generateWavform() {
}
}
void Talker::load() {
bool Talker::load() {
initRuntime();
mSeqLenIndex = 1;
set_config("{\"sampler_type\": \"mixed\", \"temperature\": 0.9, \"topK\": 40, \"topP\": 0.8, \"penalty\": 1.05}");
@ -1011,11 +1021,13 @@ void Talker::load() {
Module::Config module_config;
module_config.shapeMutable = false;
module_config.rearrange = true;
mModules.resize(1);
std::vector<std::string> inputNames {"inputs_embeds", "attention_mask", "position_ids", "logits_index"};
mModules[0].reset(Module::load(inputNames,
mModule.reset(Module::load(inputNames,
{"logits"}, mConfig->talker_model().c_str(), mRuntimeManager, &module_config));
if (mModule.get() == nullptr) {
return false;
}
// dit
mPreDit.reset(Module::load({"cond", "spk", "code"}, {"code_embeds", "rope", "mask"},
mConfig->predit_model().c_str(), mRuntimeManager, &module_config));
@ -1025,9 +1037,13 @@ void Talker::load() {
mBigvgan.reset(Module::load({"generated_mel"},
{"waveform"}, mConfig->bigvgan_model().c_str(), mRuntimeManager, &module_config));
// autoregressive decode module
mModulePool[std::make_pair(1, false)].reset(Module::clone(mModules[0].get()));
mModulePool[std::make_pair(1, false)].reset(Module::clone(mModule.get()));
// prefill module
mModulePool[std::make_pair(mPrefillKey, mConfig->all_logits())] = mModules[0];
mModulePool[std::make_pair(mPrefillKey, mConfig->all_logits())] = mModule;
if (mBigvgan.get() == nullptr || mPreDit.get() == nullptr || mDit.get() == nullptr) {
return false;
}
return true;
}
void Talker::generate_init(std::ostream* os, const char* end_with) {
@ -1128,7 +1144,6 @@ VARP Talker::ditForward(const int codec_size, const int* codec_tokens, const flo
y0 = y0 + dy;
}
}
mContext->vision_us += _t.durationInUs();
auto generated_mel = _Permute(y0, {0, 2, 1});
return generated_mel;
}

View File

@ -59,7 +59,7 @@ public:
Talker(std::shared_ptr<LlmConfig> config) : Llm(config), mThinker(nullptr) {}
Talker(std::shared_ptr<LlmConfig> config, Llm* thinker) : Llm(config), mThinker(thinker) {}
~Talker() {}
virtual void load() override;
virtual bool load() override;
virtual void generate_init(std::ostream* os = nullptr, const char* end_with = nullptr) override;
virtual Express::VARP embedding(const std::vector<int>& input_ids) override;
virtual Express::VARP gen_position_ids(int seq_len) override;
@ -102,7 +102,7 @@ public:
mVisionModule.reset();
mAudioModule.reset();
}
virtual void load() override;
virtual bool load() override;
virtual std::vector<Express::VARP> forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) override;
virtual std::vector<int> tokenizer_encode(const std::string& query) override;
virtual std::vector<int> tokenizer_encode(const MultimodalPrompt& multimodal_input) override;