Compare commits
17 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
4c9a4c2fed | |
|
|
e8a26827ae | |
|
|
ab74475604 | |
|
|
369f3c2c18 | |
|
|
7f4c9a2c64 | |
|
|
fd9b7d45e2 | |
|
|
62e04e8856 | |
|
|
96950745a6 | |
|
|
9b3f351496 | |
|
|
00b454cf30 | |
|
|
c0f6a474f3 | |
|
|
ab5b8eb160 | |
|
|
b4fe05d466 | |
|
|
a1314e573a | |
|
|
2fbeba50ae | |
|
|
d4d187bd8c | |
|
|
90b666ea20 |
|
|
@ -19,7 +19,7 @@ from hyperpyyaml import load_hyperpyyaml
|
|||
from modelscope import snapshot_download
|
||||
import torch
|
||||
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, VllmCosyVoice2Model
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.utils.class_utils import get_model_type
|
||||
|
||||
|
|
@ -54,15 +54,20 @@ class CosyVoice:
|
|||
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
if load_trt:
|
||||
self.estimator_count = configs.get('estimator_count', 1)
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
self.fp16)
|
||||
self.fp16, self.estimator_count)
|
||||
del configs
|
||||
|
||||
|
||||
def list_available_spks(self):
|
||||
spks = list(self.frontend.spk2info.keys())
|
||||
return spks
|
||||
|
||||
def add_spk_info(self, spk_id, spk_info):
|
||||
self.frontend.add_spk_info(spk_id, spk_info)
|
||||
|
||||
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_sft(i, spk_id)
|
||||
|
|
@ -88,6 +93,22 @@ class CosyVoice:
|
|||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_zero_shot_by_spk_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
||||
"""使用预定义的说话人执行 zero_shot 推理"""
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_zero_shot_by_spk_id(i, spk_id)
|
||||
start_time = time.time()
|
||||
last_time = start_time
|
||||
chunk_index = 0
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
logging.info('yield speech index:{}, len {:.2f}, rtf {:.3f}, cost {:.3f}s, all cost time {:.3f}s'.format(
|
||||
chunk_index, speech_len, (time.time()-last_time)/speech_len, time.time()-last_time, time.time()-start_time))
|
||||
yield model_output
|
||||
last_time = time.time()
|
||||
chunk_index += 1
|
||||
|
||||
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
||||
|
|
@ -126,7 +147,7 @@ class CosyVoice:
|
|||
|
||||
class CosyVoice2(CosyVoice):
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vllm=False):
|
||||
self.instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
|
|
@ -145,18 +166,27 @@ class CosyVoice2(CosyVoice):
|
|||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||
load_jit, load_trt, fp16 = False, False, False
|
||||
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
if use_vllm:
|
||||
try:
|
||||
self.model = VllmCosyVoice2Model(model_dir, configs['flow'], configs['hift'], fp16)
|
||||
except Exception as e:
|
||||
logging.warning(f'use vllm inference failed. \n{e}')
|
||||
raise e
|
||||
else:
|
||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
if load_trt:
|
||||
self.estimator_count = configs.get('estimator_count', 1)
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
self.fp16)
|
||||
self.fp16, self.estimator_count)
|
||||
del configs
|
||||
|
||||
|
||||
def inference_instruct(self, *args, **kwargs):
|
||||
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
|
||||
|
||||
|
|
@ -171,3 +201,14 @@ class CosyVoice2(CosyVoice):
|
|||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_instruct2_by_spk_id(i, instruct_text, spk_id)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
from typing import Generator, Optional
|
||||
import json
|
||||
import onnxruntime
|
||||
import torch
|
||||
|
|
@ -24,6 +24,8 @@ import torchaudio
|
|||
import os
|
||||
import re
|
||||
import inflect
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
try:
|
||||
import ttsfrd
|
||||
use_ttsfrd = True
|
||||
|
|
@ -36,6 +38,18 @@ from cosyvoice.utils.file_utils import logging
|
|||
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
||||
|
||||
|
||||
class SpeakerInfo(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
name: Optional[str] = None
|
||||
spk_id: str
|
||||
prompt_text: str
|
||||
prompt_text_token: torch.Tensor
|
||||
speech_feat: torch.Tensor
|
||||
speech_token: torch.Tensor
|
||||
embedding: torch.Tensor
|
||||
|
||||
|
||||
class CosyVoiceFrontEnd:
|
||||
|
||||
def __init__(self,
|
||||
|
|
@ -55,8 +69,9 @@ class CosyVoiceFrontEnd:
|
|||
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
||||
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
||||
"CPUExecutionProvider"])
|
||||
self.spk2info_path = spk2info
|
||||
if os.path.exists(spk2info):
|
||||
self.spk2info = torch.load(spk2info, map_location=self.device)
|
||||
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=False)
|
||||
else:
|
||||
self.spk2info = {}
|
||||
self.allowed_special = allowed_special
|
||||
|
|
@ -68,7 +83,8 @@ class CosyVoiceFrontEnd:
|
|||
'failed to initialize ttsfrd resource'
|
||||
self.frd.set_lang_type('pinyinvg')
|
||||
else:
|
||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
|
||||
# self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
|
||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=False)
|
||||
self.en_tn_model = EnNormalizer()
|
||||
self.inflect_parser = inflect.engine()
|
||||
|
||||
|
|
@ -138,11 +154,15 @@ class CosyVoiceFrontEnd:
|
|||
text = text.replace(" - ", ",")
|
||||
text = remove_bracket(text)
|
||||
text = re.sub(r'[,,、]+$', '。', text)
|
||||
if not split:
|
||||
return text
|
||||
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
||||
token_min_n=60, merge_len=20, comma_split=False))
|
||||
else:
|
||||
text = self.en_tn_model.normalize(text)
|
||||
text = spell_out_number(text, self.inflect_parser)
|
||||
if not split:
|
||||
return text
|
||||
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
||||
token_min_n=60, merge_len=20, comma_split=False))
|
||||
texts = [i for i in texts if not is_only_punctuation(i)]
|
||||
|
|
@ -151,6 +171,7 @@ class CosyVoiceFrontEnd:
|
|||
def frontend_sft(self, tts_text, spk_id):
|
||||
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
||||
embedding = self.spk2info[spk_id]['embedding']
|
||||
assert embedding is not None
|
||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
||||
return model_input
|
||||
|
||||
|
|
@ -209,3 +230,60 @@ class CosyVoiceFrontEnd:
|
|||
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
||||
'flow_embedding': embedding}
|
||||
return model_input
|
||||
|
||||
def generate_spk_info(self, spk_id: str, prompt_text: str, prompt_speech_16k: torch.Tensor, resample_rate:int=24000, name: str=None):
|
||||
assert isinstance(spk_id, str)
|
||||
assert spk_id not in self.spk2info, "spk_id already exists"
|
||||
prompt_text_token, _ = self._extract_text_token(prompt_text)
|
||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
||||
speech_feat, _ = self._extract_speech_feat(prompt_speech_resample)
|
||||
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
||||
if resample_rate == 24000:
|
||||
# cosyvoice2, force speech_feat % speech_token = 2
|
||||
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
||||
speech_feat = speech_feat[:, :2 * token_len]
|
||||
speech_token = speech_token[:, :token_len]
|
||||
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
||||
spk_info = SpeakerInfo(
|
||||
name=name,
|
||||
spk_id=spk_id,
|
||||
prompt_text=prompt_text,
|
||||
prompt_text_token=prompt_text_token,
|
||||
speech_feat=speech_feat,
|
||||
speech_token=speech_token,
|
||||
embedding=embedding,
|
||||
)
|
||||
self.add_spk_info(spk_id, spk_info)
|
||||
|
||||
def add_spk_info(self, spk_id: str, spk_info: dict|SpeakerInfo):
|
||||
if isinstance(spk_info, BaseModel):
|
||||
spk_info = spk_info.model_dump()
|
||||
self.spk2info[spk_id] = spk_info
|
||||
if self.spk2info_path:
|
||||
torch.save(self.spk2info, self.spk2info_path)
|
||||
|
||||
def frontend_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id):
|
||||
assert spk_id in self.spk2info
|
||||
tts_text_token, _ = self._extract_text_token(tts_text)
|
||||
prompt_text_token, _ = self._extract_text_token(instruct_text + '<|endofprompt|>')
|
||||
model_input = {'text': tts_text_token,
|
||||
'prompt_text': prompt_text_token,
|
||||
'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
|
||||
'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'],
|
||||
'llm_embedding': self.spk2info[spk_id]['embedding'],
|
||||
'flow_embedding': self.spk2info[spk_id]['embedding'],
|
||||
}
|
||||
return model_input
|
||||
|
||||
def frontend_zero_shot_by_spk_id(self, tts_text, spk_id):
|
||||
assert spk_id in self.spk2info
|
||||
tts_text_token, _ = self._extract_text_token(tts_text)
|
||||
model_input = {'text': tts_text_token,
|
||||
'prompt_text': self.spk2info[spk_id]['prompt_text_token'],
|
||||
'llm_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
|
||||
'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
|
||||
'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'],
|
||||
'llm_embedding': self.spk2info[spk_id]['embedding'],
|
||||
'flow_embedding': self.spk2info[spk_id]['embedding']
|
||||
}
|
||||
return model_input
|
||||
|
|
@ -22,7 +22,8 @@ from contextlib import nullcontext
|
|||
import uuid
|
||||
from cosyvoice.utils.common import fade_in_out
|
||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
||||
|
||||
from cosyvoice.flow.flow_matching import EstimatorWrapper
|
||||
import queue
|
||||
|
||||
class CosyVoiceModel:
|
||||
|
||||
|
|
@ -66,6 +67,12 @@ class CosyVoiceModel:
|
|||
self.flow_cache_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
|
||||
self.stream_context_pool = queue.Queue()
|
||||
for _ in range(10):
|
||||
self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
|
||||
|
||||
self.is_cuda_available = torch.cuda.is_available()
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||
self.llm.to(self.device).eval()
|
||||
|
|
@ -84,7 +91,7 @@ class CosyVoiceModel:
|
|||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16, estimator_count=1):
|
||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||
if not os.path.exists(flow_decoder_estimator_model):
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
|
||||
|
|
@ -96,7 +103,7 @@ class CosyVoiceModel:
|
|||
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
if self.flow.decoder.estimator_engine is None:
|
||||
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
|
||||
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
||||
self.flow.decoder.estimator = EstimatorWrapper(self.flow.decoder.estimator_engine, estimator_count=estimator_count)
|
||||
|
||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||
with self.llm_context:
|
||||
|
|
@ -122,13 +129,13 @@ class CosyVoiceModel:
|
|||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
||||
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
flow_cache=self.flow_cache_dict[uuid])
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
flow_cache=self.flow_cache_dict[uuid])
|
||||
self.flow_cache_dict[uuid] = flow_cache
|
||||
|
||||
# mel overlap fade in out
|
||||
|
|
@ -148,8 +155,8 @@ class CosyVoiceModel:
|
|||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||
'source': tts_source[:, :, -self.source_cache_len:],
|
||||
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||
'source': tts_source[:, :, -self.source_cache_len:],
|
||||
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||
else:
|
||||
if speed != 1.0:
|
||||
|
|
@ -166,63 +173,70 @@ class CosyVoiceModel:
|
|||
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
||||
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
p.start()
|
||||
if stream is True:
|
||||
token_hop_len = self.token_min_hop_len
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
||||
.unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=False)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
||||
# increase token_hop_len for better speech quality
|
||||
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
||||
break
|
||||
p.join()
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
p.join()
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True,
|
||||
speed=speed)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
self.flow_cache_dict.pop(this_uuid)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
stream_context = self.stream_context_pool.get()
|
||||
with stream_context:
|
||||
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
||||
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
p.start()
|
||||
if stream is True:
|
||||
token_hop_len = self.token_min_hop_len
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
||||
.unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=False)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
||||
# increase token_hop_len for better speech quality
|
||||
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
||||
break
|
||||
p.join()
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
p.join()
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True,
|
||||
speed=speed)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
self.flow_cache_dict.pop(this_uuid)
|
||||
|
||||
self.synchronize_stream()
|
||||
self.stream_context_pool.put(stream_context)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
|
|
@ -278,6 +292,10 @@ class CosyVoiceModel:
|
|||
self.hift_cache_dict.pop(this_uuid)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def synchronize_stream(self):
|
||||
if self.is_cuda_available:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
|
||||
class CosyVoice2Model(CosyVoiceModel):
|
||||
|
||||
|
|
@ -314,19 +332,26 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||
self.llm_end_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
|
||||
self.stream_context_pool = queue.Queue()
|
||||
for _ in range(10):
|
||||
self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
|
||||
|
||||
self.is_cuda_available = torch.cuda.is_available()
|
||||
|
||||
def load_jit(self, flow_encoder_model):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
||||
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
finalize=finalize)
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
finalize=finalize)
|
||||
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||
# append hift cache
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
|
|
@ -340,8 +365,8 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||
'source': tts_source[:, :, -self.source_cache_len:],
|
||||
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||
'source': tts_source[:, :, -self.source_cache_len:],
|
||||
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||
else:
|
||||
if speed != 1.0:
|
||||
|
|
@ -358,54 +383,84 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
p.start()
|
||||
if stream is True:
|
||||
token_offset = 0
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
token_offset=token_offset,
|
||||
finalize=False)
|
||||
token_offset += self.token_hop_len
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
|
||||
break
|
||||
p.join()
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
token_offset=token_offset,
|
||||
finalize=True)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
p.join()
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
token_offset=0,
|
||||
finalize=True,
|
||||
speed=speed)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
torch.cuda.empty_cache()
|
||||
self.synchronize_stream()
|
||||
stream_context = self.stream_context_pool.get()
|
||||
with stream_context:
|
||||
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
p.start()
|
||||
if stream is True:
|
||||
token_offset = 0
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
token_offset=token_offset,
|
||||
finalize=False)
|
||||
token_offset += self.token_hop_len
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
|
||||
break
|
||||
p.join()
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
token_offset=token_offset,
|
||||
finalize=True)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
p.join()
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
token_offset=0,
|
||||
finalize=True,
|
||||
speed=speed)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
|
||||
self.synchronize_stream()
|
||||
self.stream_context_pool.put(stream_context)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
class VllmCosyVoice2Model(CosyVoice2Model):
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool):
|
||||
try:
|
||||
from cosyvoice.llm.llm_vllm import VllmQwen2LM
|
||||
except Exception as e:
|
||||
raise e
|
||||
llm = VllmQwen2LM(model_dir)
|
||||
super().__init__(llm,flow,hift,fp16)
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.flow.load_state_dict(torch.load(flow_model, weights_only=True, map_location=self.device), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
# in case hift_model is a hifigan model
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in
|
||||
torch.load(hift_model, weights_only=True, map_location=self.device).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||
self.hift.to(self.device).eval()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,26 @@ import threading
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matcha.models.components.flow_matching import BASECFM
|
||||
import queue
|
||||
|
||||
class EstimatorWrapper:
|
||||
def __init__(self, estimator_engine, estimator_count=2,):
|
||||
self.estimators = queue.Queue()
|
||||
self.estimator_engine = estimator_engine
|
||||
for _ in range(estimator_count):
|
||||
estimator = estimator_engine.create_execution_context()
|
||||
if estimator is not None:
|
||||
self.estimators.put(estimator)
|
||||
|
||||
if self.estimators.empty():
|
||||
raise Exception("No available estimator")
|
||||
|
||||
def acquire_estimator(self):
|
||||
return self.estimators.get(), self.estimator_engine
|
||||
|
||||
def release_estimator(self, estimator):
|
||||
self.estimators.put(estimator)
|
||||
return
|
||||
|
||||
class ConditionalCFM(BASECFM):
|
||||
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
||||
|
|
@ -125,22 +144,50 @@ class ConditionalCFM(BASECFM):
|
|||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||
else:
|
||||
with self.lock:
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('t', (2,))
|
||||
self.estimator.set_input_shape('spks', (2, 80))
|
||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
if isinstance(self.estimator, EstimatorWrapper):
|
||||
estimator, engine = self.estimator.acquire_estimator()
|
||||
|
||||
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('t', (2,))
|
||||
estimator.set_input_shape('spks', (2, 80))
|
||||
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
|
||||
data_ptrs = [x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()]
|
||||
|
||||
for idx, data_ptr in enumerate(data_ptrs):
|
||||
estimator.set_tensor_address(engine.get_tensor_name(idx), data_ptr)
|
||||
|
||||
# run trt engine
|
||||
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()])
|
||||
return x
|
||||
estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.estimator.release_estimator(estimator)
|
||||
return x
|
||||
else:
|
||||
with self.lock:
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('t', (2,))
|
||||
self.estimator.set_input_shape('spks', (2, 80))
|
||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
# run trt engine
|
||||
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()])
|
||||
return x
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
"""Computes diffusion loss
|
||||
|
|
|
|||
|
|
@ -0,0 +1,212 @@
|
|||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
import queue
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import List, Generator, AsyncGenerator
|
||||
import torch
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.llm.llm import Qwen2LM
|
||||
|
||||
# 启用vllm V1版本
|
||||
import os
|
||||
os.environ["VLLM_USE_V1"] = '1'
|
||||
from vllm import ModelRegistry
|
||||
from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput
|
||||
from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
|
||||
ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
|
||||
|
||||
# EngineArgs
|
||||
ENGINE_ARGS = {
|
||||
"block_size": 16,
|
||||
"swap_space": 0,
|
||||
# "enforce_eager": True,
|
||||
"gpu_memory_utilization": 0.4,
|
||||
"max_num_batched_tokens": 1024,
|
||||
"max_model_len": 1024,
|
||||
"max_num_seqs": 256,
|
||||
"disable_log_requests": True,
|
||||
"disable_log_stats": True,
|
||||
"dtype": "float16"
|
||||
}
|
||||
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
# SamplingParams
|
||||
SAMPLING_PARAMS = {
|
||||
"temperature": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
|
||||
"top_p": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
|
||||
"top_k": 25,
|
||||
# "min_tokens": 80, # 不支持设置最小的tokens数量设置,开启后vllm直接崩溃,无法启动
|
||||
# "presence_penalty": 1.0, # 不支持设置
|
||||
# "frequency_penalty": 0.0, # 不支持设置
|
||||
"max_tokens": 1024,
|
||||
"detokenize": False, # 目前 vllm 0.7.3 v1版本中设置无效,待后续版本更新后减少计算
|
||||
"ignore_eos": False,
|
||||
"output_kind": RequestOutputKind.DELTA # 设置为DELTA,如调整该参数,请同时调整llm_inference的处理代码
|
||||
}
|
||||
|
||||
def tensor_to_list(tensor: torch.tensor):
|
||||
return tensor.view(-1).cpu().numpy().tolist()
|
||||
|
||||
class VllmQwen2LM(Qwen2LM):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir,
|
||||
mix_ratio: List[int] = [5, 15],
|
||||
):
|
||||
self.fp16 = False
|
||||
self.half = lambda: None
|
||||
self.mix_ratio = mix_ratio
|
||||
# ---------------------------------------------
|
||||
# vllm engine 的参数配置
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model_dir,
|
||||
**ENGINE_ARGS,
|
||||
)
|
||||
self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
self.speech_token_size = 6564 # 6561 + 3
|
||||
self.llm_token_size = 151936 # llm vocab_size
|
||||
self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1
|
||||
self.task_token_id = self.sos_eos_token_id + 1
|
||||
self.zero_token_id = self.task_token_id + 1
|
||||
|
||||
# vllm 的推理任务需要在一个固定的事件循环中,因此启动一个后台线程运行转用于推理任务
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True)
|
||||
self.loop_thread.start()
|
||||
|
||||
def _run_event_loop(self):
|
||||
asyncio.set_event_loop(self.loop)
|
||||
self.loop.run_forever()
|
||||
|
||||
async def async_llm_inference(self, out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens):
|
||||
sampling_params = SamplingParams(**SAMPLING_PARAMS)
|
||||
sampling_params.stop_token_ids = stop_token_ids or [6561]
|
||||
if max_tokens:
|
||||
sampling_params.max_tokens = max_tokens
|
||||
async for output in self.llm_engine.generate(
|
||||
{
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
},
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id or f"{time.time()}",
|
||||
):
|
||||
out_queue.put((output.outputs[0], output.finished))
|
||||
|
||||
def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None):
|
||||
out_queue = queue.Queue()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.async_llm_inference(out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens), self.loop
|
||||
)
|
||||
# 接收 out_queue 返回的结果
|
||||
finished = False
|
||||
while not finished:
|
||||
(output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get()
|
||||
yield output
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_len: torch.Tensor,
|
||||
prompt_text: torch.Tensor,
|
||||
prompt_text_len: torch.Tensor,
|
||||
prompt_speech_token: torch.Tensor,
|
||||
prompt_speech_token_len: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
) -> Generator[torch.Tensor|int, None, None]:
|
||||
prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
|
||||
prompt_speech_token = tensor_to_list(prompt_speech_token)
|
||||
|
||||
text = tensor_to_list(text + torch.tensor(6564))
|
||||
prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \
|
||||
[self.task_token_id] + prompt_speech_token
|
||||
max_tokens = len(text) * 20
|
||||
for output in self.llm_inference(
|
||||
prompt_token_ids,
|
||||
stop_token_ids=[6561],
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
if output.token_ids[-1] == 6561:
|
||||
need_add_tokens = output.token_ids[:-1]
|
||||
else:
|
||||
need_add_tokens = output.token_ids
|
||||
for token in need_add_tokens:
|
||||
yield token
|
||||
|
||||
def inference_bistream(
|
||||
self,
|
||||
text: Generator,
|
||||
prompt_text: torch.Tensor,
|
||||
prompt_text_len: torch.Tensor,
|
||||
prompt_speech_token: torch.Tensor,
|
||||
prompt_speech_token_len: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
|
||||
prompt_speech_token = tensor_to_list(prompt_speech_token)
|
||||
|
||||
last_tokens = []
|
||||
prompt_token_ids = [self.sos_eos_token_id]
|
||||
text_tokens_cache = prompt_text
|
||||
for this_text in text:
|
||||
this_text = tensor_to_list(this_text + torch.tensor(6564))
|
||||
# text need tokens
|
||||
assert isinstance(this_text, list), "text need token ids List[int]."
|
||||
text_tokens_cache += this_text
|
||||
while len(prompt_speech_token) != 0:
|
||||
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
||||
text_input_token = text_tokens_cache[:self.mix_ratio[0]]
|
||||
speech_input_token = prompt_speech_token[:self.mix_ratio[1]]
|
||||
prompt_token_ids += text_input_token + speech_input_token
|
||||
# reset the last cache
|
||||
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
||||
prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
|
||||
else:
|
||||
break
|
||||
if len(prompt_speech_token) == 0:
|
||||
if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
|
||||
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
||||
text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]]
|
||||
prompt_token_ids += text_tokens_temp
|
||||
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
||||
else:
|
||||
continue
|
||||
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]):
|
||||
last_tokens = output.token_ids
|
||||
if last_tokens[-1] == 6563:
|
||||
need_add_tokens = last_tokens[:-1]
|
||||
else:
|
||||
need_add_tokens = last_tokens
|
||||
for token in need_add_tokens:
|
||||
yield token
|
||||
prompt_token_ids.extend(need_add_tokens)
|
||||
prompt_token_ids += text_tokens_cache + [self.task_token_id]
|
||||
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]):
|
||||
if output.token_ids[-1] == 6561:
|
||||
need_add_tokens = output.token_ids[:-1]
|
||||
else:
|
||||
need_add_tokens = output.token_ids
|
||||
for token in need_add_tokens:
|
||||
yield token
|
||||
|
|
@ -0,0 +1,263 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm.model_executor.models.interfaces import T
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
IGNORE_ID = -1
|
||||
|
||||
|
||||
class CosyVoice2Model(nn.Module):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.llm_input_size = 896
|
||||
self.llm_output_size = 896
|
||||
|
||||
self.speech_token_size = 6561+3
|
||||
self.llm_token_size = config.vocab_size
|
||||
|
||||
# 2. build speech token language model related modules
|
||||
self.sos_eos = 0
|
||||
self.task_id = 1
|
||||
self.fill_token = 2
|
||||
|
||||
|
||||
self.allow_patterns_overrides = ["llm.*"]
|
||||
self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size)
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
# self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size)
|
||||
self.llm_decoder = ParallelLMHead(self.speech_token_size,
|
||||
self.llm_output_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "llm_decoder"))
|
||||
self.logits_processor = LogitsProcessor(self.speech_token_size)
|
||||
|
||||
# length_normalized_loss: bool = True,
|
||||
# lsm_weight: float = 0.0,
|
||||
# self.criterion_ce = LabelSmoothingLoss(
|
||||
# size=self.speech_token_size,
|
||||
# padding_idx=IGNORE_ID,
|
||||
# smoothing=lsm_weight,
|
||||
# normalize_length=length_normalized_loss,
|
||||
# )
|
||||
|
||||
# 3. [Optional] build speech token related modules
|
||||
self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size)
|
||||
|
||||
# 4. sampling method
|
||||
## use vllm sampling method
|
||||
self.sampler = get_sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
self.mix_ratio: List[int] = [5, 15]
|
||||
|
||||
# 定义特殊token常量
|
||||
self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32)
|
||||
self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32) # 163840 + 6564 = 170404
|
||||
self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32) # 170405
|
||||
self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32)
|
||||
|
||||
self.zero_embed_buffer = torch.zeros(
|
||||
(vllm_config.scheduler_config.max_num_seqs, self.llm_input_size),
|
||||
dtype=self.llm_embedding.weight.dtype,
|
||||
device=self.llm_embedding.weight.device
|
||||
)
|
||||
self.inputs_embed_buffer = torch.zeros(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size),
|
||||
dtype=self.llm_embedding.weight.dtype,
|
||||
device=self.llm_embedding.weight.device,
|
||||
)
|
||||
|
||||
def get_sos_eos_emb(self):
|
||||
return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
|
||||
def get_task_id_emb(self):
|
||||
return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[T] = None,
|
||||
attn_metadata: Optional["AttentionMetadata"] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns the input embeddings merged from the text embeddings from
|
||||
input_ids and the multimodal embeddings generated from multimodal
|
||||
kwargs.
|
||||
"""
|
||||
# 创建掩码,标记哪些 token_id 属于音频 Token
|
||||
mask = input_ids < self.speech_token_size
|
||||
|
||||
# 获取 input_ids 的原始形状
|
||||
input_shape = input_ids.shape
|
||||
# 展平 input_ids 和掩码以便统一处理
|
||||
flat_input_ids = input_ids.view(-1)
|
||||
flat_mask = mask.view(-1)
|
||||
|
||||
inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]]
|
||||
inputs_embeds.zero_()
|
||||
|
||||
# Process speech tokens
|
||||
if flat_mask.any():
|
||||
speech_token_ids = flat_input_ids[flat_mask]
|
||||
inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids)
|
||||
|
||||
# 处理大于 delta 的 token_id
|
||||
if (~flat_mask).any():
|
||||
llm_token_ids = flat_input_ids[~flat_mask]
|
||||
llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask])
|
||||
|
||||
sos_eos_mask = llm_token_ids == self.sos_eos_token_id
|
||||
task_mask = llm_token_ids == self.task_token_id
|
||||
zero_mask = llm_token_ids == self.zero_token_id
|
||||
normal_mask = ~(sos_eos_mask | task_mask | zero_mask)
|
||||
|
||||
# 分层处理逻辑
|
||||
# 第一优先级:SOS/EOS标记
|
||||
if sos_eos_mask.any():
|
||||
llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0)
|
||||
|
||||
# 第二优先级:任务标记
|
||||
if task_mask.any():
|
||||
llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0)
|
||||
|
||||
# 第二优先级:空音频标记
|
||||
if zero_mask.any():
|
||||
llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])]
|
||||
|
||||
# 常规LLM token
|
||||
if normal_mask.any():
|
||||
original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta
|
||||
# print('original_ids: ',original_ids)
|
||||
llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids)
|
||||
|
||||
inputs_embeds[~flat_mask] = llm_embeds
|
||||
|
||||
inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size)
|
||||
|
||||
# 合并多模态嵌入(如果有)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.audio_token_index
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.llm_decoder, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
@staticmethod
|
||||
def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
for name, param in weights:
|
||||
# 处理Qwen2Model核心参数
|
||||
if name.startswith("llm."):
|
||||
if name.startswith("llm.model.model."):
|
||||
name = name.replace("llm.model.model.", "model.")
|
||||
else:
|
||||
continue
|
||||
# print('weights name: ', name)
|
||||
yield name, param
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
weights = self.convert_weights(weights)
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
|
|
@ -61,7 +61,7 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
|
|||
network = builder.create_network(network_flags)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
profile = builder.create_optimization_profile()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
vllm==0.7.3
|
||||
pydantic==2.10.6
|
||||
torch==2.5.1
|
||||
torchaudio==2.5.1
|
||||
|
||||
conformer==0.3.2
|
||||
|
||||
diffusers==0.32.2
|
||||
gdown==5.1.0
|
||||
grpcio==1.57.0
|
||||
grpcio-tools==1.57.0
|
||||
hydra-core==1.3.2
|
||||
HyperPyYAML==1.2.2
|
||||
inflect==7.3.1
|
||||
librosa==0.10.2
|
||||
|
||||
lightning==2.5.0.post0
|
||||
matplotlib==3.7.5
|
||||
modelscope==1.15.0
|
||||
|
||||
networkx==3.4.2
|
||||
omegaconf==2.3.0
|
||||
onnx==1.17.0
|
||||
|
||||
onnxruntime-gpu==1.19.0; sys_platform == 'linux'
|
||||
|
||||
#openai-whisper==20231117
|
||||
openai-whisper==20240930
|
||||
protobuf==4.25
|
||||
pyworld==0.3.4
|
||||
rich==13.7.1
|
||||
soundfile==0.12.1
|
||||
tensorboard==2.14.0
|
||||
wget==3.2
|
||||
WeTextProcessing==1.0.3
|
||||
|
||||
# trt use
|
||||
tensorrt-cu12==10.0.1
|
||||
tensorrt-cu12-bindings==10.0.1
|
||||
tensorrt-cu12-libs==10.0.1
|
||||
|
|
@ -0,0 +1,486 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 测试效果\n",
|
||||
"\n",
|
||||
"- 测试代码: [speed_test.ipynb](speed_test.ipynb)\n",
|
||||
"- 测试环境: Intel i5-12400 CPU, 48GB RAM, 1x NVIDIA GeForce RTX 4070\n",
|
||||
"- 运行环境: Ubuntu 24.04.1 LTS, cuda 12.4, python 3.10.16\n",
|
||||
"- 测试说明: 单任务执行的数据(非并发测试)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 默认情况下使用"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"import asyncio\n",
|
||||
"import torchaudio\n",
|
||||
"\n",
|
||||
"import sys\n",
|
||||
"sys.path.append('third_party/Matcha-TTS')\n",
|
||||
"\n",
|
||||
"from cosyvoice.cli.cosyvoice import CosyVoice2\n",
|
||||
"from cosyvoice.utils.file_utils import load_wav\n",
|
||||
"\n",
|
||||
"prompt_text = '希望你以后能够做得比我还好哟'\n",
|
||||
"prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)\n",
|
||||
"\n",
|
||||
"# cosyvoice = CosyVoice2('./pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=True)\n",
|
||||
"cosyvoice = CosyVoice2('./pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, fp16=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 使用vllm加速llm推理\n",
|
||||
"\n",
|
||||
"#### 1. **安装依赖**\n",
|
||||
"\n",
|
||||
"(该依赖环境下可以运行原本cosyvoice2代码)\n",
|
||||
"```bash\n",
|
||||
"pip install -r requirements_vllm.txt\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"#### 2. **文件复制**\n",
|
||||
"将 pretrained_models/CosyVoice2-0.5B/CosyVoice-BlankEN 文件夹下的部分文件复制到下载的CosyVoice2-0.5B模型文件夹下,并替换 config.json 文件中的 Qwen2ForCausalLM 为 CosyVoice2Model。\n",
|
||||
"```bash\n",
|
||||
"cp pretrained_models/CosyVoice2-0.5B/CosyVoice-BlankEN/{config.json,tokenizer_config.json,vocab.json,merges.txt} pretrained_models/CosyVoice2-0.5B/\n",
|
||||
"sed -i 's/Qwen2ForCausalLM/CosyVoice2Model/' pretrained_models/CosyVoice2-0.5B/config.json\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"#### **注意:**\n",
|
||||
"\n",
|
||||
"- 使用 load_trt 后,需要进行 **预热** 10次推理以上,使用流式推理预热效果较好\n",
|
||||
"- 在 jupyter notebook 中,如果要使用 **vllm** 运行下列代码,需要将vllm_use_cosyvoice2_model.py正确复制到 vllm 包中,并注册到 _VLLM_MODELS 字典中。运行下面的 code 完成"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import shutil\n",
|
||||
"\n",
|
||||
"# 获取vllm包的安装路径\n",
|
||||
"try:\n",
|
||||
" import vllm\n",
|
||||
"except ImportError:\n",
|
||||
" raise ImportError(\"vllm package not installed\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"vllm_path = os.path.dirname(vllm.__file__)\n",
|
||||
"print(f\"vllm package path: {vllm_path}\")\n",
|
||||
"\n",
|
||||
"# 定义目标路径\n",
|
||||
"target_dir = os.path.join(vllm_path, \"model_executor\", \"models\")\n",
|
||||
"target_file = os.path.join(target_dir, \"cosyvoice2.py\")\n",
|
||||
"\n",
|
||||
"# 复制模型文件\n",
|
||||
"source_file = \"./cosyvoice/llm/vllm_use_cosyvoice2_model.py\"\n",
|
||||
"if not os.path.exists(source_file):\n",
|
||||
" raise FileNotFoundError(f\"Source file {source_file} not found\")\n",
|
||||
"\n",
|
||||
"shutil.copy(source_file, target_file)\n",
|
||||
"print(f\"Copied {source_file} to {target_file}\")\n",
|
||||
"\n",
|
||||
"# 修改registry.py文件\n",
|
||||
"registry_path = os.path.join(target_dir, \"registry.py\")\n",
|
||||
"new_entry = ' \"CosyVoice2Model\": (\"cosyvoice2\", \"CosyVoice2Model\"), # noqa: E501\\n'\n",
|
||||
"\n",
|
||||
"# 读取并修改文件内容\n",
|
||||
"with open(registry_path, \"r\") as f:\n",
|
||||
" lines = f.readlines()\n",
|
||||
"\n",
|
||||
"# 检查是否已存在条目\n",
|
||||
"entry_exists = any(\"CosyVoice2Model\" in line for line in lines)\n",
|
||||
"\n",
|
||||
"if not entry_exists:\n",
|
||||
" # 寻找插入位置\n",
|
||||
" insert_pos = None\n",
|
||||
" for i, line in enumerate(lines):\n",
|
||||
" if line.strip().startswith(\"**_FALLBACK_MODEL\"):\n",
|
||||
" insert_pos = i + 1\n",
|
||||
" break\n",
|
||||
" \n",
|
||||
" if insert_pos is None:\n",
|
||||
" raise ValueError(\"Could not find insertion point in registry.py\")\n",
|
||||
" \n",
|
||||
" # 插入新条目\n",
|
||||
" lines.insert(insert_pos, new_entry)\n",
|
||||
" \n",
|
||||
" # 写回文件\n",
|
||||
" with open(registry_path, \"w\") as f:\n",
|
||||
" f.writelines(lines)\n",
|
||||
" print(\"Successfully updated registry.py\")\n",
|
||||
"else:\n",
|
||||
" print(\"Entry already exists in registry.py, skipping modification\")\n",
|
||||
"\n",
|
||||
"print(\"All operations completed successfully!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"failed to import ttsfrd, use WeTextProcessing instead\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n",
|
||||
"/opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/diffusers/models/lora.py:393: FutureWarning: `LoRACompatibleLinear` is deprecated and will be removed in version 1.0.0. Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`.\n",
|
||||
" deprecate(\"LoRACompatibleLinear\", \"1.0.0\", deprecation_message)\n",
|
||||
"2025-03-08 00:37:04,867 INFO input frame rate=25\n",
|
||||
"/opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:115: UserWarning: Specified provider 'CUDAExecutionProvider' is not in available provider names.Available providers: 'AzureExecutionProvider, CPUExecutionProvider'\n",
|
||||
" warnings.warn(\n",
|
||||
"2025-03-08 00:37:06,103 WETEXT INFO found existing fst: /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/zh_tn_tagger.fst\n",
|
||||
"2025-03-08 00:37:06,103 INFO found existing fst: /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/zh_tn_tagger.fst\n",
|
||||
"2025-03-08 00:37:06,104 WETEXT INFO /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/zh_tn_verbalizer.fst\n",
|
||||
"2025-03-08 00:37:06,104 INFO /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/zh_tn_verbalizer.fst\n",
|
||||
"2025-03-08 00:37:06,104 WETEXT INFO skip building fst for zh_normalizer ...\n",
|
||||
"2025-03-08 00:37:06,104 INFO skip building fst for zh_normalizer ...\n",
|
||||
"2025-03-08 00:37:06,313 WETEXT INFO found existing fst: /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/en_tn_tagger.fst\n",
|
||||
"2025-03-08 00:37:06,313 INFO found existing fst: /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/en_tn_tagger.fst\n",
|
||||
"2025-03-08 00:37:06,314 WETEXT INFO /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/en_tn_verbalizer.fst\n",
|
||||
"2025-03-08 00:37:06,314 INFO /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/en_tn_verbalizer.fst\n",
|
||||
"2025-03-08 00:37:06,314 WETEXT INFO skip building fst for en_normalizer ...\n",
|
||||
"2025-03-08 00:37:06,314 INFO skip building fst for en_normalizer ...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO 03-08 00:37:07 __init__.py:207] Automatically detected platform cuda.\n",
|
||||
"WARNING 03-08 00:37:07 registry.py:352] Model architecture CosyVoice2Model is already registered, and will be overwritten by the new model class <class 'cosyvoice.llm.vllm_use_cosyvoice2_model.CosyVoice2Model'>.\n",
|
||||
"WARNING 03-08 00:37:07 config.py:2517] Casting torch.bfloat16 to torch.float16.\n",
|
||||
"INFO 03-08 00:37:07 config.py:560] This model supports multiple tasks: {'embed', 'classify', 'reward', 'generate', 'score'}. Defaulting to 'generate'.\n",
|
||||
"INFO 03-08 00:37:07 config.py:1624] Chunked prefill is enabled with max_num_batched_tokens=1024.\n",
|
||||
"WARNING 03-08 00:37:08 utils.py:2164] CUDA was previously initialized. We must use the `spawn` multiprocessing start method. Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/getting_started/troubleshooting.html#python-multiprocessing for more information.\n",
|
||||
"INFO 03-08 00:37:10 __init__.py:207] Automatically detected platform cuda.\n",
|
||||
"INFO 03-08 00:37:11 core.py:50] Initializing a V1 LLM engine (v0.7.3.dev213+gede41bc7.d20250219) with config: model='./pretrained_models/CosyVoice2-0.5B', speculative_config=None, tokenizer='./pretrained_models/CosyVoice2-0.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=1024, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=./pretrained_models/CosyVoice2-0.5B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={\"level\":3,\"custom_ops\":[\"none\"],\"splitting_ops\":[\"vllm.unified_attention\",\"vllm.unified_attention_with_output\"],\"use_inductor\":true,\"compile_sizes\":[],\"use_cudagraph\":true,\"cudagraph_num_of_warmups\":1,\"cudagraph_capture_sizes\":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],\"max_capture_size\":512}\n",
|
||||
"WARNING 03-08 00:37:11 utils.py:2298] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,list_loras,load_config,pin_lora,remove_lora,scheduler_config not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x771e56fb9a50>\n",
|
||||
"INFO 03-08 00:37:11 parallel_state.py:948] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0\n",
|
||||
"INFO 03-08 00:37:11 gpu_model_runner.py:1055] Starting to load model ./pretrained_models/CosyVoice2-0.5B...\n",
|
||||
"INFO 03-08 00:37:11 cuda.py:157] Using Flash Attention backend on V1 engine.\n",
|
||||
"WARNING 03-08 00:37:11 topk_topp_sampler.py:46] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.\n",
|
||||
"WARNING 03-08 00:37:11 rejection_sampler.py:47] FlashInfer is not available. Falling back to the PyTorch-native implementation of rejection sampling. For the best performance, please install FlashInfer.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/torch/utils/_device.py:106: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
||||
" return func(*args, **kwargs)\n",
|
||||
"Loading pt checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]\n",
|
||||
"Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 1.12it/s]\n",
|
||||
"Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 1.12it/s]\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO 03-08 00:37:12 gpu_model_runner.py:1068] Loading model weights took 0.9532 GB and 1.023026 seconds\n",
|
||||
"INFO 03-08 00:37:16 backends.py:408] Using cache directory: /home/qihua/.cache/vllm/torch_compile_cache/29f70599cb/rank_0 for vLLM's torch.compile\n",
|
||||
"INFO 03-08 00:37:16 backends.py:418] Dynamo bytecode transform time: 3.62 s\n",
|
||||
"INFO 03-08 00:37:16 backends.py:115] Directly load the compiled graph for shape None from the cache\n",
|
||||
"INFO 03-08 00:37:19 monitor.py:33] torch.compile takes 3.62 s in total\n",
|
||||
"INFO 03-08 00:37:20 kv_cache_utils.py:524] GPU KV cache size: 216,560 tokens\n",
|
||||
"INFO 03-08 00:37:20 kv_cache_utils.py:527] Maximum concurrency for 1,024 tokens per request: 211.48x\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2025-03-08 00:37:30,767 DEBUG Using selector: EpollSelector\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO 03-08 00:37:30 gpu_model_runner.py:1375] Graph capturing finished in 11 secs, took 0.37 GiB\n",
|
||||
"INFO 03-08 00:37:30 core.py:116] init engine (profile, create kv cache, warmup model) took 17.82 seconds\n",
|
||||
"inference_processor\n",
|
||||
"[03/08/2025-00:37:31] [TRT] [I] Loaded engine size: 158 MiB\n",
|
||||
"[03/08/2025-00:37:31] [TRT] [I] [MS] Running engine with multi stream info\n",
|
||||
"[03/08/2025-00:37:31] [TRT] [I] [MS] Number of aux streams is 1\n",
|
||||
"[03/08/2025-00:37:31] [TRT] [I] [MS] Number of total worker streams is 2\n",
|
||||
"[03/08/2025-00:37:31] [TRT] [I] [MS] The main stream provided by execute/enqueue calls is the first worker stream\n",
|
||||
"[03/08/2025-00:37:32] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +4545, now: CPU 0, GPU 4681 (MiB)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n",
|
||||
"inference_processor\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"import asyncio\n",
|
||||
"import torchaudio\n",
|
||||
"\n",
|
||||
"import sys\n",
|
||||
"sys.path.append('third_party/Matcha-TTS')\n",
|
||||
"\n",
|
||||
"from cosyvoice.cli.cosyvoice import CosyVoice2\n",
|
||||
"from cosyvoice.utils.file_utils import load_wav\n",
|
||||
"\n",
|
||||
"prompt_text = '希望你以后能够做得比我还好哟'\n",
|
||||
"prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)\n",
|
||||
"\n",
|
||||
"# cosyvoice = CosyVoice2(\n",
|
||||
"# './pretrained_models/CosyVoice2-0.5B', \n",
|
||||
"# load_jit=False, \n",
|
||||
"# load_trt=False, \n",
|
||||
"# fp16=True, \n",
|
||||
"# use_vllm=True,\n",
|
||||
"# )\n",
|
||||
"cosyvoice = CosyVoice2(\n",
|
||||
" './pretrained_models/CosyVoice2-0.5B', \n",
|
||||
" load_jit=True, \n",
|
||||
" load_trt=True, \n",
|
||||
" fp16=True, \n",
|
||||
" use_vllm=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/1 [00:00<?, ?it/s]2025-03-08 00:38:59,777 INFO synthesis text 收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。\n",
|
||||
"2025-03-08 00:39:00,917 INFO yield speech len 11.68, rtf 0.09757431402598342\n",
|
||||
"100%|██████████| 1/1 [00:01<00:00, 1.47s/it]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', prompt_text, prompt_speech_16k, stream=False)):\n",
|
||||
" torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/1 [00:00<?, ?it/s]2025-03-08 00:39:01,208 INFO synthesis text 收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。\n",
|
||||
"2025-03-08 00:39:01,587 INFO yield speech len 1.84, rtf 0.20591642545617145\n",
|
||||
"2025-03-08 00:39:01,790 INFO yield speech len 2.0, rtf 0.10057318210601807\n",
|
||||
"2025-03-08 00:39:02,116 INFO yield speech len 2.0, rtf 0.16271138191223145\n",
|
||||
"2025-03-08 00:39:02,367 INFO yield speech len 2.0, rtf 0.1247786283493042\n",
|
||||
"2025-03-08 00:39:02,640 INFO yield speech len 2.0, rtf 0.13561689853668213\n",
|
||||
"2025-03-08 00:39:02,980 INFO yield speech len 1.88, rtf 0.1803158445561186\n",
|
||||
"100%|██████████| 1/1 [00:02<00:00, 2.05s/it]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', prompt_text, prompt_speech_16k, stream=True)):\n",
|
||||
" torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2025-03-08 00:39:02,990 INFO get tts_text generator, will skip text_normalize!\n",
|
||||
" 0%| | 0/1 [00:00<?, ?it/s]2025-03-08 00:39:02,991 INFO get tts_text generator, will return _extract_text_token_generator!\n",
|
||||
"2025-03-08 00:39:03,236 INFO synthesis text <generator object text_generator at 0x79c694dae340>\n",
|
||||
"2025-03-08 00:39:03,237 INFO not enough text token to decode, wait for more\n",
|
||||
"2025-03-08 00:39:03,252 INFO get fill token, need to append more text token\n",
|
||||
"2025-03-08 00:39:03,253 INFO append 5 text token\n",
|
||||
"2025-03-08 00:39:03,311 INFO get fill token, need to append more text token\n",
|
||||
"2025-03-08 00:39:03,312 INFO append 5 text token\n",
|
||||
"2025-03-08 00:39:03,456 INFO no more text token, decode until met eos\n",
|
||||
"2025-03-08 00:39:04,861 INFO yield speech len 15.16, rtf 0.1072180145334128\n",
|
||||
"100%|██████████| 1/1 [00:01<00:00, 1.88s/it]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def text_generator():\n",
|
||||
" yield '收到好友从远方寄来的生日礼物,'\n",
|
||||
" yield '那份意外的惊喜与深深的祝福'\n",
|
||||
" yield '让我心中充满了甜蜜的快乐,'\n",
|
||||
" yield '笑容如花儿般绽放。'\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), prompt_text, prompt_speech_16k, stream=False)):\n",
|
||||
" torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2025-03-08 00:39:04,878 INFO get tts_text generator, will skip text_normalize!\n",
|
||||
" 0%| | 0/1 [00:00<?, ?it/s]2025-03-08 00:39:04,880 INFO get tts_text generator, will return _extract_text_token_generator!\n",
|
||||
"2025-03-08 00:39:05,151 INFO synthesis text <generator object text_generator at 0x79c694dad690>\n",
|
||||
"2025-03-08 00:39:05,152 INFO not enough text token to decode, wait for more\n",
|
||||
"2025-03-08 00:39:05,169 INFO get fill token, need to append more text token\n",
|
||||
"2025-03-08 00:39:05,169 INFO append 5 text token\n",
|
||||
"2025-03-08 00:39:05,292 INFO get fill token, need to append more text token\n",
|
||||
"2025-03-08 00:39:05,293 INFO append 5 text token\n",
|
||||
"2025-03-08 00:39:05,438 INFO no more text token, decode until met eos\n",
|
||||
"2025-03-08 00:39:05,638 INFO yield speech len 1.84, rtf 0.26492670826289966\n",
|
||||
"2025-03-08 00:39:05,841 INFO yield speech len 2.0, rtf 0.10065567493438721\n",
|
||||
"2025-03-08 00:39:06,164 INFO yield speech len 2.0, rtf 0.16065263748168945\n",
|
||||
"2025-03-08 00:39:06,422 INFO yield speech len 2.0, rtf 0.12791669368743896\n",
|
||||
"2025-03-08 00:39:06,697 INFO yield speech len 2.0, rtf 0.13690149784088135\n",
|
||||
"2025-03-08 00:39:06,998 INFO yield speech len 2.0, rtf 0.14957869052886963\n",
|
||||
"2025-03-08 00:39:07,335 INFO yield speech len 1.0, rtf 0.3356931209564209\n",
|
||||
"100%|██████████| 1/1 [00:02<00:00, 2.46s/it]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def text_generator():\n",
|
||||
" yield '收到好友从远方寄来的生日礼物,'\n",
|
||||
" yield '那份意外的惊喜与深深的祝福'\n",
|
||||
" yield '让我心中充满了甜蜜的快乐,'\n",
|
||||
" yield '笑容如花儿般绽放。'\n",
|
||||
"for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), prompt_text, prompt_speech_16k, stream=True)):\n",
|
||||
" torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/1 [00:00<?, ?it/s]2025-03-08 00:39:07,592 INFO synthesis text 收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。\n",
|
||||
"2025-03-08 00:39:08,925 INFO yield speech len 11.24, rtf 0.11861237342671567\n",
|
||||
"100%|██████████| 1/1 [00:01<00:00, 1.58s/it]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# instruct usage\n",
|
||||
"for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):\n",
|
||||
" torchaudio.save('instruct2_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "cosyvoice",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
Loading…
Reference in New Issue