mirror of https://github.com/alibaba/MNN.git
Merge pull request #3874 from alibaba/feature/llmrefractor
Feature/llmrefractor
This commit is contained in:
commit
2ddf67d37f
|
@ -1,4 +1,5 @@
|
|||
#include <cmath>
|
||||
#include <limits.h>
|
||||
#include "VulkanGaussianRender.hpp"
|
||||
namespace MNN {
|
||||
struct ImageConstant {
|
||||
|
|
|
@ -71,8 +71,6 @@ std::vector<std::vector<std::string>> parse_csv(const std::vector<std::string>&
|
|||
static int benchmark(Llm* llm, const std::vector<std::string>& prompts, int max_token_number) {
|
||||
int prompt_len = 0;
|
||||
int decode_len = 0;
|
||||
int64_t vision_time = 0;
|
||||
int64_t audio_time = 0;
|
||||
int64_t prefill_time = 0;
|
||||
int64_t decode_time = 0;
|
||||
int64_t sample_time = 0;
|
||||
|
@ -117,29 +115,39 @@ static int benchmark(Llm* llm, const std::vector<std::string>& prompts, int max_
|
|||
}
|
||||
prompt_len += context->prompt_len;
|
||||
decode_len += context->gen_seq_len;
|
||||
vision_time += context->vision_us;
|
||||
audio_time += context->audio_us;
|
||||
prefill_time += context->prefill_us;
|
||||
decode_time += context->decode_us;
|
||||
sample_time += context->sample_us;
|
||||
}
|
||||
llm->generateWavform();
|
||||
|
||||
float vision_s = vision_time / 1e6;
|
||||
float audio_s = audio_time / 1e6;
|
||||
float vision_s = context->vision_us / 1e6;
|
||||
float audio_s = context->audio_us / 1e6;
|
||||
float prefill_s = prefill_time / 1e6;
|
||||
float decode_s = decode_time / 1e6;
|
||||
float sample_s = sample_time / 1e6;
|
||||
float vision_speed = 0.0f;
|
||||
if (context->pixels_mp > 0.0f) {
|
||||
vision_speed = context->pixels_mp / vision_s;
|
||||
}
|
||||
float audio_speed = 0.0f;
|
||||
if (context->audio_input_s > 0.0f) {
|
||||
audio_speed = context->audio_input_s / audio_s;
|
||||
}
|
||||
printf("\n#################################\n");
|
||||
printf("prompt tokens num = %d\n", prompt_len);
|
||||
printf("decode tokens num = %d\n", decode_len);
|
||||
printf(" vision time = %.2f s\n", vision_s);
|
||||
printf(" audio time = %.2f s\n", audio_s);
|
||||
printf(" pixels_mp = %.2f MP\n", context->pixels_mp);
|
||||
printf(" audio process time = %.2f s\n", audio_s);
|
||||
printf(" audio input time = %.2f s\n", context->audio_input_s);
|
||||
printf("prefill time = %.2f s\n", prefill_s);
|
||||
printf(" decode time = %.2f s\n", decode_s);
|
||||
printf(" sample time = %.2f s\n", sample_s);
|
||||
printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s);
|
||||
printf(" decode speed = %.2f tok/s\n", decode_len / decode_s);
|
||||
printf(" vision speed = %.3f MP/s\n", vision_speed);
|
||||
printf(" audio RTF = %.3f \n", audio_s / context->audio_input_s);
|
||||
printf("##################################\n");
|
||||
return 0;
|
||||
}
|
||||
|
@ -256,7 +264,11 @@ int main(int argc, const char* argv[]) {
|
|||
llm->set_config("{\"tmp_path\":\"tmp\"}");
|
||||
{
|
||||
AUTOTIME;
|
||||
llm->load();
|
||||
bool res = llm->load();
|
||||
if (!res) {
|
||||
MNN_ERROR("LLM init error\n");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
if (true) {
|
||||
AUTOTIME;
|
||||
|
|
|
@ -76,8 +76,8 @@ struct LlmContext {
|
|||
int64_t prefill_us = 0;
|
||||
int64_t decode_us = 0;
|
||||
int64_t sample_us = 0;
|
||||
float prefill_mb = 0;
|
||||
float decode_mb = 0;
|
||||
float pixels_mp = 0;
|
||||
float audio_input_s = 0;
|
||||
// tokens
|
||||
int current_token;
|
||||
std::vector<int> history_tokens;
|
||||
|
@ -95,7 +95,7 @@ public:
|
|||
static void destroy(Llm* llm);// For Windows RT mode should use destroy
|
||||
Llm(std::shared_ptr<LlmConfig> config);
|
||||
virtual ~Llm();
|
||||
virtual void load();
|
||||
virtual bool load();
|
||||
virtual Express::VARP gen_attention_mask(int seq_len);
|
||||
virtual Express::VARP gen_position_ids(int seq_len);
|
||||
virtual Express::VARP embedding(const std::vector<int>& input_ids);
|
||||
|
@ -152,7 +152,7 @@ protected:
|
|||
std::shared_ptr<DiskEmbedding> mDiskEmbedding;
|
||||
std::shared_ptr<Sampler> mSampler;
|
||||
std::shared_ptr<Express::Executor::RuntimeManager> mRuntimeManager, mProcessorRuntimeManager;
|
||||
std::vector<std::shared_ptr<Express::Module>> mModules;
|
||||
std::shared_ptr<Express::Module> mModule;
|
||||
/**
|
||||
key: <seq_len, all_logists>
|
||||
value : module
|
||||
|
@ -190,7 +190,7 @@ public:
|
|||
static Embedding* createEmbedding(const std::string& config_path, bool load = true);
|
||||
static float dist(Express::VARP var0, Express::VARP var1);
|
||||
static float cos_sim(Express::VARP var0, Express::VARP var1);
|
||||
virtual void load() override;
|
||||
virtual bool load() override;
|
||||
Express::VARP ids_embedding(const std::vector<int>& ids);
|
||||
Express::VARP txt_embedding(const std::string& txt);
|
||||
int dim() const;
|
||||
|
|
|
@ -45,7 +45,7 @@ int Embedding::dim() const {
|
|||
return mConfig->hidden_size();
|
||||
}
|
||||
|
||||
void Embedding::load() {
|
||||
bool Embedding::load() {
|
||||
initRuntime();
|
||||
printf("load tokenizer\n");
|
||||
std::cout << mConfig->tokenizer_file() << std::endl;
|
||||
|
@ -59,10 +59,13 @@ void Embedding::load() {
|
|||
module_config.rearrange = true;
|
||||
auto model_path = mConfig->llm_model();
|
||||
MNN_PRINT("load %s ... ", model_path.c_str());
|
||||
mModules.resize(1);
|
||||
mModules[0].reset(Module::load({"input_ids", "attention_mask", "position_ids"}, {"sentence_embeddings"},
|
||||
mModule.reset(Module::load({"input_ids", "attention_mask", "position_ids"}, {"sentence_embeddings"},
|
||||
model_path.c_str(), mRuntimeManager, &module_config));
|
||||
if (nullptr == mModule.get()) {
|
||||
return false;
|
||||
}
|
||||
MNN_PRINT("Done!\n");
|
||||
return true;
|
||||
}
|
||||
|
||||
VARP Embedding::ids_embedding(const std::vector<int>& ids) {
|
||||
|
@ -70,7 +73,7 @@ VARP Embedding::ids_embedding(const std::vector<int>& ids) {
|
|||
auto inputs_ids = embedding(ids);
|
||||
auto attention_mask = gen_attention_mask(prompt_len);
|
||||
auto position_ids = gen_position_ids(prompt_len);
|
||||
auto outputs = mModules[0]->onForward({inputs_ids, attention_mask, position_ids});
|
||||
auto outputs = mModule->onForward({inputs_ids, attention_mask, position_ids});
|
||||
auto sentence_embeddings = outputs[0];
|
||||
return sentence_embeddings;
|
||||
}
|
||||
|
|
|
@ -234,7 +234,7 @@ static bool canSpecDecode(std::shared_ptr<Express::Module> module) {
|
|||
void Llm::setSpeculativeConfig() {
|
||||
auto specultive_type = mConfig->speculative_type();
|
||||
if(!specultive_type.empty()) {
|
||||
if(!canSpecDecode(mModules[0])) {
|
||||
if(!canSpecDecode(mModule)) {
|
||||
mInSpec = false;
|
||||
return;
|
||||
}
|
||||
|
@ -243,7 +243,7 @@ void Llm::setSpeculativeConfig() {
|
|||
}
|
||||
}
|
||||
|
||||
void Llm::load() {
|
||||
bool Llm::load() {
|
||||
initRuntime();
|
||||
// init module status
|
||||
// 1. load vocab
|
||||
|
@ -264,7 +264,6 @@ void Llm::load() {
|
|||
module_config.base = mBaseModule;
|
||||
}
|
||||
// load single model
|
||||
mModules.resize(1);
|
||||
std::string model_path = mConfig->llm_model();
|
||||
|
||||
std::vector<std::string> inputNames {"input_ids", "attention_mask", "position_ids", "logits_index"};
|
||||
|
@ -284,14 +283,14 @@ void Llm::load() {
|
|||
}
|
||||
|
||||
mRuntimeManager->setExternalFile(mConfig->llm_weight());
|
||||
mModules[0].reset(Module::load(inputNames, outputNames, model_path.c_str(), mRuntimeManager, &module_config));
|
||||
mModule.reset(Module::load(inputNames, outputNames, model_path.c_str(), mRuntimeManager, &module_config));
|
||||
mRuntimeManager->setExternalFile("");
|
||||
if(nullptr == mModules[0]) {
|
||||
if(nullptr == mModule) {
|
||||
MNN_ERROR("[Error]: Load module failed, please check model.\n");
|
||||
if(outputNames.size() > 1) {
|
||||
MNN_ERROR("[Warning]: Set module multi outputs, please double check.\n");
|
||||
}
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
// set speculative decoding params
|
||||
setSpeculativeConfig();
|
||||
|
@ -305,13 +304,13 @@ void Llm::load() {
|
|||
decode_type_num = 2;
|
||||
verify_length = mDraftLength + 1;
|
||||
// speculative decode module
|
||||
mModulePool[std::make_pair(verify_length, true)].reset(Module::clone(mModules[0].get()));
|
||||
mModulePool[std::make_pair(verify_length, true)].reset(Module::clone(mModule.get()));
|
||||
}
|
||||
|
||||
// autoregressive decode module
|
||||
mModulePool[std::make_pair(1, false)].reset(Module::clone(mModules[0].get()));
|
||||
mModulePool[std::make_pair(1, false)].reset(Module::clone(mModule.get()));
|
||||
// prefill module
|
||||
mModulePool[std::make_pair(mPrefillKey, mConfig->all_logits())] = mModules[0];
|
||||
mModulePool[std::make_pair(mPrefillKey, mConfig->all_logits())] = mModule;
|
||||
|
||||
// module input varp setting
|
||||
logitsLastIdx = _var<int>({-1}, {1});
|
||||
|
@ -340,12 +339,13 @@ void Llm::load() {
|
|||
|
||||
// MTP model load
|
||||
mGenerationStrategy->load(module_config);
|
||||
return true;
|
||||
}
|
||||
|
||||
Llm* Llm::create_lora(const std::string& lora_path) {
|
||||
auto llm = new Llm(std::make_shared<LlmConfig>(*mConfig));
|
||||
llm->set_config("{\"llm_model\": \"" + lora_path + "\", \"use_mmap\": false, \"use_cached_mmap\": false}");
|
||||
llm->mBaseModule = mModules.begin()->get();
|
||||
llm->mBaseModule = mModule.get();
|
||||
llm->load();
|
||||
return llm;
|
||||
}
|
||||
|
@ -426,7 +426,7 @@ std::vector<Express::VARP> Llm::forwardRaw(Express::VARP hiddenState, Express::V
|
|||
if(mModulePool.find(moduleKey) == mModulePool.end()) {
|
||||
MNN_PRINT("Warning: module need new clone, cloning now.\n");
|
||||
mRuntimeManager->setHintPtr(Interpreter::KVCACHE_INFO, mMeta.get());
|
||||
mModulePool[moduleKey].reset(Module::clone(mModules[0].get()));
|
||||
mModulePool[moduleKey].reset(Module::clone(mModule.get()));
|
||||
}
|
||||
|
||||
if (isAllLogists) {
|
||||
|
@ -554,6 +554,10 @@ void Llm::reset() {
|
|||
mContext->history_tokens.clear();
|
||||
mContext->all_seq_len = 0;
|
||||
mContext->gen_seq_len = 0;
|
||||
mContext->vision_us = 0;
|
||||
mContext->pixels_mp = 0.0f;
|
||||
mContext->audio_us = 0;
|
||||
mContext->audio_input_s = 0.0f;
|
||||
mMeta->remove = mMeta->previous;
|
||||
}
|
||||
|
||||
|
@ -756,7 +760,7 @@ Llm::~Llm() {
|
|||
}
|
||||
#endif
|
||||
mGenerateParam.reset();
|
||||
mModules.clear();
|
||||
mModule.reset();
|
||||
mRuntimeManager.reset();
|
||||
mProcessorRuntimeManager.reset();
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
// Created by MNN on 2025/04/08.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
//#define MNN_OPEN_TIME_TRACE
|
||||
|
||||
#ifdef _WIN32
|
||||
#define _USE_MATH_DEFINES
|
||||
|
@ -25,7 +26,6 @@
|
|||
#ifdef LLM_SUPPORT_AUDIO
|
||||
#include <audio/audio.hpp>
|
||||
#endif
|
||||
|
||||
namespace MNN {
|
||||
using namespace Express;
|
||||
namespace Transformer {
|
||||
|
@ -69,11 +69,17 @@ Omni::Omni(std::shared_ptr<LlmConfig> config) : Llm(config) {
|
|||
if (config->is_audio()) {}
|
||||
}
|
||||
|
||||
void Omni::load() {
|
||||
Llm::load();
|
||||
bool Omni::load() {
|
||||
auto res = Llm::load();
|
||||
if (!res) {
|
||||
return false;
|
||||
}
|
||||
if (mConfig->has_talker()) {
|
||||
mTalker.reset(new Talker(mConfig, this));
|
||||
mTalker->load();
|
||||
res = mTalker->load();
|
||||
}
|
||||
if (!res) {
|
||||
return false;
|
||||
}
|
||||
ScheduleConfig config;
|
||||
if (mConfig->mllm_config_.empty()) {
|
||||
|
@ -118,6 +124,7 @@ void Omni::load() {
|
|||
if (mConfig->is_audio()) {
|
||||
mAudioModule.reset(Module::load({}, {}, mConfig->audio_model().c_str(), mProcessorRuntimeManager, &module_config));
|
||||
}
|
||||
return mAudioModule.get() != nullptr && mVisionModule.get() != nullptr;
|
||||
}
|
||||
|
||||
#ifdef LLM_SUPPORT_VISION
|
||||
|
@ -142,6 +149,7 @@ std::vector<int> Omni::defaultVisionProcess(VARP image) {
|
|||
}
|
||||
|
||||
std::vector<int> Omni::qwen2VisionProcess(VARP image) {
|
||||
AUTOTIME;
|
||||
const auto inputNames = mVisionModule->getInfo()->inputNames;
|
||||
bool hasWindowIndex = inputNames.size() == 4 && inputNames[3] == "window_index";
|
||||
// Qwen2-VL / Qwen2.5-VL
|
||||
|
@ -583,6 +591,7 @@ std::vector<int> Omni::visionProcess(VARP image) {
|
|||
imgIds = defaultVisionProcess(image);
|
||||
}
|
||||
mContext->vision_us += _t.durationInUs();
|
||||
mContext->pixels_mp += (mVisionWidth / 1000.0f) * (mVisionHeight / 1000.0f);
|
||||
// set vision number for image idx
|
||||
mVisionNum += 1;
|
||||
return imgIds;
|
||||
|
@ -600,6 +609,7 @@ std::vector<int> Omni::audioProcess(const std::string& file) {
|
|||
MNN_PRINT("Omni Can't open audio: %s\n", file.c_str());
|
||||
return std::vector<int>(0);
|
||||
}
|
||||
mContext->audio_input_s += (float)(waveform->getInfo()->size) / sample_rate;
|
||||
return audioProcess(waveform);
|
||||
#else
|
||||
return std::vector<int>(0);
|
||||
|
@ -899,7 +909,7 @@ static inline bool needNewVar(VARP var, int axis, int seq_len) {
|
|||
}
|
||||
|
||||
VARP Omni::gen_position_ids(int seq_len) {
|
||||
auto positionIdsDims = mModules[0]->getInfo()->inputs[2].dim;
|
||||
auto positionIdsDims = mModule->getInfo()->inputs[2].dim;
|
||||
if (positionIdsDims[0] == 1) {
|
||||
return Llm::gen_position_ids(seq_len);
|
||||
}
|
||||
|
@ -989,7 +999,7 @@ void Omni::generateWavform() {
|
|||
}
|
||||
}
|
||||
|
||||
void Talker::load() {
|
||||
bool Talker::load() {
|
||||
initRuntime();
|
||||
mSeqLenIndex = 1;
|
||||
set_config("{\"sampler_type\": \"mixed\", \"temperature\": 0.9, \"topK\": 40, \"topP\": 0.8, \"penalty\": 1.05}");
|
||||
|
@ -1011,11 +1021,13 @@ void Talker::load() {
|
|||
Module::Config module_config;
|
||||
module_config.shapeMutable = false;
|
||||
module_config.rearrange = true;
|
||||
mModules.resize(1);
|
||||
std::vector<std::string> inputNames {"inputs_embeds", "attention_mask", "position_ids", "logits_index"};
|
||||
|
||||
mModules[0].reset(Module::load(inputNames,
|
||||
mModule.reset(Module::load(inputNames,
|
||||
{"logits"}, mConfig->talker_model().c_str(), mRuntimeManager, &module_config));
|
||||
if (mModule.get() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
// dit
|
||||
mPreDit.reset(Module::load({"cond", "spk", "code"}, {"code_embeds", "rope", "mask"},
|
||||
mConfig->predit_model().c_str(), mRuntimeManager, &module_config));
|
||||
|
@ -1025,9 +1037,13 @@ void Talker::load() {
|
|||
mBigvgan.reset(Module::load({"generated_mel"},
|
||||
{"waveform"}, mConfig->bigvgan_model().c_str(), mRuntimeManager, &module_config));
|
||||
// autoregressive decode module
|
||||
mModulePool[std::make_pair(1, false)].reset(Module::clone(mModules[0].get()));
|
||||
mModulePool[std::make_pair(1, false)].reset(Module::clone(mModule.get()));
|
||||
// prefill module
|
||||
mModulePool[std::make_pair(mPrefillKey, mConfig->all_logits())] = mModules[0];
|
||||
mModulePool[std::make_pair(mPrefillKey, mConfig->all_logits())] = mModule;
|
||||
if (mBigvgan.get() == nullptr || mPreDit.get() == nullptr || mDit.get() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Talker::generate_init(std::ostream* os, const char* end_with) {
|
||||
|
@ -1128,7 +1144,6 @@ VARP Talker::ditForward(const int codec_size, const int* codec_tokens, const flo
|
|||
y0 = y0 + dy;
|
||||
}
|
||||
}
|
||||
mContext->vision_us += _t.durationInUs();
|
||||
auto generated_mel = _Permute(y0, {0, 2, 1});
|
||||
return generated_mel;
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ public:
|
|||
Talker(std::shared_ptr<LlmConfig> config) : Llm(config), mThinker(nullptr) {}
|
||||
Talker(std::shared_ptr<LlmConfig> config, Llm* thinker) : Llm(config), mThinker(thinker) {}
|
||||
~Talker() {}
|
||||
virtual void load() override;
|
||||
virtual bool load() override;
|
||||
virtual void generate_init(std::ostream* os = nullptr, const char* end_with = nullptr) override;
|
||||
virtual Express::VARP embedding(const std::vector<int>& input_ids) override;
|
||||
virtual Express::VARP gen_position_ids(int seq_len) override;
|
||||
|
@ -102,7 +102,7 @@ public:
|
|||
mVisionModule.reset();
|
||||
mAudioModule.reset();
|
||||
}
|
||||
virtual void load() override;
|
||||
virtual bool load() override;
|
||||
virtual std::vector<Express::VARP> forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) override;
|
||||
virtual std::vector<int> tokenizer_encode(const std::string& query) override;
|
||||
virtual std::vector<int> tokenizer_encode(const MultimodalPrompt& multimodal_input) override;
|
||||
|
|
Loading…
Reference in New Issue