MNN/transformers/llm/export/utils/talker.py

246 lines
9.7 KiB
Python
Raw Normal View History

2025-05-08 12:39:44 +08:00
import os
import torch
torch.set_printoptions(precision=4, sci_mode=False)
from .model_mapper import ModelMapper
from .transformers import Rotary, Embedding, Decoder
from .token2wav import Qwen2_5OmniToken2Wav
from .spinner import spinner_run
class Talker(torch.nn.Module):
def __init__(self, talker, token2wav, base):
super().__init__()
self.model_type = base.model_type
self.thinker_embed = base.embed
self.args = base.args
self.talker = talker.float()
self.token2wav = Qwen2_5OmniToken2Wav(token2wav, base)
self.config = base.config
self.hidden_size = base.hidden_size
self.llm_config = base.llm_config
self.rope_ratio = 1.0
self.quant_bit = 4
if self.hidden_size <= 2048:
# Qwen2.5-Omni-3B using 8 bit quantization
self.quant_bit = 8
self.init_config()
self.load()
@staticmethod
def get_talker(model_type):
audio_models = {
'qwen2_5_omni': Qwen2_5OmniTalker,
}
if model_type in audio_models:
return audio_models[model_type]
return None
def init_config(self):
self.llm_config['has_talker'] = True
def load(self):
raise NotImplementedError
def add_token_embeds(self, thinker_embeds):
raise NotImplementedError
def add_hidden_states(self, thinker_hidden_states):
raise NotImplementedError
def add_generate_ids(self, token_id):
raise NotImplementedError
def forward(self, inputs_embeds, attention_mask, position_ids, past_key_values = None):
raise NotImplementedError
def export(self, onnx_path):
raise NotImplementedError
def export_embed(self):
import ctypes
tensor_data = self.embed.weight.data.bfloat16()
data_ptr = tensor_data.untyped_storage().data_ptr()
buffer = (ctypes.c_byte * (tensor_data.numel() * 2)).from_address(data_ptr)
embedding_file = f'{self.args.dst_path}/talker_embeddings_bf16.bin'
with open(embedding_file, 'wb') as f:
f.write(buffer)
return embedding_file
class OmniRotary(Rotary):
def __init__(self, model):
super().__init__(model)
self.mrope_section = model.mrope_section
self.theta_sections = self.theta.unsqueeze(0).split(self.mrope_section, dim=-1)
def forward(self, position_ids):
position_ids = position_ids.float().unsqueeze(-1)
idx_theta = torch.concat([
position_ids[0] * self.theta_sections[0],
position_ids[1] * self.theta_sections[1],
position_ids[2] * self.theta_sections[2]
], dim=-1)
rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)])
rotary_pos_emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
rotary_pos_emb = rotary_pos_emb.unsqueeze(3)
return rotary_pos_emb
class Qwen2_5OmniTalker(Talker):
def __init__(self, talker, token2wav, base):
super().__init__(talker, token2wav, base)
self.input_hidden_size = base.hidden_size
self.seq_len = 0
self.token_len = 0
self.talker_embeds = []
def load(self):
# load talker model
self.model_map = {
'config': {
'hidden_size': 'hidden_size',
'head_dim': 'head_dim',
'num_attention_heads': 'num_attention_heads',
'num_hidden_layers': 'num_hidden_layers',
'num_key_value_heads': 'num_key_value_heads',
'rope_theta': 'rope_theta',
'rope_scaling': 'rope_scaling'
},
'decoder': {
'self_attn': 'self_attn',
'mlp': 'mlp',
'input_layernorm': 'input_layernorm',
'post_attention_layernorm': 'post_attention_layernorm'
},
'attention': {
'q_proj': 'q_proj',
'k_proj': 'k_proj',
'v_proj': 'v_proj',
'o_proj': 'o_proj'
}
}
ModelMapper.do_map(self, self.talker.config, self.model_map['config'])
self.mrope_section = self.rope_scaling['mrope_section']
self.embed = self.talker.model.embed_tokens
self.rotary = OmniRotary(self)
# self.rotary = Rotary(self)
self.blocks = []
for block in self.talker.model.layers:
layer_id = len(self.blocks)
self.blocks.append(Decoder(block, layer_id, self))
def forward(self, inputs_embeds, attention_mask, position_ids, past_key_values = None):
hidden_states = self.talker.thinker_to_talker_proj(inputs_embeds)
rotary_pos_emb = self.rotary(position_ids)
presents = [None for i in range(self.num_hidden_layers)]
for i in range(self.num_hidden_layers):
hidden_states, kv = self.blocks[i](hidden_states, rotary_pos_emb, attention_mask, past_key_values[i])
presents[i] = kv
hidden_states = hidden_states[:, -1, :]
hidden_states = self.talker.model.norm(hidden_states)
logits = self.talker.codec_head(hidden_states)
presents = torch.stack(presents)
return logits, presents
def get_position_ids(self) -> torch.Tensor:
if self.token_len:
position_ids = torch.tensor([[self.seq_len - 1]], dtype=torch.int)
else:
position_ids = torch.arange(self.seq_len, dtype=torch.int).unsqueeze(0)
position_ids = torch.stack([position_ids] * 3)
return position_ids
def get_attention_mask(self) -> torch.Tensor:
if self.token_len:
return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32)
return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min
def generate(self):
talker_text_bos_token = 151872
talker_inputs_embeds = torch.cat(
[
self.talker_embeds[0],
self.thinker_embed(torch.tensor([[talker_text_bos_token]], dtype=torch.long)) + \
self.embed(torch.LongTensor([self.talker.codec_pad_token])),
self.talker_embeds[1] + self.embed(torch.LongTensor([self.talker.codec_bos_token])),
],
dim=1,
)
thinker_reply_part = torch.cat(self.talker_embeds[2:], dim=1)
thinker_reply_part = torch.cat(
[
thinker_reply_part,
self.thinker_embed(
torch.tensor([[self.talker.text_eos_token]], dtype=torch.long)
),
self.thinker_embed(
torch.tensor([[self.talker.text_pad_token]], dtype=torch.long)
),
],
dim=1,
)
_, self.seq_len, _ = talker_inputs_embeds.shape
_, reply_len, _ = thinker_reply_part.shape
past_key_values = [None for i in range(self.num_hidden_layers)]
inputs_embeds = talker_inputs_embeds.float()
self.token_len = 0
self.stop_ids = [8292, 8294]
token_id = None
tokens = []
while self.token_len < 256:
attention_mask = self.get_attention_mask()
position_ids = self.get_position_ids()
if self.token_len > 0:
inputs_embeds = self.embed(token_id)
if self.token_len <= reply_len:
inputs_embeds = inputs_embeds + thinker_reply_part[:, self.token_len - 1, :]
else:
inputs_embeds = inputs_embeds + thinker_reply_part[:, -1, :]
logits, past_key_values = self.forward(inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values)
token_id = torch.argmax(logits)
self.token_len += 1
self.seq_len += 1
tokens.append(int(token_id))
if int(token_id) in self.stop_ids:
break
talker_generate_codes = torch.tensor(tokens, dtype=torch.long).unsqueeze(0)
# 3. Generate wavs from code
wav = self.token2wav.generate(talker_generate_codes,)
import soundfile as sf
sf.write(
"output.wav",
wav.reshape(-1).detach().cpu().numpy(),
samplerate=24000,
)
def add_talker_embeds(self, talker_embed):
self.talker_embeds.append(talker_embed)
@spinner_run(f'export talker to ')
def export(self, onnx_path):
self.export_embed()
self.seq_len = 3
self.token_len = 0
inputs_embeds = torch.randn([1, self.seq_len, self.input_hidden_size])
posision_ids = self.get_position_ids()
attention_mask = self.get_attention_mask()
past_key_values = torch.zeros([self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads, self.head_dim])
talker_onnx = f'{onnx_path}/talker.onnx'
torch.onnx.export(self, (inputs_embeds, attention_mask, posision_ids, past_key_values),
talker_onnx,
input_names=['inputs_embeds', 'attention_mask', 'position_ids', 'past_key_values'],
output_names=['logits'],
dynamic_axes={
"inputs_embeds": { 1: "size" },
"attention_mask": { 2: "size", 3: "size" },
"position_ids": { 2: "size" },
"past_key_values": { 3: "size" }
},
do_constant_folding=True,
verbose=False,
opset_version=15)
return talker_onnx