mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
	
	
		
			246 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Python
		
	
	
	
		
		
			
		
	
	
			246 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Python
		
	
	
	
|  | import os | ||
|  | import torch | ||
|  | torch.set_printoptions(precision=4, sci_mode=False) | ||
|  | from .model_mapper import ModelMapper | ||
|  | from .transformers import Rotary, Embedding, Decoder | ||
|  | from .token2wav import Qwen2_5OmniToken2Wav | ||
|  | from .spinner import spinner_run | ||
|  | 
 | ||
|  | class Talker(torch.nn.Module): | ||
|  |     def __init__(self, talker, token2wav, base): | ||
|  |         super().__init__() | ||
|  |         self.model_type = base.model_type | ||
|  |         self.thinker_embed = base.embed | ||
|  |         self.args = base.args | ||
|  |         self.talker = talker.float() | ||
|  |         self.token2wav = Qwen2_5OmniToken2Wav(token2wav, base) | ||
|  |         self.config = base.config | ||
|  |         self.hidden_size = base.hidden_size | ||
|  |         self.llm_config = base.llm_config | ||
|  |         self.rope_ratio = 1.0 | ||
|  |         self.quant_bit = 4 | ||
|  |         if self.hidden_size <= 2048: | ||
|  |             # Qwen2.5-Omni-3B using 8 bit quantization | ||
|  |             self.quant_bit = 8 | ||
|  |         self.init_config() | ||
|  |         self.load() | ||
|  | 
 | ||
|  |     @staticmethod | ||
|  |     def get_talker(model_type): | ||
|  |         audio_models = { | ||
|  |             'qwen2_5_omni': Qwen2_5OmniTalker, | ||
|  |         } | ||
|  |         if model_type in audio_models: | ||
|  |             return audio_models[model_type] | ||
|  |         return None | ||
|  | 
 | ||
|  |     def init_config(self): | ||
|  |         self.llm_config['has_talker'] = True | ||
|  | 
 | ||
|  |     def load(self): | ||
|  |         raise NotImplementedError | ||
|  | 
 | ||
|  |     def add_token_embeds(self, thinker_embeds): | ||
|  |         raise NotImplementedError | ||
|  | 
 | ||
|  |     def add_hidden_states(self, thinker_hidden_states): | ||
|  |         raise NotImplementedError | ||
|  | 
 | ||
|  |     def add_generate_ids(self, token_id): | ||
|  |         raise NotImplementedError | ||
|  | 
 | ||
|  |     def forward(self, inputs_embeds, attention_mask, position_ids, past_key_values = None): | ||
|  |         raise NotImplementedError | ||
|  | 
 | ||
|  |     def export(self, onnx_path): | ||
|  |         raise NotImplementedError | ||
|  | 
 | ||
|  |     def export_embed(self): | ||
|  |         import ctypes | ||
|  |         tensor_data = self.embed.weight.data.bfloat16() | ||
|  |         data_ptr = tensor_data.untyped_storage().data_ptr() | ||
|  |         buffer = (ctypes.c_byte * (tensor_data.numel() * 2)).from_address(data_ptr) | ||
|  |         embedding_file = f'{self.args.dst_path}/talker_embeddings_bf16.bin' | ||
|  |         with open(embedding_file, 'wb') as f: | ||
|  |             f.write(buffer) | ||
|  |         return embedding_file | ||
|  | 
 | ||
|  | class OmniRotary(Rotary): | ||
|  |     def __init__(self, model): | ||
|  |         super().__init__(model) | ||
|  |         self.mrope_section = model.mrope_section | ||
|  |         self.theta_sections = self.theta.unsqueeze(0).split(self.mrope_section, dim=-1) | ||
|  | 
 | ||
|  |     def forward(self, position_ids): | ||
|  |         position_ids = position_ids.float().unsqueeze(-1) | ||
|  |         idx_theta = torch.concat([ | ||
|  |             position_ids[0] * self.theta_sections[0], | ||
|  |             position_ids[1] * self.theta_sections[1], | ||
|  |             position_ids[2] * self.theta_sections[2] | ||
|  |         ], dim=-1) | ||
|  |         rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)]) | ||
|  |         rotary_pos_emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) | ||
|  |         rotary_pos_emb = rotary_pos_emb.unsqueeze(3) | ||
|  |         return rotary_pos_emb | ||
|  | 
 | ||
|  | class Qwen2_5OmniTalker(Talker): | ||
|  |     def __init__(self, talker, token2wav, base): | ||
|  |         super().__init__(talker, token2wav, base) | ||
|  |         self.input_hidden_size = base.hidden_size | ||
|  |         self.seq_len = 0 | ||
|  |         self.token_len = 0 | ||
|  |         self.talker_embeds = [] | ||
|  | 
 | ||
|  |     def load(self): | ||
|  |         # load talker model | ||
|  |         self.model_map = { | ||
|  |             'config': { | ||
|  |                 'hidden_size': 'hidden_size', | ||
|  |                 'head_dim': 'head_dim', | ||
|  |                 'num_attention_heads': 'num_attention_heads', | ||
|  |                 'num_hidden_layers': 'num_hidden_layers', | ||
|  |                 'num_key_value_heads': 'num_key_value_heads', | ||
|  |                 'rope_theta': 'rope_theta', | ||
|  |                 'rope_scaling': 'rope_scaling' | ||
|  |             }, | ||
|  |             'decoder': { | ||
|  |                 'self_attn': 'self_attn', | ||
|  |                 'mlp': 'mlp', | ||
|  |                 'input_layernorm': 'input_layernorm', | ||
|  |                 'post_attention_layernorm': 'post_attention_layernorm' | ||
|  |             }, | ||
|  |             'attention': { | ||
|  |                 'q_proj': 'q_proj', | ||
|  |                 'k_proj': 'k_proj', | ||
|  |                 'v_proj': 'v_proj', | ||
|  |                 'o_proj': 'o_proj' | ||
|  |             } | ||
|  |         } | ||
|  |         ModelMapper.do_map(self, self.talker.config, self.model_map['config']) | ||
|  |         self.mrope_section = self.rope_scaling['mrope_section'] | ||
|  |         self.embed = self.talker.model.embed_tokens | ||
|  |         self.rotary = OmniRotary(self) | ||
|  |         # self.rotary = Rotary(self) | ||
|  |         self.blocks = [] | ||
|  |         for block in self.talker.model.layers: | ||
|  |             layer_id = len(self.blocks) | ||
|  |             self.blocks.append(Decoder(block, layer_id, self)) | ||
|  | 
 | ||
|  |     def forward(self, inputs_embeds, attention_mask, position_ids, past_key_values = None): | ||
|  |         hidden_states = self.talker.thinker_to_talker_proj(inputs_embeds) | ||
|  |         rotary_pos_emb = self.rotary(position_ids) | ||
|  |         presents = [None for i in range(self.num_hidden_layers)] | ||
|  | 
 | ||
|  |         for i in range(self.num_hidden_layers): | ||
|  |             hidden_states, kv = self.blocks[i](hidden_states, rotary_pos_emb, attention_mask, past_key_values[i]) | ||
|  |             presents[i] = kv | ||
|  | 
 | ||
|  |         hidden_states = hidden_states[:, -1, :] | ||
|  |         hidden_states = self.talker.model.norm(hidden_states) | ||
|  |         logits = self.talker.codec_head(hidden_states) | ||
|  |         presents = torch.stack(presents) | ||
|  |         return logits, presents | ||
|  | 
 | ||
|  |     def get_position_ids(self) -> torch.Tensor: | ||
|  |         if self.token_len: | ||
|  |             position_ids = torch.tensor([[self.seq_len - 1]], dtype=torch.int) | ||
|  |         else: | ||
|  |             position_ids = torch.arange(self.seq_len, dtype=torch.int).unsqueeze(0) | ||
|  |         position_ids = torch.stack([position_ids] * 3) | ||
|  |         return position_ids | ||
|  | 
 | ||
|  |     def get_attention_mask(self) -> torch.Tensor: | ||
|  |         if self.token_len: | ||
|  |             return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32) | ||
|  |         return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min | ||
|  | 
 | ||
|  |     def generate(self): | ||
|  |         talker_text_bos_token = 151872 | ||
|  |         talker_inputs_embeds = torch.cat( | ||
|  |             [ | ||
|  |                 self.talker_embeds[0], | ||
|  |                 self.thinker_embed(torch.tensor([[talker_text_bos_token]], dtype=torch.long)) + \ | ||
|  |                 self.embed(torch.LongTensor([self.talker.codec_pad_token])), | ||
|  |                 self.talker_embeds[1] + self.embed(torch.LongTensor([self.talker.codec_bos_token])), | ||
|  |             ], | ||
|  |             dim=1, | ||
|  |         ) | ||
|  |         thinker_reply_part = torch.cat(self.talker_embeds[2:], dim=1) | ||
|  |         thinker_reply_part = torch.cat( | ||
|  |             [ | ||
|  |                 thinker_reply_part, | ||
|  |                 self.thinker_embed( | ||
|  |                     torch.tensor([[self.talker.text_eos_token]], dtype=torch.long) | ||
|  |                 ), | ||
|  |                 self.thinker_embed( | ||
|  |                     torch.tensor([[self.talker.text_pad_token]], dtype=torch.long) | ||
|  |                 ), | ||
|  |             ], | ||
|  |             dim=1, | ||
|  |         ) | ||
|  | 
 | ||
|  |         _, self.seq_len, _ = talker_inputs_embeds.shape | ||
|  |         _, reply_len, _ = thinker_reply_part.shape | ||
|  |         past_key_values = [None for i in range(self.num_hidden_layers)] | ||
|  | 
 | ||
|  |         inputs_embeds = talker_inputs_embeds.float() | ||
|  |         self.token_len = 0 | ||
|  |         self.stop_ids = [8292, 8294] | ||
|  |         token_id = None | ||
|  |         tokens = [] | ||
|  |         while self.token_len < 256: | ||
|  |             attention_mask = self.get_attention_mask() | ||
|  |             position_ids = self.get_position_ids() | ||
|  |             if self.token_len > 0: | ||
|  |                 inputs_embeds = self.embed(token_id) | ||
|  |                 if self.token_len <= reply_len: | ||
|  |                     inputs_embeds = inputs_embeds + thinker_reply_part[:, self.token_len - 1, :] | ||
|  |                 else: | ||
|  |                     inputs_embeds = inputs_embeds + thinker_reply_part[:, -1, :] | ||
|  |             logits, past_key_values = self.forward(inputs_embeds=inputs_embeds, | ||
|  |                                                     attention_mask=attention_mask, | ||
|  |                                                     position_ids=position_ids, | ||
|  |                                                     past_key_values=past_key_values) | ||
|  |             token_id = torch.argmax(logits) | ||
|  |             self.token_len += 1 | ||
|  |             self.seq_len += 1 | ||
|  |             tokens.append(int(token_id)) | ||
|  |             if int(token_id) in self.stop_ids: | ||
|  |                 break | ||
|  |         talker_generate_codes = torch.tensor(tokens, dtype=torch.long).unsqueeze(0) | ||
|  |         # 3. Generate wavs from code | ||
|  |         wav = self.token2wav.generate(talker_generate_codes,) | ||
|  |         import soundfile as sf | ||
|  |         sf.write( | ||
|  |             "output.wav", | ||
|  |             wav.reshape(-1).detach().cpu().numpy(), | ||
|  |             samplerate=24000, | ||
|  |         ) | ||
|  | 
 | ||
|  |     def add_talker_embeds(self, talker_embed): | ||
|  |         self.talker_embeds.append(talker_embed) | ||
|  | 
 | ||
|  |     @spinner_run(f'export talker to ') | ||
|  |     def export(self, onnx_path): | ||
|  |         self.export_embed() | ||
|  |         self.seq_len = 3 | ||
|  |         self.token_len = 0 | ||
|  |         inputs_embeds = torch.randn([1, self.seq_len, self.input_hidden_size]) | ||
|  |         posision_ids = self.get_position_ids() | ||
|  |         attention_mask = self.get_attention_mask() | ||
|  |         past_key_values = torch.zeros([self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads, self.head_dim]) | ||
|  |         talker_onnx = f'{onnx_path}/talker.onnx' | ||
|  |         torch.onnx.export(self, (inputs_embeds, attention_mask, posision_ids, past_key_values), | ||
|  |                         talker_onnx, | ||
|  |                         input_names=['inputs_embeds', 'attention_mask', 'position_ids', 'past_key_values'], | ||
|  |                         output_names=['logits'], | ||
|  |                         dynamic_axes={ | ||
|  |                             "inputs_embeds": { 1: "size" }, | ||
|  |                             "attention_mask": { 2: "size", 3: "size" }, | ||
|  |                             "position_ids": { 2: "size" }, | ||
|  |                             "past_key_values": { 3: "size" } | ||
|  |                         }, | ||
|  |                         do_constant_folding=True, | ||
|  |                         verbose=False, | ||
|  |                         opset_version=15) | ||
|  |         return talker_onnx |