mirror of https://github.com/alibaba/MNN.git
Compare commits
3 Commits
d249dfc726
...
b6a99e001c
Author | SHA1 | Date |
---|---|---|
|
b6a99e001c | |
|
11634a1d97 | |
|
69d897bcc8 |
|
@ -593,8 +593,7 @@ using namespace MNN;
|
|||
using namespace MNN::Express;
|
||||
std::unique_ptr<MNN::NetT> optimizeNet(std::unique_ptr<MNN::NetT>& originNet, bool forTraining, modelConfig& config, const std::vector<std::string>& expectPasses) {
|
||||
BackendConfig bnConfig;
|
||||
auto exe = Executor::newExecutor(MNN_FORWARD_CPU, bnConfig, 1);
|
||||
ExecutorScope _s(exe);
|
||||
auto exe = ExecutorScope::Current();
|
||||
Global<modelConfig>::Reset(&config);
|
||||
if (!expectPasses.empty()) {
|
||||
RunNetPass(expectPasses, originNet);
|
||||
|
|
|
@ -523,7 +523,8 @@ void Llm::generate_init(std::ostream* os, const char* end_with) {
|
|||
mContext->gen_seq_len = 0;
|
||||
mContext->prefill_us = 0;
|
||||
mContext->decode_us = 0;
|
||||
mContext->current_token = 0;
|
||||
mContext->current_token = -1;
|
||||
mContext->sample_us = 0;
|
||||
if (!mConfig->reuse_kv()) {
|
||||
mContext->all_seq_len = 0;
|
||||
mContext->history_tokens.clear();
|
||||
|
@ -589,18 +590,14 @@ std::vector<int> Llm::generate(MNN::Express::VARP input_embeds, int max_tokens)
|
|||
int seqLen = input_embeds->getInfo()->dim[mSeqLenIndex];
|
||||
mContext->prompt_len = seqLen;
|
||||
Timer _t;
|
||||
auto outputs = forwardVec(input_embeds);
|
||||
if(outputs.size() < 1) {
|
||||
forwardVec(input_embeds);
|
||||
if(mGenerateParam->outputs.size() < 1) {
|
||||
return {};
|
||||
}
|
||||
auto logits = outputs[0];
|
||||
updateContext(seqLen, 0);
|
||||
|
||||
if (nullptr == logits.get()) {
|
||||
return {};
|
||||
}
|
||||
// logits compute sync for correct timer
|
||||
logits->readMap<void>();
|
||||
mGenerateParam->outputs[0]->readMap<void>();
|
||||
mContext->prefill_us = _t.durationInUs();
|
||||
|
||||
#if DEBUG_MODE == 3
|
||||
|
@ -623,10 +620,6 @@ std::vector<int> Llm::generate(MNN::Express::VARP input_embeds, int max_tokens)
|
|||
#endif
|
||||
|
||||
_t.reset();
|
||||
mContext->current_token = sample(logits);
|
||||
mContext->sample_us += _t.durationInUs();
|
||||
logits = nullptr;
|
||||
|
||||
// call generation function
|
||||
mGenerateParam->max_new_tokens = max_tokens;
|
||||
mGenerationStrategy->generate(*mGenerateParam);
|
||||
|
|
|
@ -42,10 +42,18 @@ void ArGeneration::generate(GenerationParams& param) {
|
|||
int len = 0;
|
||||
while (len < max_token) {
|
||||
AUTOTIME;
|
||||
// Update gen seq
|
||||
mContext->current_token = mLlm->sample(param.outputs[0]);
|
||||
mContext->history_tokens.push_back(mContext->current_token);
|
||||
mContext->output_tokens.push_back(mContext->current_token);
|
||||
// Update gen seq
|
||||
mLlm->updateContext(0, 1);
|
||||
if (mLlm->is_stop(mContext->current_token)) {
|
||||
if (nullptr != mContext->os) {
|
||||
*mContext->os << mContext->end_with << std::flush;
|
||||
}
|
||||
break;
|
||||
}
|
||||
// Decode and Output
|
||||
MNN::Timer _t;
|
||||
auto decodeStr = mLlm->tokenizer_decode(mContext->current_token);
|
||||
mContext->generate_str += decodeStr;
|
||||
|
@ -53,32 +61,17 @@ void ArGeneration::generate(GenerationParams& param) {
|
|||
*mContext->os << decodeStr;
|
||||
*mContext->os << std::flush;
|
||||
}
|
||||
|
||||
// Compute Next Logits
|
||||
mLlm->mMeta->remove = 0;
|
||||
auto outputs = mLlm->forwardVec({mContext->current_token});
|
||||
if(outputs.empty()) {
|
||||
break;
|
||||
}
|
||||
auto logits = outputs[0];
|
||||
// Update all seq
|
||||
// Update input seq
|
||||
mLlm->updateContext(1, 0);
|
||||
len++;
|
||||
if (nullptr == logits.get()) {
|
||||
break;
|
||||
}
|
||||
if (logits->getInfo()->size == 0) {
|
||||
break;
|
||||
}
|
||||
mContext->current_token = mLlm->sample(logits);
|
||||
mContext->decode_us += _t.durationInUs();
|
||||
if (mLlm->is_stop(mContext->current_token)) {
|
||||
mContext->history_tokens.push_back(mContext->current_token);
|
||||
mContext->output_tokens.push_back(mContext->current_token);
|
||||
mLlm->updateContext(0, 1);
|
||||
if (nullptr != mContext->os) {
|
||||
*mContext->os << mContext->end_with << std::flush;
|
||||
}
|
||||
break;
|
||||
}
|
||||
len++;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,9 @@ LookaheadGeneration::LookaheadGeneration(Llm* llm, std::shared_ptr<LlmContext> c
|
|||
}
|
||||
|
||||
void LookaheadGeneration::generate(GenerationParams& param) {
|
||||
|
||||
if (-1 == mContext->current_token) {
|
||||
mContext->current_token = mLlm->sample(param.outputs[0]);
|
||||
}
|
||||
int max_token = param.max_new_tokens;
|
||||
int len = 0;
|
||||
ngram_cache<ngram_value> prompt_ngram_cache;
|
||||
|
|
Loading…
Reference in New Issue