[LLM:Bugfix] Bugfix of Omni audio.

This commit is contained in:
zhaode.wzd 2025-05-08 18:04:18 +08:00
parent c8ca39f02b
commit 0eb1db0c34
4 changed files with 25 additions and 12 deletions

View File

@ -487,7 +487,9 @@ python llmexport.py --path /path/to/Qwen2.5-0.5B-Instruct --lora_path /path/to/l
```
#### 获取语音输出
使用Omni模型时可以使用接口`setWavformCallback`获取语音输出,示例如下:
使用Omni模型时可以使用接口`setWavformCallback`获取语音输出,使用接口`generateWavform`开始输出语音。
注意`setWavformCallback`需要在文本生成前调用, `generateWavform`在文本生成结束后调用,示例如下:
1. 保存语音到文件中
```cpp
#include <audio/audio.hpp>
@ -504,6 +506,9 @@ int main() {
}
return true;
});
llm->response("Hello");
// generate wavform
llm->generateWavform();
return 0;
}
@ -600,12 +605,15 @@ bool AudioPlayer::play(const float* ptr, size_t size, bool last_chunk) {
}
int main() {
//....
AudioPlayer audio_player;
//....
AudioPlayer audio_player;
llm->setWavformCallback([&](const float* ptr, size_t size, bool last_chunk) {
return audio_player.play(ptr, size, last_chunk);
});
//....
return 0;
//....
llm->response("Hello");
// generate wavform
llm->generateWavform();
return 0;
}
```

View File

@ -112,6 +112,7 @@ public:
return mContext.get();
}
virtual void setWavformCallback(std::function<bool(const float*, size_t, bool)> callback) {}
virtual void generateWavform() {}
protected:
void initRuntime();
std::shared_ptr<LlmContext> mContext;

View File

@ -299,7 +299,7 @@ std::vector<int> Omni::audioProcess(const std::string& file) {
auto input_features = MNN::AUDIO::whisper_fbank(waveform);
VARP audio_embedding;
if (mAudioModule->getInfo()->inputNames.size() > 1) {
int seqlen = input_features->getInfo()->dim[2] / 2;
int seqlen = UP_DIV(input_features->getInfo()->dim[2], 2);
constexpr int n_window = 100;
std::vector<int> cu_seqlens;
int curseq = 0;
@ -569,6 +569,15 @@ void Omni::response(const std::vector<int>& input_ids, std::ostream* os, const c
mTalker->generate_init();
}
generate(input_ids, max_new_tokens);
}
void Omni::setWavformCallback(std::function<bool(const float*, size_t, bool)> callback) {
if (mTalker) {
mTalker->setWavformCallback(callback);
}
}
void Omni::generateWavform() {
if (mTalker) {
mTalker->generate();
#ifdef DUMP_TALKER_PERFORMANCE
@ -599,12 +608,6 @@ void Omni::response(const std::vector<int>& input_ids, std::ostream* os, const c
}
}
void Omni::setWavformCallback(std::function<bool(const float*, size_t, bool)> callback) {
if (mTalker) {
mTalker->setWavformCallback(callback);
}
}
void Talker::load() {
MNN::BackendConfig backendConfig;
auto executor = MNN::Express::Executor::newExecutor(MNN_FORWARD_CPU, backendConfig, 1);

View File

@ -110,6 +110,7 @@ public:
virtual Express::VARP gen_position_ids(int seq_len) override;
virtual void response(const std::vector<int>& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, int max_new_tokens = -1) override;
virtual void setWavformCallback(std::function<bool(const float*, size_t, bool)> callback) override;
virtual void generateWavform() override;
private:
int mVisionHeight = 448, mVisionWidth = 448, mVisionStart = 151857,
mVisionEnd = 151858, mVisionPad = 151859, mAudioPad = 151646;