Compare commits

...

3 Commits

Author SHA1 Message Date
Yu FranzKafka b6a99e001c
Merge fad7396a0a into 11634a1d97 2025-07-25 15:11:27 +08:00
jxt1234 11634a1d97
Merge pull request #3753 from alibaba/feature/bugfix
LLM:Bugfix: For AR generation, keep current_token until generate(1)
2025-07-25 15:04:31 +08:00
xiaying 69d897bcc8 LLM:Bugfix: For AR generation, keep current_token until generate(1), fix
convert crash bug after deconstruct
2025-07-24 17:02:06 +08:00
4 changed files with 22 additions and 35 deletions

View File

@ -593,8 +593,7 @@ using namespace MNN;
using namespace MNN::Express;
std::unique_ptr<MNN::NetT> optimizeNet(std::unique_ptr<MNN::NetT>& originNet, bool forTraining, modelConfig& config, const std::vector<std::string>& expectPasses) {
BackendConfig bnConfig;
auto exe = Executor::newExecutor(MNN_FORWARD_CPU, bnConfig, 1);
ExecutorScope _s(exe);
auto exe = ExecutorScope::Current();
Global<modelConfig>::Reset(&config);
if (!expectPasses.empty()) {
RunNetPass(expectPasses, originNet);

View File

@ -523,7 +523,8 @@ void Llm::generate_init(std::ostream* os, const char* end_with) {
mContext->gen_seq_len = 0;
mContext->prefill_us = 0;
mContext->decode_us = 0;
mContext->current_token = 0;
mContext->current_token = -1;
mContext->sample_us = 0;
if (!mConfig->reuse_kv()) {
mContext->all_seq_len = 0;
mContext->history_tokens.clear();
@ -589,18 +590,14 @@ std::vector<int> Llm::generate(MNN::Express::VARP input_embeds, int max_tokens)
int seqLen = input_embeds->getInfo()->dim[mSeqLenIndex];
mContext->prompt_len = seqLen;
Timer _t;
auto outputs = forwardVec(input_embeds);
if(outputs.size() < 1) {
forwardVec(input_embeds);
if(mGenerateParam->outputs.size() < 1) {
return {};
}
auto logits = outputs[0];
updateContext(seqLen, 0);
if (nullptr == logits.get()) {
return {};
}
// logits compute sync for correct timer
logits->readMap<void>();
mGenerateParam->outputs[0]->readMap<void>();
mContext->prefill_us = _t.durationInUs();
#if DEBUG_MODE == 3
@ -623,10 +620,6 @@ std::vector<int> Llm::generate(MNN::Express::VARP input_embeds, int max_tokens)
#endif
_t.reset();
mContext->current_token = sample(logits);
mContext->sample_us += _t.durationInUs();
logits = nullptr;
// call generation function
mGenerateParam->max_new_tokens = max_tokens;
mGenerationStrategy->generate(*mGenerateParam);

View File

@ -42,10 +42,18 @@ void ArGeneration::generate(GenerationParams& param) {
int len = 0;
while (len < max_token) {
AUTOTIME;
// Update gen seq
mContext->current_token = mLlm->sample(param.outputs[0]);
mContext->history_tokens.push_back(mContext->current_token);
mContext->output_tokens.push_back(mContext->current_token);
// Update gen seq
mLlm->updateContext(0, 1);
if (mLlm->is_stop(mContext->current_token)) {
if (nullptr != mContext->os) {
*mContext->os << mContext->end_with << std::flush;
}
break;
}
// Decode and Output
MNN::Timer _t;
auto decodeStr = mLlm->tokenizer_decode(mContext->current_token);
mContext->generate_str += decodeStr;
@ -53,32 +61,17 @@ void ArGeneration::generate(GenerationParams& param) {
*mContext->os << decodeStr;
*mContext->os << std::flush;
}
// Compute Next Logits
mLlm->mMeta->remove = 0;
auto outputs = mLlm->forwardVec({mContext->current_token});
if(outputs.empty()) {
break;
}
auto logits = outputs[0];
// Update all seq
// Update input seq
mLlm->updateContext(1, 0);
len++;
if (nullptr == logits.get()) {
break;
}
if (logits->getInfo()->size == 0) {
break;
}
mContext->current_token = mLlm->sample(logits);
mContext->decode_us += _t.durationInUs();
if (mLlm->is_stop(mContext->current_token)) {
mContext->history_tokens.push_back(mContext->current_token);
mContext->output_tokens.push_back(mContext->current_token);
mLlm->updateContext(0, 1);
if (nullptr != mContext->os) {
*mContext->os << mContext->end_with << std::flush;
}
break;
}
len++;
}
}

View File

@ -44,7 +44,9 @@ LookaheadGeneration::LookaheadGeneration(Llm* llm, std::shared_ptr<LlmContext> c
}
void LookaheadGeneration::generate(GenerationParams& param) {
if (-1 == mContext->current_token) {
mContext->current_token = mLlm->sample(param.outputs[0]);
}
int max_token = param.max_new_tokens;
int len = 0;
ngram_cache<ngram_value> prompt_ngram_cache;