2024-09-12 12:57:57 +08:00
import os
import json
2025-01-22 14:47:50 +08:00
import glob
2024-09-12 12:57:57 +08:00
import base64
import warnings
import argparse
2025-01-22 14:47:50 +08:00
warnings . filterwarnings ( " ignore " )
os . environ [ ' TOKENIZERS_PARALLELISM ' ] = ' false '
os . environ [ ' PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION ' ] = ' python '
2024-09-12 12:57:57 +08:00
import onnx
import torch
2025-02-12 11:14:19 +08:00
from typing import Optional , List
2024-11-18 14:37:45 +08:00
from transformers import AutoConfig , AutoModel , AutoModelForCausalLM , AutoTokenizer
2024-09-12 12:57:57 +08:00
2025-01-22 14:47:50 +08:00
from utils . spinner import spinner_run
from utils . custom_op import FakeLinear
from utils . onnx_rebuilder import OnnxRebuilder
from utils . mnn_converter import MNNConveter
from utils . awq_quantizer import AwqQuantizer
from utils . model_mapper import ModelMapper
from utils . transformers import Embedding , Rotary , Decoder , Lm
2024-11-18 14:37:45 +08:00
2024-09-12 12:57:57 +08:00
class LlmExporter ( torch . nn . Module ) :
'''
Base class for all llm model export . Inherits from [ ` torch . nn . Module ` ] .
'''
def __init__ ( self , args ) :
super ( ) . __init__ ( )
self . init_from_args ( args )
self . load_model ( args . path )
def init_from_args ( self , args ) :
2025-03-12 11:35:16 +08:00
self . visual = None
self . audio = None
2025-05-08 12:39:44 +08:00
self . talker = None
2025-06-17 11:08:21 +08:00
self . mtp = None
2024-12-19 16:20:00 +08:00
self . args = args
2025-05-08 12:39:44 +08:00
self . max_length = 1024
2024-09-12 12:57:57 +08:00
self . stop_ids = [ ]
self . dst_name = ' llm '
# load config from args
2024-12-19 16:20:00 +08:00
self . onnx_path = os . path . join ( self . args . dst_path , ' onnx ' )
if self . args . tokenizer_path is None :
self . args . tokenizer_path = self . args . path
if args . lm_quant_bit is None :
self . args . lm_quant_bit = self . args . quant_bit
2024-09-12 12:57:57 +08:00
# init export dst dir
2024-12-19 16:20:00 +08:00
if not os . path . exists ( self . args . dst_path ) :
os . makedirs ( self . args . dst_path )
2024-09-12 20:19:02 +08:00
if not os . path . exists ( self . onnx_path ) :
os . makedirs ( self . onnx_path )
2024-09-12 12:57:57 +08:00
def load_pretrained ( self , model_path : str ) :
2025-05-23 15:21:41 +08:00
try :
self . tokenizer = AutoTokenizer . from_pretrained ( self . args . tokenizer_path , trust_remote_code = True , use_fast = False )
except :
self . tokenizer = None
if None == self . tokenizer :
try :
self . tokenizer = AutoTokenizer . from_pretrained ( self . args . tokenizer_path , trust_remote_code = True )
except :
self . tokenizer = None
if None == self . tokenizer :
print ( " Default load tokenizer failed for " , model_path )
2025-05-08 12:39:44 +08:00
if ' Qwen2.5-Omni ' in model_path :
from transformers import Qwen2_5OmniForConditionalGeneration
self . model = Qwen2_5OmniForConditionalGeneration . from_pretrained ( model_path , torch_dtype = " auto " ) . eval ( )
elif ' Qwen2.5-VL ' in model_path or ' Qwen2___5-VL ' in model_path :
2025-02-17 16:54:34 +08:00
from transformers import Qwen2_5_VLForConditionalGeneration
self . model = Qwen2_5_VLForConditionalGeneration . from_pretrained ( model_path , torch_dtype = ' auto ' ) . eval ( )
elif ' Qwen2-VL ' in model_path :
2024-09-12 20:19:02 +08:00
from transformers import Qwen2VLForConditionalGeneration
2024-12-02 10:12:08 +08:00
self . model = Qwen2VLForConditionalGeneration . from_pretrained ( model_path , torch_dtype = ' auto ' ) . eval ( )
2025-05-23 15:21:41 +08:00
elif ' deepseek-vl ' in model_path :
from deepseek_vl . models import VLChatProcessor , MultiModalityCausalLM
vl_chat_processor = VLChatProcessor . from_pretrained ( model_path )
self . tokenizer = vl_chat_processor . tokenizer
self . processor = vl_chat_processor
vl_gpt : MultiModalityCausalLM = AutoModelForCausalLM . from_pretrained ( model_path , trust_remote_code = True ) . eval ( )
self . model = vl_gpt
self . model . config . model_type = " deepseek-vl "
2024-12-19 16:20:00 +08:00
elif ' Qwen2-Audio ' in model_path :
from transformers import Qwen2AudioForConditionalGeneration
self . audio = Qwen2AudioForConditionalGeneration . from_pretrained ( model_path , torch_dtype = " auto " )
self . model = self . audio . language_model
2024-11-18 14:37:45 +08:00
elif ' Llama-3.2 ' in model_path and ' Vision ' in model_path :
from transformers import MllamaForConditionalGeneration
2024-12-02 10:12:08 +08:00
self . model = MllamaForConditionalGeneration . from_pretrained ( model_path , torch_dtype = ' auto ' ) . eval ( )
2025-01-22 14:47:50 +08:00
elif ' Llama ' in model_path or ' Yi ' in model_path :
from transformers import LlamaForCausalLM
self . model = LlamaForCausalLM . from_pretrained ( model_path , torch_dtype = ' auto ' , trust_remote_code = True ) . eval ( )
2025-04-28 11:38:44 +08:00
elif ' InternVL ' in model_path :
self . model = AutoModel . from_pretrained ( model_path , torch_dtype = torch . float32 , trust_remote_code = True ) . eval ( )
2025-05-23 15:21:41 +08:00
elif ' gemma-3-4b-it ' in model_path :
from transformers import Gemma3ForConditionalGeneration
self . model = Gemma3ForConditionalGeneration . from_pretrained ( model_path , torch_dtype = ' auto ' ) . eval ( )
elif ' gemma-3-1b-it ' in model_path :
from transformers import Gemma3ForCausalLM
self . model = Gemma3ForCausalLM . from_pretrained ( model_path , torch_dtype = ' auto ' ) . eval ( )
elif ' SmolVLM2 ' in model_path :
from transformers import AutoModelForImageTextToText
self . model = AutoModelForImageTextToText . from_pretrained ( model_path , torch_dtype = ' auto ' ) . eval ( )
2025-06-05 15:15:29 +08:00
elif ' SmolVLM ' in model_path or ' SmolDocling ' in model_path :
2025-05-23 15:21:41 +08:00
from transformers import AutoModelForVision2Seq
self . model = AutoModelForVision2Seq . from_pretrained ( model_path , torch_dtype = ' auto ' ) . eval ( )
2024-09-12 20:19:02 +08:00
else :
try :
2024-12-02 10:12:08 +08:00
self . model = AutoModelForCausalLM . from_pretrained ( model_path , torch_dtype = ' auto ' , trust_remote_code = True ) . eval ( )
2024-09-12 20:19:02 +08:00
except :
2024-12-02 10:12:08 +08:00
self . model = AutoModel . from_pretrained ( model_path , torch_dtype = ' auto ' , trust_remote_code = True ) . eval ( )
2024-09-12 12:57:57 +08:00
self . config = self . model . config
2025-01-22 14:47:50 +08:00
if self . args . lora_path is not None and not self . args . lora_split :
2024-09-12 12:57:57 +08:00
from peft import PeftModel
2024-12-19 16:20:00 +08:00
adapter = PeftModel . from_pretrained ( self . model , model_id = self . args . lora_path )
2024-09-12 12:57:57 +08:00
self . model = adapter . merge_and_unload ( progressbar = True )
2024-09-12 20:19:02 +08:00
@staticmethod
def has_attr ( obj , attr ) :
return hasattr ( obj , attr ) and getattr ( obj , attr ) is not None
2024-12-02 10:12:08 +08:00
@spinner_run ( f ' load pretrained model ' , True )
2024-09-12 12:57:57 +08:00
def load_model ( self , model_path ) :
self . load_pretrained ( model_path )
self . attention_mask_type = ' float '
# load tokenizer info
self . stop_ids . append ( self . tokenizer . eos_token_id )
if hasattr ( self . tokenizer , ' im_end_id ' ) :
self . stop_ids . append ( self . tokenizer . im_end_id )
2025-01-22 14:47:50 +08:00
try :
eot_id = self . tokenizer . encode ( ' <|eot_id|> ' )
if len ( eot_id ) == 1 :
self . stop_ids . append ( eot_id [ 0 ] )
2025-02-24 11:44:27 +08:00
# gemma/gemma-2
eot_id = self . tokenizer . encode ( ' <end_of_turn> ' )
if len ( eot_id ) == 2 and eot_id [ 0 ] == 2 :
self . stop_ids . append ( eot_id [ 1 ] )
2025-01-22 14:47:50 +08:00
except :
pass
2024-09-12 20:19:02 +08:00
if hasattr ( self . model , ' generation_config ' ) and self . model . generation_config is not None :
2024-09-12 12:57:57 +08:00
eos_token_id = self . model . generation_config . eos_token_id
from collections . abc import Iterable
if isinstance ( eos_token_id , int ) :
self . stop_ids . append ( eos_token_id )
elif isinstance ( eos_token_id , Iterable ) :
for id in eos_token_id :
self . stop_ids . append ( id )
self . stop_ids = [ stop_id for stop_id in self . stop_ids if stop_id is not None ]
self . stop_ids = list ( set ( self . stop_ids ) )
model_mapper = ModelMapper ( )
self . model_type , self . model_map = model_mapper . get_map ( self . config )
2024-12-19 16:20:00 +08:00
if self . args . awq :
2024-12-02 10:12:08 +08:00
self . model . float ( )
2024-12-19 16:20:00 +08:00
if self . args . export is not None :
2024-12-02 10:12:08 +08:00
# set norm's weight as float for export
def visit_module ( module ) :
if not isinstance ( module , torch . nn . Linear ) and hasattr ( module , ' weight ' ) :
module . float ( )
for name , child in module . named_children ( ) :
visit_module ( child )
visit_module ( self . model )
2024-09-12 20:19:02 +08:00
# print(self.config, self.model_type, self.model_map, self.model)
2024-09-12 12:57:57 +08:00
# load config info
ModelMapper . do_map ( self , self . config , self . model_map [ ' config ' ] )
if not hasattr ( self , ' num_key_value_heads ' ) or self . num_key_value_heads is None :
self . num_key_value_heads = self . num_attention_heads
if not hasattr ( self , ' rope_theta ' ) or self . rope_theta is None :
self . rope_theta = 10000.0
2025-03-17 21:44:35 +08:00
if not hasattr ( self , ' rope_ratio ' ) or self . rope_ratio is None :
self . rope_ratio = 1.0
2024-09-12 12:57:57 +08:00
if not hasattr ( self , ' head_dim ' ) or self . head_dim is None :
2024-11-18 14:37:45 +08:00
if isinstance ( self . num_attention_heads , list ) :
self . head_dim = [ self . hidden_size / / atten_head for atten_head in self . num_attention_heads ]
else :
self . head_dim = self . hidden_size / / self . num_attention_heads
2024-09-12 12:57:57 +08:00
# some export info
2024-11-18 14:37:45 +08:00
if isinstance ( self . num_attention_heads , list ) :
self . past_kv_shape = [ self . num_hidden_layers , 2 , 1 , 0 , self . num_key_value_heads [ 0 ] , self . head_dim ]
else :
self . past_kv_shape = [ self . num_hidden_layers , 2 , 1 , 0 , self . num_key_value_heads , self . head_dim ]
2024-09-12 12:57:57 +08:00
self . block_dynamic_axes = {
" inputs_embeds " : { 0 : " seq_len " } ,
" attention_mask " : { 2 : " seq_len " , 3 : " seq_len " } ,
" position_ids " : { 0 : " seq_len " } ,
" past_key_values " : { 1 : " history_len " }
}
self . model_dynamic_axes = {
" input_ids " : { 0 : " seq_len " } ,
" attention_mask " : { 2 : " seq_len " , 3 : " seq_len " } ,
2024-11-18 14:37:45 +08:00
" position_ids " : { 1 : " seq_len " } ,
" past_key_values " : { 3 : " history_len " }
2024-09-12 12:57:57 +08:00
}
2025-02-24 11:44:27 +08:00
prompt_template = self . build_prompt_template ( )
2024-09-12 12:57:57 +08:00
self . llm_config = {
' hidden_size ' : self . hidden_size ,
' layer_nums ' : self . num_hidden_layers ,
' attention_mask ' : self . attention_mask_type ,
' key_value_shape ' : self . past_kv_shape [ 1 : ] ,
2025-02-24 11:44:27 +08:00
" bos " : prompt_template [ ' bos ' ] ,
" system_prompt_template " : prompt_template [ ' system ' ] . format ( content = ' %s ' ) ,
' user_prompt_template ' : prompt_template [ ' user ' ] . format ( content = ' %s ' ) ,
' assistant_prompt_template ' : prompt_template [ ' assistant ' ] . format ( content = ' %s ' ) ,
2024-09-12 12:57:57 +08:00
' is_visual ' : False
}
2025-06-17 11:08:21 +08:00
if ' jinja ' in prompt_template :
self . llm_config [ ' jinja ' ] = prompt_template [ ' jinja ' ]
2024-09-12 12:57:57 +08:00
# load modules
ModelMapper . do_map ( self , self . model , self . model_map [ ' model ' ] )
2025-06-05 15:15:29 +08:00
self . tie_word_embeddings = not self . args . seperate_embed and self . lm_ . weight . equal ( self . embed_ . weight )
2025-05-23 15:21:41 +08:00
if self . tie_word_embeddings :
2025-06-05 15:15:29 +08:00
print ( " Tie word embeddings in lm, set lm quant bit to 8 " )
2025-05-23 15:21:41 +08:00
self . args . lm_quant_bit = 8
2024-09-12 12:57:57 +08:00
# rebuild modules
2024-11-18 14:37:45 +08:00
if self . lm_ is None :
out_features , in_features = self . embed_ . weight . shape
self . lm_ = torch . nn . Linear ( in_features , out_features )
self . lm_ . weight = self . embed_ . weight
2025-01-22 14:47:50 +08:00
elif not isinstance ( self . lm_ , torch . nn . Linear ) :
# for Baichuan2
weight = self . lm_ . weight
out_features , in_features = weight . shape
self . lm_ = torch . nn . Linear ( in_features , out_features )
self . lm_ . weight = weight
2025-02-24 11:44:27 +08:00
self . lm_ . bias . data = torch . zeros ( out_features , dtype = weight . dtype )
2024-11-18 14:37:45 +08:00
2024-09-12 12:57:57 +08:00
if self . embed_ . weight is self . lm_ . weight :
import copy
embed_copy = copy . deepcopy ( self . embed_ )
self . embed = Embedding ( embed_copy , self )
else :
self . embed = Embedding ( self . embed_ , self )
# Rotary
2025-05-08 12:39:44 +08:00
2024-09-12 12:57:57 +08:00
self . rotary = Rotary ( self )
self . blocks = [ ]
for block in self . blocks_ . children ( ) :
2024-11-18 14:37:45 +08:00
layer_id = len ( self . blocks )
self . blocks . append ( Decoder ( block , layer_id , self ) )
2025-06-17 11:08:21 +08:00
self . lm = Lm ( self . lm_ )
2025-05-23 15:21:41 +08:00
2024-09-12 12:57:57 +08:00
# visual model
if self . visual is not None :
2024-12-19 16:20:00 +08:00
if self . args . export is not None :
self . visual . float ( )
2025-01-22 14:47:50 +08:00
from utils . vision import Vision
self . visual = Vision . get_vision ( self . model_type ) ( self . visual , self )
2024-12-19 16:20:00 +08:00
if hasattr ( self , ' audio ' ) and self . audio is not None :
2025-01-22 14:47:50 +08:00
from utils . audio import Audio
2024-12-19 16:20:00 +08:00
self . audio = Audio . get_audio ( self . audio . config . model_type ) ( self . audio , self )
else :
self . audio = None
2025-05-08 12:39:44 +08:00
# talker model
if hasattr ( self , ' talker ' ) and self . talker is not None and \
hasattr ( self , ' token2wav ' ) and self . token2wav is not None :
from utils . talker import Talker
self . talker = Talker . get_talker ( self . model_type ) ( self . talker , self . token2wav , self )
2025-06-17 11:08:21 +08:00
# MTP model
if self . model_type == ' poi_qwen2_mtp ' :
self . mtp = [ self . mtp1 , self . mtp2 ]
if self . mtp is not None :
if self . args . export is not None :
for mtp_model in self . mtp :
mtp_model . float ( )
from utils . mtp import Mtp
self . mtp = Mtp . get_mtp ( self . model_type ) ( self . mtp , self )
2024-09-12 12:57:57 +08:00
return model_path
def get_attention_mask ( self ) - > torch . Tensor :
if self . model_type == ' chatglm ' :
return self . chatglm_attention_mask ( )
if self . token_len :
return torch . zeros ( [ 1 , 1 , 1 , self . seq_len ] , dtype = torch . float32 )
return ( 1 - torch . tril ( torch . ones ( [ 1 , 1 , self . seq_len , self . seq_len ] ) ) ) * torch . finfo ( torch . float32 ) . min
2025-05-08 12:39:44 +08:00
def get_position_ids ( self , input_ids = None ) - > torch . Tensor :
if self . visual is not None and hasattr ( self . visual , ' get_position_ids ' ) and callable ( getattr ( self . visual , ' get_position_ids ' ) ) :
return self . visual . get_position_ids ( input_ids , self . seq_len , self . token_len )
2024-09-12 12:57:57 +08:00
if self . model_type == ' chatglm ' :
return self . chatglm_position_ids ( )
if self . token_len :
2024-11-18 14:37:45 +08:00
return torch . tensor ( [ [ self . seq_len - 1 ] ] , dtype = torch . int )
return torch . arange ( self . seq_len , dtype = torch . int ) . unsqueeze ( 0 )
2024-09-12 12:57:57 +08:00
def chatglm_attention_mask ( self ) :
if self . token_len :
return torch . zeros ( [ 1 ] ) . bool ( ) . reshape ( [ 1 , 1 , 1 , 1 ] )
attention_mask = torch . zeros ( [ self . seq_len , self . seq_len ] , dtype = torch . bool )
for i in range ( self . seq_len - 1 ) :
attention_mask [ i ] [ - 1 ] = True
attention_mask = attention_mask . reshape ( [ 1 , 1 , self . seq_len , self . seq_len ] )
return attention_mask
def chatglm_position_ids ( self ) :
if self . token_len :
return torch . tensor ( [ self . context_len , self . token_len + 1 ] ) . reshape ( [ 1 , 2 , 1 ] )
2024-11-18 14:37:45 +08:00
position_ids_0 = torch . arange ( self . seq_len , dtype = torch . int )
position_ids_1 = torch . zeros ( self . seq_len , dtype = torch . int )
2024-09-12 12:57:57 +08:00
position_ids_0 [ - 1 ] = position_ids_0 [ - 2 ]
position_ids_1 [ - 1 ] = 1
position_ids = torch . stack ( [ position_ids_0 , position_ids_1 ] ) . view ( 1 , 2 , - 1 )
return position_ids
def visual_embed ( self , input_ids ) :
2024-09-12 20:19:02 +08:00
return self . visual . embed ( input_ids )
2024-09-12 12:57:57 +08:00
2024-12-19 16:20:00 +08:00
def audio_embed ( self , input_ids ) :
return self . audio . embed ( input_ids )
2024-09-12 12:57:57 +08:00
def embedding ( self , input_ids ) :
if self . visual is not None and self . token_len == 0 :
input_embeds = self . visual_embed ( input_ids )
2024-12-19 16:20:00 +08:00
elif self . audio is not None and self . token_len == 0 :
input_embeds = self . audio_embed ( input_ids )
2024-09-12 12:57:57 +08:00
else :
input_embeds = self . embed ( input_ids )
return input_embeds
2024-11-18 14:37:45 +08:00
def forward ( self ,
input_ids : torch . Tensor ,
attention_mask : torch . Tensor ,
position_ids : torch . Tensor ,
2025-03-12 11:35:16 +08:00
past_key_values : Optional [ List [ torch . Tensor ] ] = None ,
2025-02-24 11:44:27 +08:00
logits_index : int = - 1 ,
2024-11-18 14:37:45 +08:00
cross_attention_states : Optional [ torch . Tensor ] = None ,
cross_attention_mask : Optional [ torch . Tensor ] = None ,
) :
2024-09-12 12:57:57 +08:00
hidden_states = input_ids # llm forward without embedding
2025-02-24 11:44:27 +08:00
if self . model_type == ' gemma ' :
normalizer = torch . tensor ( self . hidden_size * * 0.5 , dtype = hidden_states . dtype )
hidden_states = hidden_states * normalizer
2025-05-23 15:21:41 +08:00
if self . model_type == ' gemma3 ' or self . model_type == ' gemma3_text ' : # if --test, comments these
hidden_states = hidden_states * self . embed . embed_scale
presents = [ None for i in range ( len ( self . blocks ) ) ]
2024-09-12 12:57:57 +08:00
rotary_pos_emb = self . rotary ( position_ids )
2024-12-19 16:20:00 +08:00
if self . args . test and rotary_pos_emb . dtype != hidden_states . dtype :
rotary_pos_emb = rotary_pos_emb . type ( hidden_states . dtype )
2025-05-23 15:21:41 +08:00
for i in range ( len ( self . blocks ) ) :
2024-11-18 14:37:45 +08:00
if self . blocks [ i ] . cross_decoder and cross_attention_states is None :
continue
2024-09-12 12:57:57 +08:00
hidden_states , kv = self . blocks [ i ] ( hidden_states , rotary_pos_emb , attention_mask , past_key_values [ i ] )
2024-11-18 14:37:45 +08:00
presents [ i ] = kv
2025-05-23 15:21:41 +08:00
talker_embeds = None
2025-05-08 12:39:44 +08:00
if hasattr ( self , ' talker ' ) and self . talker is not None :
2025-05-23 15:21:41 +08:00
talker_embeds = self . final_layernorm_ ( hidden_states ) + input_ids . permute ( [ 1 , 0 , 2 ] )
2025-05-08 12:39:44 +08:00
self . talker . add_talker_embeds ( talker_embeds )
2025-06-17 11:08:21 +08:00
final_layernorm = hidden_states
if self . mtp is None :
hidden_states = hidden_states [ : , logits_index : , : ]
hidden_states = self . final_layernorm_ ( hidden_states )
else :
# final_layernorm need compute all logists
if self . model_type == ' mimo ' :
final_layernorm = hidden_states # mimo
hidden_states = self . final_layernorm_ ( hidden_states )
if self . model_type == ' poi_qwen2_mtp ' :
final_layernorm = hidden_states # poi
hidden_states = hidden_states [ : , logits_index : , : ]
logits = self . lm ( hidden_states )
2024-11-18 14:37:45 +08:00
if presents [ 0 ] . shape == presents [ - 1 ] . shape and None not in presents :
presents = torch . stack ( presents )
2024-09-12 12:57:57 +08:00
self . seq_len + = 1
self . token_len + = 1
2025-06-17 11:08:21 +08:00
return logits , final_layernorm , presents , talker_embeds
2024-09-12 12:57:57 +08:00
# some test functions
2025-02-24 11:44:27 +08:00
def build_prompt_template ( self ) :
template = {
' bos ' : ' ' ,
' system ' : ' {content} ' ,
' user ' : ' {content} ' ,
2025-05-23 15:21:41 +08:00
' assistant ' : ' {content} '
2025-02-24 11:44:27 +08:00
}
2025-06-17 11:08:21 +08:00
if hasattr ( self . tokenizer , ' get_chat_template ' ) :
template [ ' jinja ' ] = { }
template [ ' jinja ' ] [ ' chat_template ' ] = self . tokenizer . get_chat_template ( )
if None != self . tokenizer . bos_token :
template [ ' jinja ' ] [ ' bos ' ] = self . tokenizer . bos_token
if None != self . tokenizer . eos_token :
template [ ' jinja ' ] [ ' eos ' ] = self . tokenizer . eos_token
2025-02-24 11:44:27 +08:00
if self . model_type == ' baichuan ' :
template [ ' user ' ] = ' <reserved_106> {content} '
template [ ' assistant ' ] = ' <reserved_107> {content} '
if self . model_type == ' chatglm ' :
template [ ' user ' ] = ' {content} [gMASK]<sop> '
if self . model_type == ' chatglm2 ' and ' codegeex ' not in self . args . path :
template [ ' user ' ] = ' [Round 1] \n \n 问: {content} \n \n '
template [ ' assistant ' ] = ' 答: {content} \n \n '
if ' chatglm3 ' in self . args . path or ' glm-4 ' in self . args . path :
template [ ' bos ' ] = ' [gMASK]<sop> '
template [ ' system ' ] = ' <|system|> \n {content} \n '
template [ ' user ' ] = ' <|user|> \n {content} \n '
template [ ' assistant ' ] = ' <|assistant|> \n {content} \n '
if self . model_type == ' llama ' :
if ' Llama-2 ' in self . args . path :
template [ ' bos ' ] = ' [INST] '
template [ ' system ' ] = " <<SYS>> \n {content} \n <</SYS>> \n \n "
template [ ' user ' ] = ' {content} [/INST] '
template [ ' assistant ' ] = " {content} </s> " ;
if ' Llama-3 ' in self . args . path :
template [ ' system ' ] = ' <|start_header_id|>system<|end_header_id|> \n \n {content} <|eot_id|> '
template [ ' user ' ] = ' <|start_header_id|>user<|end_header_id|> \n \n {content} <|eot_id|> '
template [ ' assistant ' ] = ' <|start_header_id|>assistant<|end_header_id|> \n \n {content} <|eot_id|> '
if ' TinyLlama ' in self . args . path :
template [ ' bos ' ] = ' <s> '
template [ ' system ' ] = ' <|system|> \n {content} </s> \n '
template [ ' user ' ] = ' <|user|> \n {content} </s> \n '
template [ ' assistant ' ] = ' <|assistant|> \n {content} </s> \n '
if ' Yi ' in self . args . path or ' SmolLM2 ' in self . args . path :
template [ ' system ' ] = ' <|im_start|>system \n {content} <|im_end|> \n '
template [ ' user ' ] = ' <|im_start|>user \n {content} <|im_end|> \n '
template [ ' assistant ' ] = ' <|im_start|>assistant \n {content} <|im_end|> \n '
if self . model_type == ' gemma2 ' :
template [ ' bos ' ] = ' <bos> '
template [ ' system ' ] = ' <start_of_turn>system \n {content} <end_of_turn> \n '
template [ ' user ' ] = ' <start_of_turn>user \n {content} <end_of_turn> \n '
template [ ' assistant ' ] = ' <start_of_turn>model \n {content} <end_of_turn> \n '
if self . model_type == ' gemma ' :
template [ ' bos ' ] = ' <bos> '
if self . model_type == ' internlm ' :
template [ ' user ' ] = ' <|User|>: {content} <eoh> \n '
template [ ' assistant ' ] = ' <|Bot|>: {content} <eoh> \n '
if self . model_type == ' phi-msft ' :
template [ ' user ' ] = ' Instruct: {content} \n '
template [ ' assistant ' ] = ' Output: {content} \n '
if self . model_type == ' openelm ' :
template [ ' bos ' ] = ' <s> '
2025-04-28 11:38:44 +08:00
if self . model_type == ' internvl_chat ' :
if ' Qwen ' in self . config . llm_config . _name_or_path :
print ( " [DEBUG] Use qwen prompt template " )
template [ ' system ' ] = ' <|im_start|>system \n {content} <|im_end|> \n '
template [ ' user ' ] = ' <|im_start|>user \n {content} <|im_end|> \n '
template [ ' assistant ' ] = ' <|im_start|>assistant \n {content} <|im_end|> \n '
2025-05-23 15:21:41 +08:00
if self . model_type == ' phi3 ' :
template [ ' system ' ] = ' <|im_start|>system<|im_sep|> {content} <|im_end|> '
template [ ' user ' ] = ' <|im_start|>user<|im_sep|> {content} <|im_end|> '
template [ ' assistant ' ] = ' <|im_start|>assistant<|im_sep|> {content} <|im_end|> '
if self . model_type == " gemma3 " :
template [ ' bos ' ] = ' <bos><start_of_turn>user \n '
template [ ' system ' ] = ' {content} \n \n '
template [ ' user ' ] = ' {content} <end_of_turn> \n '
template [ ' assistant ' ] = ' <start_of_turn>model \n {content} <end_of_turn> \n <start_of_turn>user \n '
if self . model_type == " gemma3_text " :
template [ ' bos ' ] = ' <bos><start_of_turn>user \n '
template [ ' system ' ] = ' {content} \n \n '
template [ ' user ' ] = ' {content} <end_of_turn> \n '
template [ ' assistant ' ] = ' <start_of_turn>model \n {content} <end_of_turn> \n <start_of_turn>user \n '
if self . model_type in [ ' idefics3 ' , ' smolvlm ' ] :
template [ ' bos ' ] = ' <|im_start|> '
template [ ' system ' ] = ' System: {content} <end_of_utterance> \n '
template [ ' user ' ] = ' User: {content} <end_of_utterance> \n '
template [ ' assistant ' ] = ' Assistant: {content} <end_of_utterance> \n '
if ' qwen ' in self . model_type or ' mimo ' in self . model_type :
2025-02-24 11:44:27 +08:00
template [ ' system ' ] = ' <|im_start|>system \n {content} <|im_end|> \n '
template [ ' user ' ] = ' <|im_start|>user \n {content} <|im_end|> \n '
template [ ' assistant ' ] = ' <|im_start|>assistant \n {content} <|im_end|> \n '
2025-05-23 15:21:41 +08:00
if ' DeepSeek ' in self . args . path or ' deepseek ' in self . args . path :
template [ ' bos ' ] = ' <|begin_of_sentence|> '
template [ ' system ' ] = ' {content} \n '
template [ ' user ' ] = ' \n User: {content} \n '
template [ ' assistant ' ] = ' \n Assistant: {content} \n <|end_of_sentence|> '
2025-02-24 11:44:27 +08:00
return template
def build_prompt ( self , messages ) :
template = self . build_prompt_template ( )
prompt = template [ ' bos ' ]
for item in messages :
role , content = item [ ' role ' ] , item [ ' content ' ]
if ' {content} ' in template [ role ] :
prompt + = template [ role ] . format ( content = content )
else :
prompt + = role + ' \n ' + content + ' \n '
assistant_prefix = template [ ' assistant ' ] . split ( ' {content} ' ) [ 0 ]
return prompt + assistant_prefix
2024-09-12 12:57:57 +08:00
def str_to_ids ( self , prompt ) :
2024-09-12 20:19:02 +08:00
if self . visual is not None :
return self . visual . str_to_ids ( prompt )
2024-12-19 16:20:00 +08:00
if self . audio is not None :
return self . audio . str_to_ids ( prompt )
2024-09-12 12:57:57 +08:00
input_ids = self . tokenizer ( prompt , return_tensors = " pt " ) [ ' input_ids ' ]
return input_ids
def id_to_str ( self , token_id ) :
2024-12-19 16:20:00 +08:00
try :
word = self . tokenizer . decode ( int ( token_id ) )
except :
def contains_replacement ( text ) : return ' \uFFFD ' in text
def decode_id ( token_id ) :
return self . tokenizer . convert_tokens_to_string (
self . tokenizer . _convert_id_to_token ( int ( token_id ) ) )
def decode_ids ( token_ids ) :
return self . tokenizer . convert_tokens_to_string (
self . tokenizer . convert_ids_to_tokens ( token_ids ) )
word = decode_id ( int ( token_id ) )
# Smollm tokenizer will produce half chinese character, using buffer to decode
if contains_replacement ( word ) :
self . decode_buffer . append ( token_id )
buffer_txt = decode_ids ( self . decode_buffer )
if not contains_replacement ( buffer_txt ) :
word = buffer_txt
self . decode_buffer . clear ( )
else :
word = ' '
2024-09-12 12:57:57 +08:00
return word
2025-03-17 21:44:35 +08:00
@torch.no_grad ( )
2024-09-12 12:57:57 +08:00
def response ( self , query ) :
2024-11-18 14:37:45 +08:00
# self.imitate_quant()
self . decode_buffer = [ ]
2025-02-24 11:44:27 +08:00
messages = [
{ " role " : " system " , " content " : " You are a helpful assistant. " } ,
{ " role " : " user " , " content " : query }
]
prompt = self . build_prompt ( messages )
2024-09-12 12:57:57 +08:00
input_ids = self . str_to_ids ( prompt )
2024-11-18 14:37:45 +08:00
if self . visual is not None :
cross_attention_states = self . visual . cross_attention_states
cross_attention_mask = self . visual . cross_attention_mask
else :
cross_attention_states = None
cross_attention_mask = None
2024-09-12 12:57:57 +08:00
self . seq_len = input_ids . numel ( )
self . context_len = self . seq_len - 2
self . token_len = 0
past_key_values = [ None for i in range ( self . num_hidden_layers ) ]
token_id = input_ids
while self . token_len < self . max_length :
attention_mask = self . get_attention_mask ( )
2025-05-08 12:39:44 +08:00
position_ids = self . get_position_ids ( token_id )
2024-09-12 20:19:02 +08:00
input_ids = self . embedding ( token_id )
2025-05-08 12:39:44 +08:00
logits , past_key_values , _ = self . forward ( input_ids ,
attention_mask ,
position_ids ,
past_key_values ,
2025-06-05 15:15:29 +08:00
cross_attention_states = cross_attention_states ,
cross_attention_mask = cross_attention_mask )
2025-03-17 21:44:35 +08:00
token_id = torch . argmax ( logits [ : , - 1 , : ] )
2024-09-12 12:57:57 +08:00
if token_id in self . stop_ids :
print ( " " , end = ' \n ' )
break
word = self . id_to_str ( token_id )
print ( word , end = " " , flush = True )
2025-05-08 12:39:44 +08:00
if hasattr ( self , ' talker ' ) and self . talker is not None :
self . talker . generate ( )
2024-09-12 12:57:57 +08:00
2025-06-17 11:08:21 +08:00
def export_mtp ( self ) :
if self . mtp is None :
return
mtp_onnx = self . mtp . export ( self . onnx_path )
if self . mnn_converter :
self . mtp . unloaded_ops [ ' /lm/lm_head/Linear ' ] = self . unloaded_ops [ ' /lm/lm_head/Linear ' ]
MNNConveter ( self , self . mtp . unloaded_ops ) . export ( mtp_onnx )
2024-09-12 12:57:57 +08:00
@spinner_run ( f ' export embedding to ' )
def export_embed ( self ) :
import ctypes
if hasattr ( self , ' word_embeddings ' ) :
# embedding model's embed
tensor_data = self . word_embeddings . weight . data . bfloat16 ( )
else :
tensor_data = self . embed . embed . weight . data . bfloat16 ( )
data_ptr = tensor_data . untyped_storage ( ) . data_ptr ( )
buffer = ( ctypes . c_byte * ( tensor_data . numel ( ) * 2 ) ) . from_address ( data_ptr )
2024-12-19 16:20:00 +08:00
embedding_file = f ' { self . args . dst_path } /embeddings_bf16.bin '
2024-09-12 12:57:57 +08:00
with open ( embedding_file , ' wb ' ) as f :
f . write ( buffer )
return embedding_file
@spinner_run ( f ' export config to ' )
def export_config ( self , mnn_config = False ) :
2024-12-19 16:20:00 +08:00
config_json = f ' { self . args . dst_path } /llm_config.json '
2024-09-12 12:57:57 +08:00
with open ( config_json , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( self . llm_config , f , ensure_ascii = False , indent = 4 )
if not mnn_config :
return config_json
2024-12-19 16:20:00 +08:00
with open ( f ' { self . args . dst_path } /config.json ' , ' w ' , encoding = ' utf-8 ' ) as f :
2024-09-12 12:57:57 +08:00
config = {
" llm_model " : f " { self . dst_name } .mnn " ,
" llm_weight " : f " { self . dst_name } .mnn.weight " ,
" backend_type " : " cpu " ,
" thread_num " : 4 ,
" precision " : " low " ,
2025-04-28 11:38:44 +08:00
" memory " : " low " ,
2025-05-23 15:21:41 +08:00
# "system_prompt": "You are a helpful assistant.",
2025-06-05 15:15:29 +08:00
" sampler_type " : ' penalty ' ,
" penalty " : 1.1
2024-09-12 12:57:57 +08:00
}
2025-05-08 12:39:44 +08:00
if self . talker is not None :
config [ ' system_prompt ' ] = " You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech. "
config [ ' talker_max_new_tokens ' ] = 2048
config [ ' talker_speaker ' ] = " Chelsie "
config [ ' dit_steps ' ] = 5
config [ ' dit_solver ' ] = 1
2025-05-23 15:21:41 +08:00
if self . model_type == " gemma3 " :
2025-06-17 11:08:21 +08:00
config . update ( { ' precision ' : " normal " } )
2024-12-19 16:20:00 +08:00
if self . visual is not None or self . audio is not None :
config [ ' mllm ' ] = {
' backend_type ' : " cpu " ,
" thread_num " : 4 ,
2025-05-08 12:39:44 +08:00
" precision " : " normal " ,
2024-12-19 16:20:00 +08:00
" memory " : " low "
}
2024-09-12 12:57:57 +08:00
json . dump ( config , f , ensure_ascii = False , indent = 4 )
return config_json
def imitate_quant ( self ) :
2024-12-19 16:20:00 +08:00
def quant_dequant ( linear , quant_bit = self . args . quant_bit , quant_block = self . args . quant_block ) :
2024-09-12 12:57:57 +08:00
weight = linear . weight . data
oc , ic = weight . shape
if quant_block == 0 :
block_size = ic
else :
block_size = quant_block
block_num = ic / / block_size
weight = weight . reshape ( oc , block_num , block_size )
max_val = torch . max ( weight , axis = - 1 , keepdims = True ) . values
min_val = torch . min ( weight , axis = - 1 , keepdims = True ) . values
offset = 1 << ( quant_bit - 1 )
clip_max = offset - 1
clip_min = - offset
scale = ( max_val - min_val ) / ( clip_max - clip_min )
q_weight = torch . round ( ( weight - min_val ) / scale ) + clip_min
q_weight = torch . clip ( q_weight , clip_min , clip_max )
dq_weight = ( q_weight - clip_min ) * scale + min_val
dq_weight = dq_weight . reshape ( oc , ic ) . float ( )
linear . weight . data = dq_weight
return linear
with torch . no_grad ( ) :
for i in range ( self . num_hidden_layers ) :
for name , child in self . blocks [ i ] . self_attn . named_children ( ) :
if isinstance ( child , torch . nn . Linear ) :
setattr ( self . blocks [ i ] . self_attn , name , quant_dequant ( child ) )
for name , child in self . blocks [ i ] . mlp . named_children ( ) :
if isinstance ( child , torch . nn . Linear ) :
setattr ( self . blocks [ i ] . mlp , name , quant_dequant ( child ) )
self . lm . lm = quant_dequant ( self . lm . lm )
def unload_param ( self ) :
self . unloaded_ops = { }
2025-05-23 15:21:41 +08:00
self . experts = [ ]
2024-09-12 12:57:57 +08:00
def build_faker ( real , name ) :
faker = FakeLinear ( real . in_features , real . out_features , real . bias is not None , name )
self . unloaded_ops [ name ] = real
return faker
# replace linear with fakelinear to save export memory and time
with torch . no_grad ( ) :
2025-05-23 15:21:41 +08:00
for i in range ( len ( self . blocks ) ) :
2024-11-18 14:37:45 +08:00
# different kv cache shape in different layers
if isinstance ( self . num_attention_heads , list ) :
self . blocks [ i ] . self_attn . export_fused_attn = True
2025-05-23 15:21:41 +08:00
is_moe = hasattr ( self . blocks [ i ] . mlp , ' is_moe ' ) and self . blocks [ i ] . mlp . is_moe
if is_moe :
self . blocks [ i ] . mlp . export_moe = True
2024-09-12 12:57:57 +08:00
for name , child in self . blocks [ i ] . self_attn . named_children ( ) :
if isinstance ( child , torch . nn . Linear ) :
setattr ( self . blocks [ i ] . self_attn , name , build_faker ( child , f ' /layers. { i } /self_attn/ { name } /Linear ' ) )
for name , child in self . blocks [ i ] . mlp . named_children ( ) :
if isinstance ( child , torch . nn . Linear ) :
setattr ( self . blocks [ i ] . mlp , name , build_faker ( child , f ' /layers. { i } /mlp/ { name } /Linear ' ) )
2025-05-23 15:21:41 +08:00
if is_moe and isinstance ( child , torch . nn . ModuleList ) : # experts
self . experts . append ( child )
for j in range ( len ( child ) ) :
for name , cchild in child [ j ] . named_children ( ) :
if isinstance ( cchild , torch . nn . Linear ) :
setattr ( self . blocks [ i ] . mlp . experts [ j ] , name , build_faker ( cchild , f ' /expert/ { i } _ { j } / { name } ' ) )
2024-09-12 12:57:57 +08:00
self . lm . lm = build_faker ( self . lm . lm , f ' /lm/lm_head/Linear ' )
@spinner_run ( f ' export model weight to ' )
def onnx_load_param ( self , onnx_path ) :
return OnnxRebuilder ( onnx_path , self . unloaded_ops ) . rebuild ( )
@spinner_run ( f ' slim the graph of ' )
2024-12-19 16:20:00 +08:00
def slim_onnx ( self , onnx_model ) :
2024-09-12 12:57:57 +08:00
import onnxslim
model = onnxslim . slim ( onnx_model )
onnx . save ( model , onnx_model )
return onnx_model
@spinner_run ( f ' export onnx model to ' )
def export_onnx ( self ) :
# unload linear weight to save export memory
self . unload_param ( )
model = self
self . seq_len = 3
self . token_len = 0
input_ids = torch . arange ( 3 , dtype = torch . long )
attention_mask = self . get_attention_mask ( )
2025-05-08 12:39:44 +08:00
position_ids = self . get_position_ids ( input_ids )
2024-09-12 20:19:02 +08:00
onnx_model = f ' { self . onnx_path } / { self . dst_name } .onnx '
2025-04-28 11:38:44 +08:00
# For export onnx, don't need image or audio's embedding
input_ids = self . embed ( input_ids )
2024-11-18 14:37:45 +08:00
past_key_values = torch . zeros ( self . past_kv_shape )
2025-02-24 11:44:27 +08:00
logits_index = torch . tensor ( [ - 1 ] , dtype = torch . int32 )
2025-05-08 12:39:44 +08:00
if hasattr ( self , ' talker ' ) and self . talker is not None :
2025-06-17 11:08:21 +08:00
output_names = [ ' logits ' , ' hidden_states ' , ' presents ' , ' talker_embeds ' ]
2025-05-08 12:39:44 +08:00
else :
2025-06-17 11:08:21 +08:00
output_names = [ ' logits ' , ' hidden_states ' , ' presents ' ]
2024-09-12 12:57:57 +08:00
# export to onnx
torch . onnx . export (
2025-02-24 11:44:27 +08:00
model , ( input_ids , attention_mask , position_ids , past_key_values , logits_index ) ,
2024-09-12 12:57:57 +08:00
onnx_model ,
input_names = [
2025-02-24 11:44:27 +08:00
' input_ids ' , ' attention_mask ' , ' position_ids ' , ' past_key_values ' , ' logits_index '
2024-09-12 12:57:57 +08:00
] ,
2025-05-08 12:39:44 +08:00
output_names = output_names ,
2025-06-17 11:08:21 +08:00
2024-09-12 12:57:57 +08:00
dynamic_axes = self . model_dynamic_axes ,
do_constant_folding = True ,
2024-11-18 14:37:45 +08:00
verbose = False ,
2024-09-12 12:57:57 +08:00
opset_version = 15 )
return onnx_model
2024-11-18 14:37:45 +08:00
def awq_quant ( self ) :
self . awq_quantizer = AwqQuantizer ( self )
self . awq_quantizer . quantize ( )
self . is_awq_quantized = True
2025-05-08 12:39:44 +08:00
def export_vision ( self ) :
if self . visual is None :
return
vision_onnx = self . visual . export ( self . onnx_path )
2025-05-23 15:21:41 +08:00
if self . mnn_converter :
self . mnn_converter . export ( vision_onnx , self . visual . quant_bit ,
self . visual . quant_block ,
transformer_fuse = self . visual . transformer_fuse )
2025-05-08 12:39:44 +08:00
def export_audio ( self ) :
if self . audio is None :
return
audio_onnx = self . audio . export ( self . onnx_path )
if self . mnn_converter : self . mnn_converter . export ( audio_onnx , self . audio . quant_bit )
def export_talker ( self ) :
if self . talker is None :
return
talker_onnx = self . talker . export ( self . onnx_path )
predit_onnx , dit_onnx , bigvgan_onnx = self . talker . token2wav . export ( self . onnx_path )
if self . mnn_converter :
self . mnn_converter . export ( talker_onnx , self . talker . quant_bit )
self . mnn_converter . export ( predit_onnx , self . talker . token2wav . quant_bit )
self . mnn_converter . export ( dit_onnx , self . talker . token2wav . quant_bit )
self . mnn_converter . export ( bigvgan_onnx , self . talker . token2wav . quant_bit )
def export_language ( self ) :
# export_embedding
if self . mnn_converter and self . tie_word_embeddings :
2025-02-17 16:54:34 +08:00
pass # mnn tie_word_embeddings need't export embedding
else :
self . export_embed ( )
2025-05-08 12:39:44 +08:00
# export transformer
2024-09-12 12:57:57 +08:00
onnx_model = self . export_onnx ( )
2025-06-17 11:08:21 +08:00
2024-12-19 16:20:00 +08:00
if self . args . onnx_slim :
self . slim_onnx ( onnx_model )
2025-05-08 12:39:44 +08:00
if self . mnn_converter :
MNNConveter ( self , self . unloaded_ops ) . export ( onnx_model )
2025-06-17 11:08:21 +08:00
2024-09-12 12:57:57 +08:00
else :
self . onnx_load_param ( onnx_model )
2024-11-18 14:37:45 +08:00
2025-05-08 12:39:44 +08:00
def export ( self , export_type ) :
if self . args . awq :
self . awq_quant ( )
export_mnn = export_type == ' mnn '
self . mnn_converter = MNNConveter ( self ) if export_mnn else None
self . export_talker ( )
self . export_vision ( )
self . export_audio ( )
self . export_language ( )
2025-06-17 11:08:21 +08:00
self . export_mtp ( )
2025-05-08 12:39:44 +08:00
self . export_tokenizer ( )
self . export_config ( export_mnn )
if export_mnn :
# delete onnx file
try :
for file in glob . glob ( f ' { self . onnx_path } /* ' ) :
os . remove ( file )
os . rmdir ( self . onnx_path )
except Exception as e :
print ( f " remove onnx error: { e } " )
2024-09-12 12:57:57 +08:00
@spinner_run ( f ' export tokenizer to ' )
def export_tokenizer ( self ) :
# load tokenizer file
2024-12-19 16:20:00 +08:00
tokenizer_model = os . path . join ( self . args . tokenizer_path , ' tokenizer.model ' )
ice_text_model = os . path . join ( self . args . tokenizer_path , ' ice_text.model ' )
2024-09-12 12:57:57 +08:00
try :
import sentencepiece as spm
if os . path . exists ( tokenizer_model ) :
self . sp_model = spm . SentencePieceProcessor ( tokenizer_model )
elif os . path . exists ( ice_text_model ) :
self . sp_model = spm . SentencePieceProcessor ( ice_text_model )
else :
self . sp_model = None
except :
self . sp_model = None
2024-12-19 16:20:00 +08:00
merge_file = os . path . join ( self . args . path , ' merges.txt ' )
2024-09-12 12:57:57 +08:00
if os . path . exists ( merge_file ) :
self . merge_txt = merge_file
else :
self . merge_txt = None
# TOKENIZER MAGIC NUMBER
MAGIC_NUMBER = 430
# TOKENIZER TYPE
SENTENCEPIECE = 0 ; TIKTOIKEN = 1 ; BERT = 2 ; HUGGINGFACE = 3
def write_line ( fp , * args ) :
for arg in args :
for token in arg :
fp . write ( str ( token ) + ' ' )
fp . write ( ' \n ' )
def write_header ( fp , type , speicals , prefix = [ ] ) :
fp . write ( f ' { MAGIC_NUMBER } { type } \n ' )
fp . write ( f ' { len ( speicals ) } { len ( self . stop_ids ) } { len ( prefix ) } \n ' )
write_line ( fp , speicals , self . stop_ids , prefix )
2024-12-19 16:20:00 +08:00
file_path = os . path . join ( self . args . dst_path , " tokenizer.txt " )
2024-09-12 12:57:57 +08:00
special_list = list ( self . tokenizer . added_tokens_decoder . keys ( ) )
if hasattr ( self . tokenizer , ' special_tokens ' ) :
for k , v in self . tokenizer . special_tokens . items ( ) :
special_list . append ( v )
2025-05-23 15:21:41 +08:00
if hasattr ( self . tokenizer , ' all_special_ids ' ) : #gemma3
special_list . extend ( self . tokenizer . all_special_ids )
2024-09-12 12:57:57 +08:00
if hasattr ( self . tokenizer , ' gmask_token_id ' ) :
special_list . append ( self . tokenizer . gmask_token_id )
2025-02-24 11:44:27 +08:00
if hasattr ( self . model , ' generation_config ' ) and self . model . generation_config is not None :
generation_config = self . model . generation_config
if hasattr ( generation_config , ' user_token_id ' ) :
special_list . append ( generation_config . user_token_id )
if hasattr ( generation_config , ' assistant_token_id ' ) :
special_list . append ( generation_config . assistant_token_id )
2024-09-12 12:57:57 +08:00
vocab_list = [ ]
prefix_list = [ ]
if hasattr ( self . tokenizer , ' get_prefix_tokens ' ) :
prefix_list = self . tokenizer . get_prefix_tokens ( )
2024-11-18 14:37:45 +08:00
if len ( prefix_list ) == 0 :
2025-01-22 14:47:50 +08:00
try :
test_txt = ' A '
ids = self . tokenizer . encode ( test_txt )
get_txt = self . tokenizer . decode ( ids [ - 1 ] )
if len ( ids ) > 1 and get_txt == test_txt :
prefix_list + = ids [ : - 1 ]
except :
pass
2024-11-18 14:37:45 +08:00
2024-09-12 12:57:57 +08:00
if self . sp_model is not None :
# senetencepiece
NORMAL = 1 ; UNKNOWN = 2 ; CONTROL = 3
USER_DEFINED = 4 ; UNUSED = 5 ; BYTE = 6
for i in range ( self . sp_model . GetPieceSize ( ) ) :
token = self . sp_model . IdToPiece ( i )
score = self . sp_model . GetScore ( i )
token_type = NORMAL
if self . sp_model . IsUnknown ( i ) :
token_type = UNKNOWN
elif self . sp_model . IsControl ( i ) :
token_type = CONTROL
elif self . sp_model . IsUnused ( i ) :
token_type = UNUSED
elif self . sp_model . IsByte ( i ) :
token_type = BYTE
2024-12-19 16:20:00 +08:00
if self . args . path == ' Chatglm_6b ' :
2024-09-12 12:57:57 +08:00
if ' <n> ' in token : token = ' \n '
if ' <|tab|> ' in token : token = ' \t '
if ' <|blank_ ' in token : token = ' ' * int ( token [ 8 : token . find ( ' |> ' ) ] )
if ' ▁ ' in token : token = token . replace ( ' ▁ ' , ' ' )
token_encode = base64 . b64encode ( token . encode ( " utf-8 " ) ) . decode ( " utf8 " )
vocab_list . append ( f ' { token_encode } { score } { token_type } \n ' )
with open ( file_path , " w " , encoding = " utf8 " ) as fp :
write_header ( fp , SENTENCEPIECE , special_list , prefix_list )
2025-05-23 15:21:41 +08:00
if self . model_type == " gemma3 " or self . model_type == " gemma3-text " :
fp . write ( f ' { len ( vocab_list ) + 1 } \n ' ) # len(vocab_list)==262144, self.tokenizer([262144])=='image_soft_token' is a special token
else :
fp . write ( f ' { len ( vocab_list ) } \n ' )
2024-09-12 12:57:57 +08:00
for vocab in vocab_list :
fp . write ( vocab )
elif hasattr ( self . tokenizer , ' mergeable_ranks ' ) :
# tikton
vocab_list = [ ]
for k , v in self . tokenizer . mergeable_ranks . items ( ) :
line = base64 . b64encode ( k ) . decode ( " utf8 " ) + " \n "
vocab_list . append ( line )
if hasattr ( self . tokenizer , ' special_tokens ' ) :
for k , v in self . tokenizer . special_tokens . items ( ) :
line = base64 . b64encode ( k . encode ( " utf-8 " ) ) . decode ( " utf8 " ) + " \n "
vocab_list . append ( line )
if hasattr ( self . tokenizer , ' added_tokens_decoder ' ) :
for k , v in self . tokenizer . added_tokens_decoder . items ( ) :
line = base64 . b64encode ( v . __str__ ( ) . encode ( " utf-8 " ) ) . decode ( " utf8 " ) + " \n "
vocab_list . append ( line )
with open ( file_path , " w " , encoding = " utf8 " ) as fp :
write_header ( fp , TIKTOIKEN , special_list , prefix_list )
fp . write ( f ' { len ( vocab_list ) } \n ' )
for vocab in vocab_list :
fp . write ( vocab )
elif self . merge_txt is not None :
# huggingface tokenizer
merge_list = [ ]
vocab = self . tokenizer . get_vocab ( )
special_list = list ( self . tokenizer . added_tokens_decoder . keys ( ) )
vocab_list = [ ' <unk> ' for i in range ( len ( vocab ) ) ]
# load vocab
for k , v in vocab . items ( ) :
vocab_list [ int ( v ) ] = k
# load merge
with open ( self . merge_txt , ' rt ' ) as merge :
for line in merge . readlines ( ) :
merge_list . append ( line )
# write to tokenizer.txt
with open ( file_path , " w " , encoding = " utf8 " ) as fp :
write_header ( fp , HUGGINGFACE , special_list )
fp . write ( f ' { len ( vocab_list ) } { len ( merge_list ) } \n ' )
for v in vocab_list :
fp . write ( v + ' \n ' )
for m in merge_list :
fp . write ( m )
else :
# tiktoken or bert
if ' bert ' in type ( self . tokenizer ) . __name__ . lower ( ) :
tokenizer_type = BERT
else :
tokenizer_type = TIKTOIKEN
# bert tokenizer
def unicode_to_byte ( u : int ) :
if u > = 256 and u < = 288 :
return u - 256
if u > = 289 and u < = 322 :
return u - 162
if u == 323 :
return 173
if u == 65372 : # |
return 124
if u == 9601 : # _
return 95
return u
vocab = self . tokenizer . get_vocab ( )
vocab_list = [ ' <unk> ' for i in range ( len ( vocab ) ) ]
for k , v in vocab . items ( ) :
2025-03-12 11:35:16 +08:00
try :
vocab_list [ int ( v ) ] = bytes ( [ unicode_to_byte ( ord ( c ) ) for c in k ] )
except :
vocab_list [ int ( v ) ] = k . encode ( ' utf-8 ' )
2025-02-24 11:44:27 +08:00
2024-09-12 12:57:57 +08:00
special_list = list ( self . tokenizer . added_tokens_decoder . keys ( ) )
with open ( file_path , " w " , encoding = " utf8 " ) as fp :
write_header ( fp , tokenizer_type , special_list )
fp . write ( f ' { len ( vocab_list ) } \n ' )
for v in vocab_list :
2025-02-24 11:44:27 +08:00
line = base64 . b64encode ( v ) . decode ( " utf8 " ) + " \n "
2024-09-12 12:57:57 +08:00
fp . write ( line )
return file_path
class EmbeddingExporter ( LlmExporter ) :
def __init__ ( self , args ) :
super ( ) . __init__ ( args )
self . dst_name = ' embedding '
def word_embed ( self , input_ids ) :
return self . word_embeddings ( input_ids . view ( 1 , - 1 ) )
def bge_forward ( self , inputs_embeds , position_ids , attention_mask ) :
# bert absolute position
inputs_embeds = inputs_embeds . reshape ( 1 , - 1 , self . hidden_size )
position_embeddings = self . position_embeddings ( position_ids )
embeddings = inputs_embeds + position_embeddings + self . token_type_embeddings
hidden_states = self . embedding_layernorm ( embeddings )
for i in range ( self . num_hidden_layers ) :
hidden_states = self . blocks [ i ] ( hidden_states , attention_mask ) [ 0 ]
sentence_embeddings = hidden_states [ : , 0 ]
sentence_embeddings = torch . nn . functional . normalize ( sentence_embeddings , p = 2 , dim = 1 )
return sentence_embeddings
def gte_forward ( self , inputs_embeds , position_ids , attention_mask ) :
# rope position
inputs_embeds = inputs_embeds . reshape ( 1 , - 1 , self . hidden_size )
freqs = position_ids . float ( ) . reshape ( - 1 , 1 ) * self . inv_freq
emb = torch . cat ( ( freqs , freqs ) , dim = - 1 )
rope_embeds = torch . stack ( [ emb . cos ( ) , emb . sin ( ) ] ) . unsqueeze ( - 2 ) . unsqueeze ( 1 )
attention_bias = 1 - attention_mask . float ( )
hidden_states = self . embedding_layernorm ( inputs_embeds + self . token_type_embeddings )
for i in range ( self . num_hidden_layers ) :
hidden_states = self . blocks [ i ] ( hidden_states , attention_bias , rope_embeds ) [ 0 ]
sentence_embeddings = hidden_states [ : , 0 ]
sentence_embeddings = torch . nn . functional . normalize ( sentence_embeddings , p = 2 , dim = 1 )
return sentence_embeddings
def forward ( self , inputs_embeds , position_ids , attention_mask ) :
if self . model_type == ' bert ' :
return self . bge_forward ( inputs_embeds , position_ids , attention_mask )
if self . model_type == ' new ' :
return self . gte_forward ( inputs_embeds , position_ids , attention_mask )
raise RuntimeError ( f ' Not support embedding model: { self . model_type } ! ' )
def response ( self , query ) :
self . eval ( )
input_ids = self . tokenizer ( query ) [ ' input_ids ' ]
self . seq_len = len ( input_ids )
input_ids = torch . tensor ( input_ids )
position_ids = self . get_position_ids ( )
attention_mask = self . get_attention_mask ( )
inputs_embeds = self . word_embed ( input_ids )
res = self . forward ( inputs_embeds , position_ids , attention_mask )
2024-09-12 20:19:02 +08:00
# print(res)
2024-09-12 12:57:57 +08:00
return res
@spinner_run ( f ' load pretrained model ' )
def load_model ( self , model_path ) :
self . tokenizer = AutoTokenizer . from_pretrained ( model_path , trust_remote_code = True )
2024-11-18 14:37:45 +08:00
self . config = AutoConfig . from_pretrained ( model_path )
self . config . _attn_implementation = ' eager '
self . model = AutoModel . from_config ( self . config )
2024-09-12 12:57:57 +08:00
transformer = self . model . encoder
self . model_type = self . config . model_type
self . lm_ = self . model . pooler
self . embed_ = self . model . embeddings
self . word_embeddings = self . embed_ . word_embeddings
self . token_type_embeddings = self . embed_ . token_type_embeddings . weight . data [ 0 ]
self . embedding_layernorm = self . embed_ . LayerNorm
if hasattr ( self . embed_ , ' position_embeddings ' ) :
self . position_embeddings = self . embed_ . position_embeddings
self . hidden_size = self . word_embeddings . weight . shape [ - 1 ]
self . blocks = transformer . layer
if self . model_type == ' new ' :
self . inv_freq = self . embed_ . rotary_emb . inv_freq
# some wrapper
self . stop_ids = [ ]
self . num_hidden_layers = len ( self . blocks )
self . embed = self . embed_
self . lm = self . lm_
# some config for export
self . model_dynamic_axes = {
" input_ids " : { 1 : " seq_len " } ,
" position_ids " : { 1 : " seq_len " } ,
" attention_mask " : { 3 : " seq_len " }
}
self . attention_mask_type = ' int '
self . llm_config = {
' hidden_size ' : self . hidden_size ,
' layer_nums ' : self . num_hidden_layers ,
' attention_mask ' : self . attention_mask_type ,
' key_value_shape ' : [ ] ,
" prompt_template " : self . build_prompt ( ' %s ' ) ,
' is_visual ' : False
}
return model_path
@spinner_run ( f ' export onnx model to ' )
def export_onnx ( self ) :
model = self . eval ( )
self . seq_len = 3
input_ids = torch . arange ( 3 , dtype = torch . long )
position_ids = self . get_position_ids ( )
attention_mask = self . get_attention_mask ( )
inputs_embeds = self . word_embed ( input_ids )
2024-09-12 20:19:02 +08:00
onnx_model = f ' { self . onnx_path } / { self . dst_name } .onnx '
2024-09-12 12:57:57 +08:00
torch . onnx . export (
model , ( inputs_embeds , position_ids , attention_mask ) ,
onnx_model ,
input_names = [
' input_ids ' ,
' position_ids ' ,
' attention_mask '
] ,
output_names = [ ' sentence_embeddings ' ] ,
dynamic_axes = self . model_dynamic_axes ,
do_constant_folding = True ,
opset_version = 15 )
return onnx_model
def export ( self , export_type ) :
export_mnn = ' mnn ' in export_type
self . export_tokenizer ( )
self . export_config ( export_mnn )
self . export_embed ( )
onnx_model = self . export_onnx ( )
2024-12-19 16:20:00 +08:00
if self . args . onnx_slim :
self . slim_onnx ( onnx_model )
2024-09-12 12:57:57 +08:00
if export_mnn :
MNNConveter ( onnx_model , None , self ) . export ( )
2025-02-24 11:44:27 +08:00
def build_prompt ( self , content ) :
2024-09-12 12:57:57 +08:00
if self . model_type == ' bert ' :
2025-02-24 11:44:27 +08:00
return f ' [CLS] { content } [SEP] '
2024-09-12 12:57:57 +08:00
if self . model_type == ' new ' :
2025-02-24 11:44:27 +08:00
return f ' <s> { content } </s> '
2024-09-12 12:57:57 +08:00
def get_position_ids ( self ) - > torch . Tensor :
return torch . arange ( self . seq_len , dtype = torch . long ) . unsqueeze ( 0 )
def get_attention_mask ( self ) - > torch . Tensor :
return torch . ones ( [ 1 , 1 , 1 , self . seq_len ] , dtype = torch . long )
2024-11-18 14:37:45 +08:00
2024-09-12 12:57:57 +08:00
def export ( path ,
type = None ,
2025-01-22 14:47:50 +08:00
tokenizer_path = None ,
2024-09-12 12:57:57 +08:00
lora_path = None ,
2025-01-22 14:47:50 +08:00
gptq_path = None ,
2024-09-12 12:57:57 +08:00
dst_path = ' ./model ' ,
export = ' onnx ' ,
2024-11-18 14:37:45 +08:00
onnx_slim = False ,
2024-09-12 12:57:57 +08:00
quant_bit = 4 ,
2025-06-05 15:15:29 +08:00
quant_block = 64 ,
2025-01-22 14:47:50 +08:00
lm_quant_bit = None ,
mnnconvert = None ,
ppl = False ,
awq = False ,
sym = False ,
2025-06-05 15:15:29 +08:00
seperate_embed = False ,
2025-01-22 14:47:50 +08:00
lora_split = False ) :
2024-09-12 12:57:57 +08:00
args = argparse . Namespace ( )
for k , v in {
' path ' : path ,
' type ' : type ,
2025-01-22 14:47:50 +08:00
' tokenizer_path ' : tokenizer_path ,
2024-09-12 12:57:57 +08:00
' lora_path ' : lora_path ,
2025-01-22 14:47:50 +08:00
' gptq_path ' : gptq_path ,
2024-09-12 12:57:57 +08:00
' dst_path ' : dst_path ,
' export ' : export ,
2024-11-18 14:37:45 +08:00
' onnx_slim ' : onnx_slim ,
2024-09-12 12:57:57 +08:00
' quant_bit ' : quant_bit ,
' quant_block ' : quant_block ,
2025-01-22 14:47:50 +08:00
' lm_quant_bit ' : lm_quant_bit ,
' mnnconvert ' : mnnconvert ,
' ppl ' : ppl ,
' awq ' : awq ,
' sym ' : sym ,
2025-06-05 15:15:29 +08:00
' seperate_embed ' : seperate_embed ,
2025-01-22 14:47:50 +08:00
' lora_split ' : lora_split
2024-09-12 12:57:57 +08:00
} . items ( ) :
setattr ( args , k , v )
if ' bge ' in path :
llm_exporter = EmbeddingExporter ( args )
else :
llm_exporter = LlmExporter ( args )
# export
llm_exporter . export ( export )
def main ( ) :
parser = argparse . ArgumentParser ( description = ' llm_exporter ' , formatter_class = argparse . RawTextHelpFormatter )
parser . add_argument ( ' --path ' , type = str , required = True ,
help = ' path(`str` or `os.PathLike`): \n Can be either: '
' \n \t - A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO] '
' \n \t - A path to a *directory* clone from repo like `../chatglm-6b`. ' )
parser . add_argument ( ' --type ' , type = str , default = None ,
help = ' type(`str`, *optional*): '
' \n \t The pretrain llm model type. '
)
2024-11-18 14:37:45 +08:00
parser . add_argument ( ' --tokenizer_path ' , type = str , default = None , help = ' tokenizer path, defaut is `None` mean using `--path` value. ' )
2024-09-12 12:57:57 +08:00
parser . add_argument ( ' --lora_path ' , type = str , default = None , help = ' lora path, defaut is `None` mean not apply lora. ' )
2025-01-22 14:47:50 +08:00
parser . add_argument ( ' --gptq_path ' , type = str , default = None , help = ' gptq path, defaut is `None` mean not apply gptq. ' )
2024-09-12 12:57:57 +08:00
parser . add_argument ( ' --dst_path ' , type = str , default = ' ./model ' , help = ' export onnx/mnn model to path, defaut is `./model`. ' )
2024-11-18 14:37:45 +08:00
parser . add_argument ( ' --verbose ' , action = ' store_true ' , help = ' Whether or not to print verbose. ' )
2024-09-12 12:57:57 +08:00
parser . add_argument ( ' --test ' , type = str , help = ' test model inference with query `TEST`. ' )
parser . add_argument ( ' --export ' , type = str , default = None , help = ' export model to an onnx/mnn model. ' )
2024-11-18 14:37:45 +08:00
parser . add_argument ( ' --onnx_slim ' , action = ' store_true ' , help = ' Whether or not to use onnx-slim. ' )
2024-09-12 12:57:57 +08:00
parser . add_argument ( ' --quant_bit ' , type = int , default = 4 , help = ' mnn quant bit, 4 or 8, default is 4. ' )
2025-06-05 15:15:29 +08:00
parser . add_argument ( ' --quant_block ' , type = int , default = 64 , help = ' mnn quant block, 0 mean channle-wise, default is 64. ' )
2024-09-12 12:57:57 +08:00
parser . add_argument ( ' --lm_quant_bit ' , type = int , default = None , help = ' mnn lm_head quant bit, 4 or 8, default is `quant_bit`. ' )
parser . add_argument ( ' --mnnconvert ' , type = str , default = ' ../../../build/MNNConvert ' , help = ' local mnnconvert path, if invalid, using pymnn. ' )
2024-11-18 14:37:45 +08:00
parser . add_argument ( ' --ppl ' , action = ' store_true ' , help = ' Whether or not to get all logits of input tokens. ' )
parser . add_argument ( ' --awq ' , action = ' store_true ' , help = ' Whether or not to use awq quant. ' )
parser . add_argument ( ' --sym ' , action = ' store_true ' , help = ' Whether or not to using symmetric quant (without zeropoint), defualt is False. ' )
2025-06-05 15:15:29 +08:00
parser . add_argument ( ' --seperate_embed ' , action = ' store_true ' , help = ' For lm and embed shared model, whether or not to sepearte embed to avoid quant, defualt is False, if True, embed weight will be seperate to embeddingbf16.bin. ' )
2025-01-22 14:47:50 +08:00
parser . add_argument ( ' --lora_split ' , action = ' store_true ' , help = ' Whether or not export lora split, defualt is False. ' )
2024-09-12 12:57:57 +08:00
args = parser . parse_args ( )
model_path = args . path
if ' gte ' in model_path or ' bge ' in model_path :
llm_exporter = EmbeddingExporter ( args )
else :
llm_exporter = LlmExporter ( args )
# some actions
if args . test is not None :
llm_exporter . response ( args . test )
if args . export is not None :
llm_exporter . export ( args . export )
if __name__ == ' __main__ ' :
2025-05-23 15:21:41 +08:00
main ( )