Compare commits
9 Commits
main
...
dev/wucong
| Author | SHA1 | Date |
|---|---|---|
|
|
0a44bca692 | |
|
|
d94908f9d7 | |
|
|
11dbd88947 | |
|
|
9a4aebb0ea | |
|
|
54e9384fb1 | |
|
|
f280558bcb | |
|
|
f6a18ee07a | |
|
|
4df0683a37 | |
|
|
c37c00ff94 |
|
|
@ -49,6 +49,7 @@ class CosyVoice:
|
|||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
self.vllm_codec_engine = None
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
|
|
@ -126,7 +127,7 @@ class CosyVoice:
|
|||
|
||||
class CosyVoice2(CosyVoice):
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vllm=False):
|
||||
self.instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
|
|
@ -149,6 +150,16 @@ class CosyVoice2(CosyVoice):
|
|||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
self.vllm_codec_engine = None
|
||||
if use_vllm:
|
||||
from vllm import EngineArgs, LLMEngine
|
||||
self.model.export_codec_vllm(''.join([model_dir, '/codec_vllm_model']))
|
||||
engine_args = EngineArgs(model=''.join([model_dir, '/codec_vllm_model']),
|
||||
skip_tokenizer_init=True,
|
||||
gpu_memory_utilization=0.2)
|
||||
self.vllm_codec_engine = LLMEngine.from_engine_args(engine_args)
|
||||
self.model.vllm_codec_engine = self.vllm_codec_engine
|
||||
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
if load_trt:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import torch
|
|||
import numpy as np
|
||||
import threading
|
||||
import time
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from contextlib import nullcontext
|
||||
import uuid
|
||||
|
|
@ -65,6 +66,7 @@ class CosyVoiceModel:
|
|||
self.mel_overlap_dict = {}
|
||||
self.flow_cache_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
self.vllm_codec_engine = None
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||
|
|
@ -102,13 +104,23 @@ class CosyVoiceModel:
|
|||
with self.llm_context:
|
||||
if isinstance(text, Generator):
|
||||
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
||||
for i in self.llm.inference_bistream(text=text,
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device)):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
if self.vllm_codec_engine is None:
|
||||
for i in self.llm.inference_bistream(text=text,
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device)):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
else:
|
||||
for i in self.llm.inference_bistream_vllm(text=text,
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device),
|
||||
vllm_codec_engine=self.vllm_codec_engine):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
else:
|
||||
for i in self.llm.inference(text=text.to(self.device),
|
||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
|
|
@ -116,7 +128,8 @@ class CosyVoiceModel:
|
|||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device)):
|
||||
embedding=llm_embedding.to(self.device),
|
||||
vllm_codec_engine=self.vllm_codec_engine):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
self.llm_end_dict[uuid] = True
|
||||
|
||||
|
|
@ -313,10 +326,50 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||
self.tts_speech_token_dict = {}
|
||||
self.llm_end_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
self.vllm_codec_engine = None
|
||||
|
||||
def load_jit(self, flow_encoder_model):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def export_codec_vllm(self, model_path):
|
||||
if os.path.exists(model_path):
|
||||
return
|
||||
pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
vocab_size = self.llm.speech_embedding.num_embeddings
|
||||
feature_size = self.llm.speech_embedding.embedding_dim
|
||||
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
dtype = torch.bfloat16
|
||||
# lm_head
|
||||
new_lm_head = nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
|
||||
with torch.no_grad():
|
||||
new_lm_head.weight[:vocab_size] = self.llm.llm_decoder.weight
|
||||
new_lm_head.bias[:vocab_size] = self.llm.llm_decoder.bias
|
||||
new_lm_head.weight[vocab_size:] = 0
|
||||
new_lm_head.bias[vocab_size:] = 0
|
||||
self.llm.llm.model.lm_head = new_lm_head
|
||||
new_codec_embed = nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
|
||||
# embed_tokens
|
||||
embed_tokens = self.llm.llm.model.model.embed_tokens
|
||||
with torch.no_grad():
|
||||
new_codec_embed.weight[:vocab_size] = self.llm.speech_embedding.weight
|
||||
new_codec_embed.weight[vocab_size:] = 0
|
||||
self.llm.llm.model.set_input_embeddings(new_codec_embed)
|
||||
self.llm.llm.model.to(self.device)
|
||||
self.llm.llm.model.to(dtype)
|
||||
tmp_vocab_size = self.llm.llm.model.config.vocab_size
|
||||
tmp_tie_embedding = self.llm.llm.model.config.tie_word_embeddings
|
||||
del self.llm.llm.model.generation_config.eos_token_id
|
||||
del self.llm.llm.model.config.bos_token_id
|
||||
del self.llm.llm.model.config.eos_token_id
|
||||
self.llm.llm.model.config.vocab_size = pad_vocab_size
|
||||
self.llm.llm.model.config.tie_word_embeddings = False
|
||||
self.llm.llm.model.config.use_bias = True
|
||||
self.llm.llm.model.save_pretrained(model_path)
|
||||
self.llm.llm.model.config.vocab_size = tmp_vocab_size
|
||||
self.llm.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||
self.llm.llm.model.set_input_embeddings(embed_tokens)
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||
|
|
|
|||
|
|
@ -296,6 +296,7 @@ class Qwen2LM(TransformerLM):
|
|||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
vllm_codec_engine=None,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
device = text.device
|
||||
text = torch.concat([prompt_text, text], dim=1)
|
||||
|
|
@ -316,22 +317,51 @@ class Qwen2LM(TransformerLM):
|
|||
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
||||
|
||||
# 5. step by step decode
|
||||
out_tokens = []
|
||||
cache = None
|
||||
for i in range(max_len):
|
||||
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||
if top_ids == self.speech_token_size:
|
||||
break
|
||||
if top_ids > self.speech_token_size:
|
||||
continue
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
if vllm_codec_engine is None:
|
||||
out_tokens = []
|
||||
cache = None
|
||||
for i in range(max_len):
|
||||
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||
if top_ids == self.speech_token_size:
|
||||
break
|
||||
if top_ids > self.speech_token_size:
|
||||
continue
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
else:
|
||||
from vllm import SamplingParams, RequestOutput
|
||||
import uuid
|
||||
sampling_params = SamplingParams(top_k=sampling,
|
||||
stop_token_ids=[6561, 6563],
|
||||
min_tokens=min_len,
|
||||
max_tokens=max_len)
|
||||
request_id = uuid.uuid4()
|
||||
vllm_codec_engine.add_request(request_id,
|
||||
{"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(device)},
|
||||
sampling_params)
|
||||
while True:
|
||||
speech_token_break = False
|
||||
request_outputs: List[RequestOutput] = vllm_codec_engine.step()
|
||||
for request_output in request_outputs:
|
||||
if str(request_output.request_id) != str(request_id):
|
||||
continue
|
||||
# print(f"request output: {request_output}")
|
||||
top_ids = list(request_output.outputs[0].token_ids)[-1]
|
||||
if top_ids == self.speech_token_size:
|
||||
speech_token_break = True
|
||||
break
|
||||
if top_ids > self.speech_token_size:
|
||||
continue
|
||||
yield top_ids
|
||||
|
||||
if not vllm_codec_engine.has_unfinished_requests() or speech_token_break:
|
||||
break
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_bistream(
|
||||
|
|
@ -432,3 +462,129 @@ class Qwen2LM(TransformerLM):
|
|||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
@torch.inference_mode()
|
||||
def inference_bistream_vllm(
|
||||
self,
|
||||
text: Generator,
|
||||
prompt_text: torch.Tensor,
|
||||
prompt_text_len: torch.Tensor,
|
||||
prompt_speech_token: torch.Tensor,
|
||||
prompt_speech_token_len: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
vllm_codec_engine=None,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
|
||||
device = prompt_text.device
|
||||
# 1. prepare input
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_eos_emb], dim=1)
|
||||
|
||||
# 2. iterate text
|
||||
out_tokens = []
|
||||
cache = None
|
||||
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
||||
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
||||
next_fill_index = -1
|
||||
from vllm import SamplingParams, RequestOutput
|
||||
import uuid
|
||||
sampling_params = SamplingParams(top_k=sampling,
|
||||
stop_token_ids=[6561, 6563],
|
||||
max_tokens=10000)
|
||||
for this_text in text:
|
||||
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
||||
# prompt_speech_token_emb not empty, try append to lm_input
|
||||
while prompt_speech_token_emb.size(1) != 0:
|
||||
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||
lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
|
||||
logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
|
||||
lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
|
||||
text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
|
||||
else:
|
||||
logging.info('not enough text token to decode, wait for more')
|
||||
break
|
||||
# no prompt_speech_token_emb remain, can decode some speech token
|
||||
if prompt_speech_token_emb.size(1) == 0:
|
||||
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
||||
logging.info('get fill token, need to append more text token')
|
||||
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||
lm_input_text = text_cache[:, :self.mix_ratio[0]]
|
||||
logging.info('append {} text token'.format(lm_input_text.size(1)))
|
||||
if vllm_codec_engine is None and len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
|
||||
lm_input = lm_input_text
|
||||
else:
|
||||
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
|
||||
text_cache = text_cache[:, self.mix_ratio[0]:]
|
||||
else:
|
||||
logging.info('not enough text token to decode, wait for more')
|
||||
continue
|
||||
request_id = uuid.uuid4()
|
||||
vllm_codec_engine.add_request(request_id,
|
||||
{"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(device)},
|
||||
sampling_params)
|
||||
## generator
|
||||
while True:
|
||||
speech_token_break = False
|
||||
request_outputs: List[RequestOutput] = vllm_codec_engine.step()
|
||||
for request_output in request_outputs:
|
||||
if str(request_output.request_id) != str(request_id):
|
||||
continue
|
||||
# print(f"request output: {request_output}")
|
||||
out_token = list(request_output.outputs[0].token_ids)[-1]
|
||||
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
|
||||
top_ids = self.speech_token_size + 2
|
||||
next_fill_index += (self.mix_ratio[1] + 1)
|
||||
else:
|
||||
top_ids = out_token
|
||||
if top_ids == self.speech_token_size + 2:
|
||||
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
||||
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
||||
out_tokens.append(top_ids)
|
||||
if top_ids >= self.speech_token_size:
|
||||
if top_ids == self.speech_token_size + 2:
|
||||
speech_token_break = True
|
||||
break
|
||||
else:
|
||||
raise ValueError('should not get token {}'.format(top_ids))
|
||||
yield top_ids
|
||||
token_embedding = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
lm_input = torch.concat([lm_input, token_embedding], dim=1)
|
||||
|
||||
if not vllm_codec_engine.has_unfinished_requests() or speech_token_break:
|
||||
break
|
||||
|
||||
# 3. final decode
|
||||
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
|
||||
logging.info('no more text token, decode until met eos')
|
||||
request_id = uuid.uuid4()
|
||||
vllm_codec_engine.add_request(request_id,
|
||||
{"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(device)},
|
||||
sampling_params)
|
||||
## generator
|
||||
while True:
|
||||
speech_token_break = False
|
||||
request_outputs: List[RequestOutput] = vllm_codec_engine.step()
|
||||
for request_output in request_outputs:
|
||||
if str(request_output.request_id) != str(request_id):
|
||||
continue
|
||||
# print(f"request output: {request_output}")
|
||||
top_ids = list(request_output.outputs[0].token_ids)[-1]
|
||||
out_tokens.append(top_ids)
|
||||
if top_ids >= self.speech_token_size:
|
||||
if top_ids == self.speech_token_size:
|
||||
speech_token_break = True
|
||||
break
|
||||
else:
|
||||
raise ValueError('should not get token {}'.format(top_ids))
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
|
||||
if not vllm_codec_engine.has_unfinished_requests() or speech_token_break:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ tensorboard==2.14.0
|
|||
tensorrt-cu12==10.0.1; sys_platform == 'linux'
|
||||
tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux'
|
||||
tensorrt-cu12-libs==10.0.1; sys_platform == 'linux'
|
||||
torch==2.3.1
|
||||
torchaudio==2.3.1
|
||||
torch==2.4.0
|
||||
torchaudio==2.4.0
|
||||
transformers==4.40.1
|
||||
uvicorn==0.30.0
|
||||
wget==3.2
|
||||
|
|
|
|||
Loading…
Reference in New Issue