2024-09-12 12:57:57 +08:00
import os
2024-11-18 14:37:45 +08:00
import gc
2024-09-12 12:57:57 +08:00
import sys
import math
import copy
import json
import time
import base64
import logging
2024-11-18 14:37:45 +08:00
import inspect
2024-09-12 12:57:57 +08:00
import warnings
import argparse
import functools
2024-11-18 14:37:45 +08:00
import traceback
from collections import defaultdict
from typing import Optional , Tuple , List , Union , Dict
2024-09-12 12:57:57 +08:00
2024-11-18 14:37:45 +08:00
from tqdm import tqdm
2024-09-12 12:57:57 +08:00
from yaspin import yaspin
import onnx
import torch
import numpy as np
2024-11-18 14:37:45 +08:00
from transformers import AutoConfig , AutoModel , AutoModelForCausalLM , AutoTokenizer
2024-09-12 12:57:57 +08:00
RESET = " \033 [0m "
GREEN = " \033 [32;1m "
YELLOW = " \033 [33;4m "
EXPORT_LOG = ' .export.log '
# ignore warnning info
warnings . filterwarnings ( " ignore " )
logging . basicConfig ( level = logging . ERROR )
os . environ [ ' TOKENIZERS_PARALLELISM ' ] = ' false '
def spinner_run ( text = ' Processing... ' ) :
def decorator ( func ) :
@functools.wraps ( func )
def wrapper ( * args , * * kwargs ) :
with yaspin ( text = text , color = " cyan " ) as spinner :
start = time . time ( )
try :
result = func ( * args , * * kwargs )
except Exception as e :
spinner . fail ( " 💥 Failed " )
2024-11-18 14:37:45 +08:00
traceback . print_exc ( )
2024-09-12 12:57:57 +08:00
exit ( 1 )
end = time . time ( )
during = f ' [ { end - start : 05.2f } s] ' . replace ( ' [0 ' , ' [ ' )
padding = ' ' * ( 64 - len ( spinner . text ) - len ( result ) )
spinner . text = f ' { spinner . text } { YELLOW } { result } { RESET } { padding } { GREEN } { during } { RESET } '
spinner . ok ( " ✅ Done " )
return result
return wrapper
return decorator
class ModelMapper :
def __init__ ( self ) :
self . attrs = [ ]
self . mapper = dict ( )
self . regist_models ( )
def get_map ( self , config ) :
model_type = config . model_type
if model_type == ' chatglm ' :
if hasattr ( config , ' vocab_size ' ) and config . vocab_size == 130528 :
model_type = ' chatglm '
else :
model_type = ' chatglm2 '
if model_type in self . mapper :
return model_type , self . mapper [ model_type ]
return model_type , self . default_map
def regist ( self , model_type , model_map ) :
assert ( ' config ' in model_map and
' decoder ' in model_map and
' attention ' in model_map )
self . mapper [ model_type ] = model_map
def regist_models ( self ) :
self . defualt_map ( )
# regist models
self . regist_llama ( )
2024-11-18 14:37:45 +08:00
self . regist_mllama ( )
2024-09-12 12:57:57 +08:00
self . regist_qwen ( )
self . regist_glm ( )
self . regist_glm2 ( )
self . regist_phi ( )
self . regist_gemma2 ( )
2024-11-18 14:37:45 +08:00
self . register_openelm ( )
2024-09-12 12:57:57 +08:00
def regist_llama ( self ) :
llama_map = self . default_map
self . regist ( ' llama ' , llama_map )
self . regist ( ' qwen2 ' , llama_map )
self . regist ( ' internlm ' , llama_map )
2024-11-18 14:37:45 +08:00
self . regist ( ' mobilellm ' , llama_map )
# baichuan
2024-09-12 12:57:57 +08:00
baichuan_map = copy . deepcopy ( self . default_map )
baichuan_map [ self . attention_key ] = {
' qkv_proj ' : ' W_pack ' ,
' o_proj ' : ' o_proj '
}
self . regist ( ' baichuan ' , baichuan_map )
2024-11-18 14:37:45 +08:00
def regist_mllama ( self ) :
mllama_map = {
' config ' : {
' hidden_size ' : ' text_config.hidden_size ' ,
' num_attention_heads ' : ' text_config.num_attention_heads ' ,
' num_hidden_layers ' : ' text_config.num_hidden_layers ' ,
' num_key_value_heads ' : ' text_config.num_key_value_heads ' ,
' rope_theta ' : ' text_config.rope_theta '
} ,
' model ' : {
' lm_ ' : ' language_model.lm_head ' ,
' embed_ ' : ' language_model.model.embed_tokens ' ,
' blocks_ ' : ' language_model.model.layers ' ,
' final_layernorm_ ' : ' language_model.model.norm ' ,
' visual ' : ' vision_model '
} ,
' decoder ' : {
' self_attn ' : ' self_attn ' ,
' cross_attn ' : ' cross_attn ' ,
' mlp ' : ' mlp ' ,
' input_layernorm ' : ' input_layernorm ' ,
' post_attention_layernorm ' : ' post_attention_layernorm '
} ,
' attention ' : {
' q_proj ' : ' q_proj ' ,
' k_proj ' : ' k_proj ' ,
' v_proj ' : ' v_proj ' ,
' o_proj ' : ' o_proj ' ,
' q_norm ' : ' q_norm ' ,
' k_norm ' : ' k_norm ' ,
' cross_attn_attn_gate ' : ' cross_attn_attn_gate ' ,
' cross_attn_mlp_gate ' : ' cross_attn_mlp_gate '
}
}
self . regist ( ' mllama ' , mllama_map )
2024-09-12 12:57:57 +08:00
def regist_qwen ( self ) :
qwen_map = {
' config ' : {
' hidden_size ' : ' hidden_size ' ,
' num_attention_heads ' : ' num_attention_heads ' ,
' num_hidden_layers ' : ' num_hidden_layers ' ,
' rope_theta ' : ' rotary_emb_base ' ,
} ,
' model ' : {
' lm_ ' : ' lm_head ' ,
' embed_ ' : ' transformer.wte ' ,
' blocks_ ' : ' transformer.h ' ,
' final_layernorm_ ' : ' transformer.ln_f ' ,
' visual ' : ' transformer.visual '
} ,
' decoder ' : {
' self_attn ' : ' attn ' ,
' mlp ' : ' mlp ' ,
' input_layernorm ' : ' ln_1 ' ,
' post_attention_layernorm ' : ' ln_2 '
} ,
' attention ' : {
' qkv_proj ' : ' c_attn ' ,
' o_proj ' : ' c_proj '
}
}
self . regist ( ' qwen ' , qwen_map )
def regist_glm ( self ) :
glm_map = {
' config ' : {
' hidden_size ' : ' hidden_size ' ,
' num_attention_heads ' : ' num_attention_heads ' ,
' num_hidden_layers ' : ' num_layers '
} ,
' model ' : {
' lm_ ' : ' lm_head ' ,
' embed_ ' : ' transformer.word_embeddings ' ,
' blocks_ ' : ' transformer.layers ' ,
' final_layernorm_ ' : ' transformer.final_layernorm ' ,
} ,
' decoder ' : {
' self_attn ' : ' attention ' ,
' mlp ' : ' mlp ' ,
' input_layernorm ' : ' input_layernorm ' ,
' post_attention_layernorm ' : ' post_attention_layernorm '
} ,
' attention ' : {
' qkv_proj ' : ' query_key_value ' ,
' o_proj ' : ' dense '
}
}
self . regist ( ' chatglm ' , glm_map )
def regist_glm2 ( self ) :
glm2_map = {
' config ' : {
' hidden_size ' : ' hidden_size ' ,
' num_attention_heads ' : ' num_attention_heads ' ,
' num_key_value_heads ' : ' multi_query_group_num ' ,
' num_hidden_layers ' : ' num_layers ' ,
} ,
' model ' : {
' lm_ ' : ' transformer.output_layer ' ,
' embed_ ' : ' transformer.embedding.word_embeddings ' ,
' blocks_ ' : ' transformer.encoder.layers ' ,
' final_layernorm_ ' : ' transformer.encoder.final_layernorm ' ,
} ,
' decoder ' : {
' self_attn ' : ' self_attention ' ,
' mlp ' : ' mlp ' ,
' input_layernorm ' : ' input_layernorm ' ,
' post_attention_layernorm ' : ' post_attention_layernorm '
} ,
' attention ' : {
' qkv_proj ' : ' query_key_value ' ,
' o_proj ' : ' dense '
}
}
self . regist ( ' chatglm2 ' , glm2_map )
def regist_phi ( self ) :
phi_map = {
' config ' : {
' hidden_size ' : ' n_embd ' ,
' num_attention_heads ' : ' n_head ' ,
' num_hidden_layers ' : ' n_layer ' ,
' rotary_dim ' : ' rotary_dim '
} ,
' model ' : {
' lm_ ' : ' lm_head.linear ' ,
' embed_ ' : ' transformer.embd.wte ' ,
' blocks_ ' : ' transformer.h ' ,
' final_layernorm_ ' : ' lm_head.ln ' ,
} ,
' decoder ' : {
' self_attn ' : ' mixer ' ,
' mlp ' : ' mlp ' ,
' input_layernorm ' : ' ln ' ,
} ,
' attention ' : {
' qkv_proj ' : ' Wqkv ' ,
' o_proj ' : ' out_proj '
}
}
self . regist ( ' phi-msft ' , phi_map )
def regist_gemma2 ( self ) :
gemma2_config = copy . deepcopy ( self . default_config )
gemma2_config [ ' head_dim ' ] = ' head_dim '
gemma2_decoder = copy . deepcopy ( self . default_decoder )
gemma2_decoder [ ' pre_feedforward_layernorm ' ] = ' pre_feedforward_layernorm '
gemma2_decoder [ ' post_feedforward_layernorm ' ] = ' post_feedforward_layernorm '
gemma2_map = {
' config ' : gemma2_config ,
' model ' : self . defualt_model ,
' decoder ' : gemma2_decoder ,
' attention ' : self . default_attention
}
self . regist ( ' gemma2 ' , gemma2_map )
2024-11-18 14:37:45 +08:00
def register_openelm ( self ) :
openelm_config = {
' hidden_size ' : ' model_dim ' ,
' head_dim ' : ' head_dim ' ,
' num_attention_heads ' : ' num_query_heads ' ,
' num_hidden_layers ' : ' num_transformer_layers ' ,
' num_key_value_heads ' : ' num_kv_heads ' ,
' rope_theta ' : ' rope_freq_constant '
}
openelm_model = {
' lm_ ' : ' lm_head ' ,
' embed_ ' : ' transformer.token_embeddings ' ,
' blocks_ ' : ' transformer.layers ' ,
' final_layernorm_ ' : ' transformer.norm '
}
openelm_decoder = {
' self_attn ' : ' attn ' ,
' mlp ' : ' ffn ' ,
' input_layernorm ' : ' attn_norm ' ,
' post_attention_layernorm ' : ' ffn_norm '
}
openelm_attention = {
' qkv_proj ' : ' qkv_proj ' ,
' o_proj ' : ' out_proj ' ,
' q_norm ' : ' q_norm ' ,
' k_norm ' : ' k_norm '
}
openelm_map = {
' config ' : openelm_config ,
' model ' : openelm_model ,
' decoder ' : openelm_decoder ,
' attention ' : openelm_attention
}
self . regist ( ' openelm ' , openelm_map )
2024-09-12 12:57:57 +08:00
def defualt_map ( self ) :
# default map is `LlamaForCausalLM`
self . config_key = ' config '
self . model_key = ' model '
self . decoder_key = ' decoder '
self . attention_key = ' attention '
self . default_config = {
' hidden_size ' : ' hidden_size ' ,
' num_attention_heads ' : ' num_attention_heads ' ,
' num_hidden_layers ' : ' num_hidden_layers ' ,
' num_key_value_heads ' : ' num_key_value_heads ' ,
' rope_theta ' : ' rope_theta '
}
self . defualt_model = {
' lm_ ' : ' lm_head ' ,
' embed_ ' : ' model.embed_tokens ' ,
' blocks_ ' : ' model.layers ' ,
' final_layernorm_ ' : ' model.norm ' ,
2024-09-12 20:19:02 +08:00
' visual ' : ' visual '
2024-09-12 12:57:57 +08:00
}
self . default_decoder = {
' self_attn ' : ' self_attn ' ,
' mlp ' : ' mlp ' ,
' input_layernorm ' : ' input_layernorm ' ,
' post_attention_layernorm ' : ' post_attention_layernorm '
}
self . default_attention = {
' q_proj ' : ' q_proj ' ,
' k_proj ' : ' k_proj ' ,
' v_proj ' : ' v_proj ' ,
' o_proj ' : ' o_proj '
}
self . default_map = {
' config ' : self . default_config ,
' model ' : self . defualt_model ,
' decoder ' : self . default_decoder ,
' attention ' : self . default_attention
}
@staticmethod
def do_map ( dst , src , map ) :
for dst_attr , src_attr in map . items ( ) :
attributes = src_attr . split ( ' . ' )
obj = src
for attr in attributes :
if hasattr ( obj , attr ) :
obj = getattr ( obj , attr )
else :
obj = None
break
setattr ( dst , dst_attr , obj )
2024-11-18 14:37:45 +08:00
# Quant class
# awq quantizer start
class AwqQuantizer :
def __init__ (
self ,
model ,
modules_to_not_convert = None ,
apply_clip = True ,
n_parallel_calib_samples = None ,
max_calib_samples = 128 ,
max_calib_seq_len = 512 ,
max_chunk_memory = 1024 * 1024 * 1024 ,
) - > None :
self . awq_model = model
self . model = model
self . tokenizer = model . tokenizer
self . w_bit = model . quant_bit
self . group_size = model . quant_block
self . zeropoint = not model . symmetric
self . calib_data = ' ag_news '
self . split = ' test '
self . duo_scaling = True
self . apply_clip = apply_clip
self . n_parallel_calib_samples = n_parallel_calib_samples
self . max_calib_samples = max_calib_samples
self . max_calib_seq_len = max_calib_seq_len
self . max_chunk_memory = max_chunk_memory
self . modules_to_not_convert = (
modules_to_not_convert if modules_to_not_convert is not None else [ ]
)
self . modules , self . module_kwargs , self . inps = self . init_quant (
n_samples = self . max_calib_samples , max_seq_len = self . max_calib_seq_len
)
def pseudo_quantize_tensor ( self , w : torch . Tensor ) :
org_w_shape = w . shape
if self . group_size > 0 :
assert org_w_shape [ - 1 ] % self . group_size == 0
w = w . reshape ( - 1 , self . group_size )
assert w . dim ( ) == 2
assert torch . isnan ( w ) . sum ( ) == 0
# zero point quantization
if self . zeropoint :
max_val = w . amax ( dim = 1 , keepdim = True )
min_val = w . amin ( dim = 1 , keepdim = True )
offset = 1 << ( self . w_bit - 1 )
clip_max = offset - 1
clip_min = - offset
scales = ( max_val - min_val ) / ( clip_max - clip_min )
zeros = - torch . round ( min_val / scales ) + clip_min
qw = torch . round ( w / scales ) + zeros
qw = torch . clamp ( qw , clip_min , clip_max )
w = ( qw - zeros ) * scales
zeros = min_val . view ( org_w_shape [ 0 ] , - 1 )
else :
abs_max = w . abs ( ) . amax ( dim = 1 , keepdim = True )
offset = 1 << ( self . w_bit - 1 )
clip_max = offset - 1
clip_min = - clip_max
scales = abs_max / clip_max
w = torch . clamp ( torch . round ( w / scales ) , clip_min , clip_max ) * scales
zeros = None
assert torch . isnan ( scales ) . sum ( ) == 0
assert torch . isnan ( w ) . sum ( ) == 0
scales = scales . view ( org_w_shape [ 0 ] , - 1 )
w = w . reshape ( org_w_shape )
return w , scales , zeros
def quantize ( self ) :
for i in tqdm ( range ( len ( self . modules ) ) , desc = " AWQ " ) :
# if i > 0: break
# Move module and inputs to correct device
common_device = next ( self . modules [ i ] . parameters ( ) ) . device
if common_device is None or str ( common_device ) == " cpu " :
best_device = AwqQuantizer . get_best_device ( )
self . modules [ i ] = self . modules [ i ] . to ( best_device )
common_device = next ( self . modules [ i ] . parameters ( ) ) . device
if self . module_kwargs . get ( " position_ids " ) is not None :
self . module_kwargs [ " position_ids " ] = self . module_kwargs [
" position_ids "
] . to ( common_device )
if self . module_kwargs . get ( " attention_mask " ) is not None :
self . module_kwargs [ " attention_mask " ] = self . module_kwargs [
" attention_mask "
] . to ( common_device )
self . inps = self . inps . to ( common_device )
# print(f'# {i} inps shape: {self.inps.shape}, inps.max: {self.inps.max()}')
# [STEP 1]: Get layer, extract linear modules, extract input features
named_linears = AwqQuantizer . get_named_linears ( self . modules [ i ] )
# Filter out the linear layers we don't want to exclude
named_linears = AwqQuantizer . exclude_layers_to_not_quantize (
named_linears , self . modules_to_not_convert
)
input_feat = self . _get_input_feat ( self . modules [ i ] , named_linears )
AwqQuantizer . clear_memory ( )
# [STEP 2]: Compute and apply scale list
module_config = [ ]
# q, k, v proj
module_config . append (
dict (
prev_op = self . modules [ i ] . input_layernorm ,
layers = [
self . modules [ i ] . self_attn . q_proj ,
self . modules [ i ] . self_attn . k_proj ,
self . modules [ i ] . self_attn . v_proj ,
] ,
inp = input_feat [ " self_attn.q_proj " ] ,
module2inspect = self . modules [ i ] . self_attn ,
kwargs = self . module_kwargs ,
)
)
# o_proj
if self . modules [ i ] . self_attn . v_proj . weight . shape == self . modules [ i ] . self_attn . o_proj . weight . shape :
module_config . append (
dict (
prev_op = self . modules [ i ] . self_attn . v_proj ,
layers = [ self . modules [ i ] . self_attn . o_proj ] ,
inp = input_feat [ " self_attn.o_proj " ] ,
)
)
# mlp gate
module_config . append (
dict (
prev_op = self . modules [ i ] . post_attention_layernorm ,
layers = [ self . modules [ i ] . mlp . gate_proj , self . modules [ i ] . mlp . up_proj ] ,
inp = input_feat [ " mlp.gate_proj " ] ,
module2inspect = self . modules [ i ] . mlp ,
)
)
# mlp down
module_config . append (
dict (
prev_op = self . modules [ i ] . mlp . up_proj ,
layers = [ self . modules [ i ] . mlp . down_proj ] ,
inp = input_feat [ " mlp.down_proj " ] ,
)
)
scales_list = [
self . _search_best_scale ( self . modules [ i ] , * * layer )
for layer in module_config
]
# print(scales_list); exit(0)
AwqQuantizer . apply_scale ( self . modules [ i ] , scales_list , input_feat_dict = input_feat )
# [STEP 3]: Compute and apply clipping list
if self . apply_clip :
clip_list = self . _search_best_clip (
self . modules [ i ] , named_linears , input_feat
)
AwqQuantizer . apply_clip ( self . modules [ i ] , clip_list )
AwqQuantizer . clear_memory ( )
@torch.no_grad ( )
def _module_forward (
self , x : torch . Tensor , module : torch . nn . Module , module_kwargs : Dict
) - > torch . Tensor :
if self . n_parallel_calib_samples is None :
# runs through all samples at once
# print(module, x, module_kwargs); exit(0)
module_output = module ( x , * * module_kwargs )
if isinstance ( module_output , tuple ) :
module_output = module_output [ 0 ]
else :
# memory efficiently runs through all calibration samples
# but only n_parallel_calib_samples at a time
module_output = [ ]
partitioned_inputs = torch . split ( x , self . n_parallel_calib_samples )
for x_partial in partitioned_inputs :
partial_output = module ( x_partial , * * module_kwargs )
if isinstance ( partial_output , tuple ) :
partial_output = partial_output [ 0 ]
module_output . append ( partial_output . cpu ( ) )
module_output = torch . cat ( module_output , dim = 0 )
return module_output
@torch.no_grad ( )
def _search_best_scale (
self ,
module ,
prev_op ,
layers : List [ torch . nn . Linear ] ,
inp : torch . Tensor ,
module2inspect = None ,
kwargs = { } ,
) :
if module2inspect is None :
assert len ( layers ) == 1
module2inspect = layers [ 0 ]
if " use_cache " in kwargs :
kwargs . pop ( " use_cache " )
# Put x on the right device
inp = inp . to ( next ( module2inspect . parameters ( ) ) . device )
# [STEP 1]: Compute per-channel mean of normalised weights
# All layer weights are concatted together
weight = torch . cat ( [ _m . weight for _m in layers ] , dim = 0 )
org_shape = weight . shape
# The weights are reshaped to be organised by quantization group
weight = weight . view ( - 1 , self . group_size )
# Calculates the relative magnitude of the weights within each of the quantization groups,
# and rescales each group individually so that each group has weights on a 0-1 scale.
w_scale = weight . abs ( ) / ( weight . abs ( ) . amax ( dim = 1 , keepdim = True ) + 1e-6 )
# Resizes the rescaled weight matrix back up to its original dimensions
w_scale = w_scale . view ( org_shape )
# Gets the average rescaled magnitude for each output channel
w_mean = w_scale . mean ( 0 )
AwqQuantizer . clear_memory ( weight )
# [STEP 2]: Compute per-channel mean of the input activation with chunking
# move inp to cpu to avoid memory leak
inp_flat = inp . cpu ( ) . abs ( ) . view ( - 1 , inp . shape [ - 1 ] )
num_elements = inp_flat . size ( 0 )
num_channels = inp_flat . size ( 1 )
element_size_bytes = inp_flat . element_size ( ) * 2 # multiplied by 2 for FP32
# Calculate chunk size dynamically based on max_chunk_memory
chunk_size = int ( self . max_chunk_memory / / ( element_size_bytes * num_channels ) )
chunk_size = min ( chunk_size , num_elements )
# Use float32 for sum calculation
x_sum = torch . zeros ( num_channels , dtype = torch . float32 , device = inp . device )
for i in range ( 0 , num_elements , chunk_size ) :
end = min ( i + chunk_size , num_elements )
chunk_sum = inp_flat [ i : end ] . to ( torch . float32 ) . sum ( dim = 0 )
x_sum + = chunk_sum . to ( inp . device )
x_mean = ( x_sum / num_elements ) . to ( inp . dtype )
AwqQuantizer . clear_memory ( x_sum )
# [STEP 3]: Compute output of module
with torch . no_grad ( ) :
module_kwargs = self . _sanitize_kwargs ( kwargs , module2inspect )
fp16_output = self . _module_forward ( inp , module2inspect , module_kwargs )
# [STEP 4]: Compute loss
best_scales = self . _compute_best_scale (
inp , w_mean , x_mean , module2inspect , layers , fp16_output , module_kwargs
)
return (
AwqQuantizer . get_op_name ( module , prev_op ) ,
tuple ( [ AwqQuantizer . get_op_name ( module , m ) for m in layers ] ) ,
best_scales ,
)
def _compute_best_scale (
self ,
x : torch . Tensor ,
w_mean : torch . Tensor ,
x_mean : torch . Tensor ,
module2inspect : torch . nn . Module ,
linears2scale : List [ torch . nn . Linear ] ,
fp16_output : torch . Tensor ,
kwargs : Dict = { } ,
) :
"""
Compute loss and select best scales
L ( s ) = | | Q ( W * s ) ( s ^ - 1 * X ) - W * X | |
Q : weight quantization function | pseudo_quantize_tensor ( W * s )
X : inputs from calib dataset | X
W : original weights in FP16 | layer
s : per channel scaling factor | s ^ - 1 * X
"""
n_grid = 20
history = [ ]
best_ratio = - 1
best_scales = None
best_error = float ( " inf " )
device = x . device
x_mean = x_mean . view ( - 1 ) . to ( device )
w_mean = w_mean . view ( - 1 ) . to ( device )
ord_weights = [ ]
for fc in linears2scale :
ord_weights . append ( fc . weight . data . clone ( ) )
for ratio in range ( n_grid ) :
# create new scales
ratio = ratio / n_grid
# NOTE: s^-1 * x is fused here, according to paper
if self . duo_scaling :
scales = ( x_mean . pow ( ratio ) / ( w_mean . pow ( 1 - ratio ) + 1e-4 ) ) . clamp ( min = 1e-4 )
else :
scales = x_mean . pow ( ratio ) . clamp ( min = 1e-4 ) . view ( - 1 )
scales = scales / ( scales . max ( ) * scales . min ( ) ) . sqrt ( )
scales_view = scales . view ( 1 , - 1 ) . to ( device )
# avoid scaling values that overflow
scales [ torch . isinf ( scales ) ] = 1
scales [ torch . isnan ( scales ) ] = 1
# Q(W * s)
for fc in linears2scale :
fc . weight . mul_ ( scales_view )
fc . weight . data = (
self . pseudo_quantize_tensor ( fc . weight . data ) [ 0 ] / scales_view
)
# W * X
int_w_output = self . _module_forward ( x , module2inspect , kwargs )
# compute mean squared error (L2 norm)
loss = self . _compute_loss ( fp16_output , int_w_output , device )
history . append ( loss )
if loss < best_error :
best_error = loss
best_ratio = ratio
best_scales = scales . clone ( )
for fc , ord_weight in zip ( linears2scale , ord_weights ) :
fc . weight . data = ord_weight . clone ( )
del ord_weights
if best_ratio == - 1 :
logging . debug ( history )
raise Exception
assert torch . isnan ( best_scales ) . sum ( ) == 0 , best_scales
return best_scales . detach ( ) . cpu ( )
@torch.no_grad ( )
def _compute_loss (
self ,
fp16_output : torch . Tensor ,
int_w_output : torch . Tensor ,
device : torch . device ,
) :
loss = 0.0
fp16_output_flat = fp16_output . view ( - 1 )
int_w_output_flat = int_w_output . view ( - 1 )
num_elements = fp16_output_flat . size ( 0 )
element_size_bytes = fp16_output . element_size ( )
# Calculate chunk size dynamically based on max_chunk_memory
# Divide the max_chunk_memory by twice the element size
chunk_size = self . max_chunk_memory / / ( element_size_bytes * 2 )
chunk_size = min ( chunk_size , num_elements )
# Split the computation into chunks
fp16_chunks = torch . split ( fp16_output_flat , chunk_size )
int_w_chunks = torch . split ( int_w_output_flat , chunk_size )
# Compute the loss for each chunk
for fp16_chunk , int_w_chunk in zip ( fp16_chunks , int_w_chunks ) :
chunk_loss = ( fp16_chunk . to ( device ) - int_w_chunk . to ( device ) ) . float ( ) . pow ( 2 ) . sum ( ) . item ( )
loss + = chunk_loss
# Normalize the loss by the total number of elements
loss / = num_elements
return loss
@torch.no_grad ( )
def _search_best_clip ( self , layer , named_linears , input_feat ) :
clip_list = [ ]
avoid_clipping = [ " q_ " , " k_ " , " query " , " key " , " Wqkv " ]
for name in named_linears :
# due to qk bmm, it is hard to clip precisely
if any ( [ _ in name for _ in avoid_clipping ] ) :
continue
named_linears [ name ] . to ( AwqQuantizer . get_best_device ( ) )
max_val = self . _compute_best_clip (
named_linears [ name ] . weight , input_feat [ name ]
)
clip_list . append ( ( name , max_val ) )
named_linears [ name ] . cpu ( )
return clip_list
@torch.no_grad ( )
def _compute_best_clip (
self ,
w : torch . Tensor ,
input_feat : torch . Tensor ,
n_grid = 20 ,
max_shrink = 0.5 ,
n_sample_token = 512 ,
) :
assert w . dim ( ) == 2
org_w_shape = w . shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = self . group_size if self . group_size > 0 else org_w_shape [ 1 ]
input_feat = input_feat . view ( - 1 , input_feat . shape [ - 1 ] )
input_feat = input_feat . reshape ( 1 , input_feat . shape [ 0 ] , - 1 , group_size )
# Compute input feature step size (minimum 1)
step_size = max ( 1 , input_feat . shape [ 1 ] / / n_sample_token )
input_feat = input_feat [ : , : : step_size ]
w = w . reshape ( org_w_shape [ 0 ] , 1 , - 1 , group_size )
oc_batch_size = 256 if org_w_shape [ 0 ] % 256 == 0 else 64 # prevent OOM
assert org_w_shape [ 0 ] % oc_batch_size == 0
w_all = w
best_max_val_all = [ ]
for i_b in range ( org_w_shape [ 0 ] / / oc_batch_size ) :
w = w_all [ i_b * oc_batch_size : ( i_b + 1 ) * oc_batch_size ]
org_max_val = w . abs ( ) . amax ( dim = - 1 , keepdim = True ) # co, 1, n_group, 1
best_max_val = org_max_val . clone ( )
min_errs = torch . ones_like ( org_max_val ) * 1e9
input_feat = input_feat . to ( w . device )
org_out = ( input_feat * w ) . sum ( dim = - 1 ) # co, n_token, n_group
for i_s in range ( int ( max_shrink * n_grid ) ) :
max_val = org_max_val * ( 1 - i_s / n_grid )
min_val = - max_val
cur_w = torch . clamp ( w , min_val , max_val )
q_w = self . pseudo_quantize_tensor ( cur_w ) [ 0 ]
cur_out = ( input_feat * q_w ) . sum ( dim = - 1 )
# co, 1, n_group, 1
err = ( cur_out - org_out ) . pow ( 2 ) . mean ( dim = 1 ) . view ( min_errs . shape )
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs [ cur_best_idx ] = err [ cur_best_idx ]
best_max_val [ cur_best_idx ] = max_val [ cur_best_idx ]
best_max_val_all . append ( best_max_val )
best_max_val = torch . cat ( best_max_val_all , dim = 0 )
AwqQuantizer . clear_memory ( input_feat )
AwqQuantizer . clear_memory ( org_out )
return best_max_val . squeeze ( 1 )
@staticmethod
@torch.no_grad ( )
def apply_clip ( module , clip_list : Tuple [ str , torch . Tensor ] ) :
for name , max_val in clip_list :
layer : torch . nn . Linear = AwqQuantizer . get_op_by_name ( module , name )
layer . to ( AwqQuantizer . get_best_device ( ) )
max_val = max_val . to ( layer . weight . device )
org_shape = layer . weight . shape
layer . weight . data = layer . weight . data . reshape ( * max_val . shape [ : 2 ] , - 1 )
layer . weight . data = torch . clamp ( layer . weight . data , - max_val , max_val )
layer . weight . data = layer . weight . data . reshape ( org_shape )
layer . cpu ( )
@staticmethod
@torch.no_grad ( )
def scale_fc_fcs ( fc1 : torch . nn . Linear , fcs : List [ torch . nn . Linear ] , scales : torch . Tensor ) :
if not isinstance ( fcs , list ) :
fcs = [ fcs ]
scales = scales . to ( fc1 . weight . device )
fc1 . weight [ - scales . size ( 0 ) : ] . div_ ( scales . view ( - 1 , 1 ) )
if fc1 . bias is not None :
fc1 . bias . div_ ( scales . view ( - 1 ) )
for fc in fcs :
fc . weight . mul_ ( scales . view ( 1 , - 1 ) )
for p in fc1 . parameters ( ) :
assert torch . isnan ( p ) . sum ( ) == 0
for fc in fcs :
for p in fc . parameters ( ) :
assert torch . isnan ( p ) . sum ( ) == 0
@staticmethod
def is_allowed_act_fns ( op ) :
from transformers . activations import NewGELUActivation , PytorchGELUTanh , GELUActivation
allowed_act_fns = [
torch . nn . GELU ,
NewGELUActivation ,
PytorchGELUTanh ,
GELUActivation ,
]
return ( op in allowed_act_fns )
@staticmethod
def is_allowed_norms ( op ) :
if isinstance ( op , torch . nn . LayerNorm ) :
return True
if any ( t in str ( type ( op ) ) for t in [ ' LlamaRMSNorm ' , ' GemmaRMSNorm ' , ' CohereLayerNorm ' ] ) :
return True
return False
@staticmethod
@torch.no_grad ( )
def scale_fc_fc ( fc1 : torch . nn . Linear , fc2 : torch . nn . Linear , scales : torch . Tensor ) :
assert isinstance ( fc1 , torch . nn . Linear )
assert isinstance ( fc2 , torch . nn . Linear )
scales = scales . to ( fc1 . weight . device )
fc1 . weight [ - scales . size ( 0 ) : ] . div_ ( scales . view ( - 1 , 1 ) )
if fc1 . bias is not None :
fc1 . bias . div_ ( scales . view ( - 1 ) )
fc2 . weight . mul_ ( scales . view ( 1 , - 1 ) )
for p in fc1 . parameters ( ) :
assert torch . isnan ( p ) . sum ( ) == 0
for p in fc2 . parameters ( ) :
assert torch . isnan ( p ) . sum ( ) == 0
@staticmethod
@torch.no_grad ( )
def scale_ln_fcs ( ln : torch . nn . Linear , fcs : List [ torch . nn . Linear ] , scales : torch . Tensor ) :
if not isinstance ( fcs , list ) :
fcs = [ fcs ]
scales = scales . to ( ln . weight . device )
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if ' GemmaRMSNorm ' in str ( type ( ln ) ) :
ln . weight + = 1
ln . weight . div_ ( scales )
ln . weight - = 1
else :
ln . weight . div_ ( scales )
if hasattr ( ln , " bias " ) and ln . bias is not None :
ln . bias . div_ ( scales )
for fc in fcs :
fc . weight . mul_ ( scales . view ( 1 , - 1 ) )
for p in ln . parameters ( ) :
assert torch . isnan ( p ) . sum ( ) == 0
for fc in fcs :
for p in fc . parameters ( ) :
assert torch . isnan ( p ) . sum ( ) == 0
@staticmethod
@torch.no_grad ( )
def scale_gelu_fc ( gelu , fc : torch . nn . Linear , scales : torch . Tensor ) :
assert AwqQuantizer . is_allowed_act_fns ( gelu )
assert isinstance ( fc , torch . nn . Linear )
fc . weight . mul_ ( scales . view ( 1 , - 1 ) . to ( fc . weight . device ) )
for p in fc . parameters ( ) :
assert torch . isnan ( p ) . sum ( ) == 0
@staticmethod
def apply_scale ( module , scales_list , input_feat_dict = None ) :
for prev_op_name , layer_names , scales in scales_list :
prev_op = AwqQuantizer . get_op_by_name ( module , prev_op_name )
layers = [ AwqQuantizer . get_op_by_name ( module , name ) for name in layer_names ]
best_device = AwqQuantizer . get_best_device ( )
prev_op . to ( best_device )
for layer in layers :
layer . to ( best_device )
scales . to ( best_device )
if (
isinstance ( prev_op , torch . nn . Linear )
and type ( layers ) == list
and isinstance ( layers [ 0 ] , torch . nn . Linear )
) :
if len ( layers ) == 1 :
AwqQuantizer . scale_fc_fc ( prev_op , layers [ 0 ] , scales )
else :
AwqQuantizer . scale_fc_fcs ( prev_op , layers , scales )
elif (
AwqQuantizer . is_allowed_norms ( prev_op )
or " rmsnorm " in str ( prev_op . __class__ ) . lower ( )
) :
AwqQuantizer . scale_ln_fcs ( prev_op , layers , scales )
elif AwqQuantizer . is_allowed_act_fns ( prev_op ) :
#new_module = ScaledActivation(prev_op, scales)
#set_op_by_name(module, prev_op_name, new_module)
AwqQuantizer . scale_gelu_fc ( prev_op , layers [ 0 ] , scales )
else :
raise NotImplementedError ( f " prev_op { type ( prev_op ) } not supported yet! " )
# apply the scaling to input feat if given; prepare it for clipping
if input_feat_dict is not None :
for layer_name in layer_names :
# Skip the modules that are not quantized
if layer_name in input_feat_dict :
inp = input_feat_dict [ layer_name ]
inp . div_ ( scales . view ( 1 , - 1 ) . to ( inp . device ) )
prev_op . cpu ( )
for layer in layers :
layer . cpu ( )
scales . cpu ( )
@staticmethod
def exclude_layers_to_not_quantize ( linear_layers , modules_to_not_convert ) :
if modules_to_not_convert is None :
return linear_layers
filtered_layers = { }
for name , linear_layer in linear_layers . items ( ) :
if not any ( key in name for key in modules_to_not_convert ) :
filtered_layers [ name ] = linear_layer
return filtered_layers
@staticmethod
def get_named_linears ( module ) :
return { name : m for name , m in module . named_modules ( ) if isinstance ( m , torch . nn . Linear ) }
@staticmethod
def get_op_by_name ( module , op_name ) :
# get the op by its name relative to the module
for name , m in module . named_modules ( ) :
if name == op_name :
return m
raise ValueError ( f " Cannot find op { op_name } in module { module } " )
@staticmethod
def get_calib_dataset (
data : Union [ str , List [ str ] , List [ List [ int ] ] ] = " pileval " ,
tokenizer = None ,
n_samples = 128 ,
max_seq_len = 512 ,
split = " train " ,
text_column = " text " ,
) :
if isinstance ( data , str ) :
from datasets import load_dataset
if data == " pileval " :
dataset = load_dataset ( " mit-han-lab/pile-val-backup " , split = " validation " )
else :
dataset = load_dataset ( data , split = split )
# dataset = dataset.shuffle(seed=42)
elif isinstance ( data , list ) :
if isinstance ( data [ 0 ] , str ) :
dataset = [ { text_column : text } for text in data ]
elif isinstance ( data [ 0 ] [ 0 ] , int ) :
dataset = data
else :
raise NotImplementedError (
" Either pass a string to a huggingface dataset or a list "
" that is preprocessed with one sample of text per element "
" or a list of list of int for tokenized words. "
)
else :
raise NotImplementedError (
" Either pass a string to a huggingface dataset or a list "
" that is preprocessed with one sample of text per element "
" or a list of list of int for tokenized words. "
)
samples = [ ]
n_run = 0
for data in dataset :
if isinstance ( data , list ) :
line_encoded = data
else :
line = data [ text_column ]
line = line . strip ( )
line_encoded = tokenizer . encode ( line )
if len ( line_encoded ) > max_seq_len :
continue
sample = torch . tensor ( [ line_encoded ] )
if sample . numel ( ) == 0 :
continue
samples . append ( sample )
n_run + = 1
if n_run == n_samples :
break
# now concatenate all samples and split according to max sequence length
cat_samples = torch . cat ( samples , dim = 1 )
n_split = cat_samples . shape [ 1 ] / / max_seq_len
logging . debug ( f " * Split into { n_split } blocks " )
return [
cat_samples [ : , i * max_seq_len : ( i + 1 ) * max_seq_len ] for i in range ( n_split )
]
@staticmethod
def get_best_device ( ) :
if torch . backends . mps . is_available ( ) :
return " mps "
elif torch . cuda . is_available ( ) :
return " cuda:0 "
else :
return " cpu "
@staticmethod
def clear_memory ( weight = None ) :
if weight is not None :
del weight
gc . collect ( )
torch . cuda . empty_cache ( )
@staticmethod
def get_op_name ( module , op ) :
# get the name of the op relative to the module
for name , m in module . named_modules ( ) :
if m is op :
return name
raise ValueError ( f " Cannot find op { op } in module { module } " )
@staticmethod
def append_str_prefix ( x , prefix ) :
if isinstance ( x , str ) :
return prefix + x
elif isinstance ( x , tuple ) :
return tuple ( [ AwqQuantizer . append_str_prefix ( y , prefix ) for y in x ] )
elif isinstance ( x , list ) :
return [ AwqQuantizer . append_str_prefix ( y , prefix ) for y in x ]
else :
return x
def init_quant ( self , n_samples = 128 , max_seq_len = 512 ) :
modules = self . awq_model . blocks
samples = AwqQuantizer . get_calib_dataset (
data = self . calib_data ,
tokenizer = self . tokenizer ,
n_samples = n_samples ,
max_seq_len = max_seq_len ,
split = self . split
)
# samples = torch.cat(samples, dim=0)
samples = torch . cat ( samples [ : 1 ] , dim = 0 ) # just using 1 batch
inps = [ ]
layer_kwargs = { }
# build inps
self . model . seq_len = samples . numel ( )
self . model . context_len = samples . numel ( ) - 2
self . model . token_len = 0
best_device = AwqQuantizer . get_best_device ( )
inps = self . model . embedding ( samples ) . to ( best_device )
position_ids = self . model . get_position_ids ( )
rotary_pos_emb = self . model . rotary ( position_ids )
attention_mask = self . model . get_attention_mask ( )
layer_kwargs [ " rotary_pos_emb " ] = rotary_pos_emb . to ( best_device )
layer_kwargs [ " attention_mask " ] = attention_mask . to ( best_device )
del samples
AwqQuantizer . clear_memory ( )
return modules , layer_kwargs , inps
def _get_input_feat ( self , layer , named_linears ) :
# firstly, get input features of all linear layers
def cache_input_hook ( m , x , y , name , feat_dict ) :
x = x [ 0 ]
x = x . detach ( ) . cpu ( )
feat_dict [ name ] . append ( x )
input_feat = defaultdict ( list )
handles = [ ]
for name in named_linears :
handles . append (
named_linears [ name ] . register_forward_hook (
functools . partial ( cache_input_hook , name = name , feat_dict = input_feat )
)
)
self . inps = self . inps . to ( next ( layer . parameters ( ) ) . device ) # in case multi-gpu
# get output as next layer's input
# Sanitize the kwargs in case we use transformers version that contains
# kwargs that are not handled by the module.
# Useful for trust_remote_code models.
module_kwargs = self . _sanitize_kwargs ( self . module_kwargs , layer )
self . inps = self . _module_forward ( self . inps , layer , module_kwargs )
for h in handles :
h . remove ( )
# now solve for scaling and clipping
input_feat = { k : torch . cat ( v , dim = 0 ) for k , v in input_feat . items ( ) }
return input_feat
def _sanitize_kwargs ( self , inputs_kwargs , module ) :
"""
Remove the arguments that are not supported in the module ' s
forward pass to avoid breaking behaviour between different versions
of transformers .
Args :
inputs_kwargs ( ` dict ` ) :
The input dictionary to pass to the model layer
module ( ` torch . nn . Module ` ) :
Target module to quantize .
"""
module_signature = inspect . signature ( module . forward ) . parameters
sanitized_kwargs = { }
for k , v in inputs_kwargs . items ( ) :
if k in module_signature :
sanitized_kwargs [ k ] = v
return sanitized_kwargs
# awq quantizer end
2024-09-12 12:57:57 +08:00
# Export class
2024-11-18 14:37:45 +08:00
# custom op start
class FakeLinearOp ( torch . autograd . Function ) :
2024-09-12 12:57:57 +08:00
@staticmethod
def symbolic ( g , input , in_features , out_features , has_bias , name ) :
# These become the operator attributes.
kwargs = {
" in_features_i " : in_features ,
" out_features_i " : out_features ,
" has_bias_i " : has_bias ,
" name_s " : name
}
from torch . onnx . symbolic_helper import _get_tensor_sizes
out_sizes = _get_tensor_sizes ( input ) [ : - 1 ] + [ out_features ]
output_type = input . type ( ) . with_sizes ( out_sizes )
return g . op ( " LlmExporter::FakeLinear " , input , * * kwargs ) . setType ( output_type )
@staticmethod
def forward ( ctx , input , in_features , out_features , has_bias , name ) :
out_shape = list ( input . shape ) [ : - 1 ] + [ out_features ]
return input . new_zeros ( out_shape )
class FakeLinear ( torch . nn . Module ) :
def __init__ ( self , in_features , out_features , has_bias , name ) :
super ( FakeLinear , self ) . __init__ ( )
self . in_features = in_features
self . out_features = out_features
self . has_bias = has_bias
self . name = name
def forward ( self , x ) :
2024-11-18 14:37:45 +08:00
return FakeLinearOp . apply ( x , self . in_features , self . out_features , self . has_bias , self . name )
class FusedAttentionOp ( torch . autograd . Function ) :
@staticmethod
def symbolic ( g , query , key , value , attention_mask , hidden_size , name ) :
# These become the operator attributes.
kwargs = {
" hidden_size_i " : hidden_size ,
" name_s " : name
}
from torch . onnx . symbolic_helper import _get_tensor_sizes
out_sizes = _get_tensor_sizes ( query )
output_type = query . type ( ) . with_sizes ( out_sizes )
return g . op ( " LlmExporter::FusedAttention " , query , key , value , attention_mask , * * kwargs ) . setType ( output_type )
@staticmethod
def forward ( ctx , query , key , value , attention_mask , hidden_size , name ) :
out_shape = list ( query . shape ) [ : 2 ] + [ hidden_size ]
return query . new_zeros ( out_shape )
class FusedAttention ( torch . nn . Module ) :
def __init__ ( self , hidden_size , name ) :
super ( FusedAttention , self ) . __init__ ( )
self . hidden_size = hidden_size
self . name = name
def forward ( self , query , key , value , attention_mask ) :
return FusedAttentionOp . apply ( query , key , value , attention_mask , self . hidden_size , self . name )
# custom op end
2024-09-12 12:57:57 +08:00
class OnnxRebuilder :
def __init__ ( self , onnx_path , weight_ops ) :
self . weight_ops = weight_ops
self . onnx_model = onnx . load ( onnx_path )
self . dst_path = onnx_path
self . onnx_weight_path = f ' { onnx_path } .data '
self . onnx_weight_offset = 0
def make_external ( self , name , data , shape ) :
# write to external weight
length = self . onnx_weight . write ( data . tobytes ( ) )
location = os . path . basename ( self . onnx_weight_path )
offset = self . onnx_weight_offset
self . onnx_weight_offset + = length
tensor = onnx . TensorProto ( )
tensor . name = name
tensor . data_type = onnx . TensorProto . FLOAT
tensor . dims . extend ( shape )
# external info
tensor . data_location = onnx . TensorProto . EXTERNAL
for k , v in { " location " : location , " offset " : offset , " length " : length } . items ( ) :
entry = tensor . external_data . add ( )
entry . key = k
entry . value = str ( v )
self . onnx_model . graph . initializer . append ( tensor )
def build_weight ( self , name , has_bias , ic , oc ) :
assert ( name in self . weight_ops )
linear = self . weight_ops [ name ]
assert ( linear . in_features == ic and
linear . out_features == oc and
( linear . bias is not None ) == has_bias )
weight_name , bias_name = f ' { name } _weight ' , f ' { name } _bias '
weight = linear . weight . data . transpose ( 1 , 0 ) . flatten ( ) . numpy ( )
self . make_external ( weight_name , weight , [ ic , oc ] )
if has_bias :
bias = linear . bias . data . flatten ( ) . numpy ( )
self . make_external ( bias_name , bias , [ oc ] )
return weight_name , bias_name
def rebuild ( self ) :
from onnx import helper
new_nodes = [ ]
self . onnx_weight = open ( self . onnx_weight_path , ' wb ' )
for node in self . onnx_model . graph . node :
if node . op_type == ' FakeLinear ' :
attributes = { a . name : a for a in node . attribute }
name = attributes . get ( ' name ' ) . s . decode ( ' utf-8 ' )
has_bias = attributes . get ( ' has_bias ' ) . i
ic = attributes . get ( ' in_features ' ) . i
oc = attributes . get ( ' out_features ' ) . i
weight , bias = self . build_weight ( name , has_bias , ic , oc )
if has_bias :
# fakelinear -> matmul + add
middle_tensor = f ' { name } _matmul '
new_nodes . append ( helper . make_node ( ' MatMul ' , [ node . input [ 0 ] , weight ] , [ middle_tensor ] , name ) )
2024-11-18 14:37:45 +08:00
new_nodes . append ( helper . make_node ( ' Add ' , [ middle_tensor , bias ] , node . output , f ' { name } /Add ' ) )
2024-09-12 12:57:57 +08:00
else :
# fakelinear -> matmul
new_nodes . append ( helper . make_node ( ' MatMul ' , [ node . input [ 0 ] , weight ] , node . output , name ) )
else :
new_nodes . append ( node )
self . onnx_weight . close ( )
del self . onnx_model . graph . node [ : ]
self . onnx_model . graph . node . extend ( new_nodes )
onnx . save ( self . onnx_model , self . dst_path )
return self . onnx_weight_path
class MNNConveter :
def __init__ ( self , onnx_path , weight_ops , config ) :
self . weight_ops = weight_ops
2024-11-18 14:37:45 +08:00
self . config = config
2024-09-12 12:57:57 +08:00
self . quant_block = config . quant_block
self . quant_bit = config . quant_bit
self . lm_quant_bit = config . lm_quant_bit
2024-11-18 14:37:45 +08:00
self . symmetric = config . symmetric
2024-09-12 12:57:57 +08:00
self . mnn_weight_offset = 0
self . onnx_model_path = onnx_path
2024-09-12 20:19:02 +08:00
self . mnn_name = os . path . basename ( onnx_path ) . replace ( ' .onnx ' , ' .mnn ' )
self . mnn_model_path = os . path . join ( config . dst_path , self . mnn_name )
2024-09-12 12:57:57 +08:00
self . mnn_weight_path = f ' { self . mnn_model_path } .weight '
if os . path . exists ( config . mnnconvert ) :
self . mnnconvert = config . mnnconvert
else :
self . mnnconvert = None
def convert ( self , convert_args ) :
sfd = os . dup ( 1 )
log_fp = open ( EXPORT_LOG , " a " )
log_fd = log_fp . fileno ( )
2024-09-12 20:19:02 +08:00
# mnnconvert ... > .export.log
2024-09-12 12:57:57 +08:00
os . dup2 ( log_fd , 1 )
try :
sys . argv = convert_args
sys . argc = len ( convert_args )
if self . mnnconvert is None :
from MNN . tools import mnnconvert
mnnconvert . main ( )
else :
convert_args [ 0 ] = self . mnnconvert
cmd = ' ' . join ( convert_args )
message = os . popen ( cmd ) . read ( )
print ( message )
sys . argv = [ ]
finally :
os . dup2 ( sfd , 1 )
os . close ( log_fd )
@spinner_run ( f ' convert onnx model to ' )
def onnx2mnn ( self , onnx_path , mnn_path , args = [ ] ) :
convert_args = [
' ' ,
' -f ' ,
' ONNX ' ,
' --modelFile ' ,
str ( onnx_path ) ,
' --MNNModel ' ,
str ( mnn_path ) ,
' --transformerFuse ' ,
' --allowCustomOp '
]
convert_args + = args
self . convert ( convert_args )
return mnn_path
def mnn2json ( self , mnn_path , json_path ) :
convert_args = [
' ' ,
' -f ' ,
' MNN ' ,
' --modelFile ' ,
str ( mnn_path ) ,
' --JsonFile ' ,
str ( json_path )
]
self . convert ( convert_args )
return json_path
def json2mnn ( self , json_path , mnn_path ) :
convert_args = [
' ' ,
' -f ' ,
' JSON ' ,
' --modelFile ' ,
str ( json_path ) ,
' --MNNModel ' ,
str ( mnn_path )
]
self . convert ( convert_args )
return mnn_path
2024-09-12 20:19:02 +08:00
def export ( self , quant_bit = None , quant_block = None ) :
2024-09-12 12:57:57 +08:00
if self . weight_ops is None :
2024-09-12 20:19:02 +08:00
if quant_bit is None :
quant_bit = self . quant_bit
if quant_block is None :
quant_block = self . quant_block
if quant_bit == 16 :
quant_args = [ ' --fp16 ' ]
else :
quant_args = [
' --weightQuantBits ' ,
str ( quant_bit ) ,
' --weightQuantBlock ' ,
str ( quant_block )
]
2024-09-12 12:57:57 +08:00
self . onnx2mnn ( self . onnx_model_path , self . mnn_model_path , quant_args )
else :
mnn_json = f ' { self . mnn_model_path } .json '
self . onnx2mnn ( self . onnx_model_path , self . mnn_model_path )
self . mnn2json ( self . mnn_model_path , mnn_json )
self . rebuild ( mnn_json )
self . json2mnn ( mnn_json , self . mnn_model_path )
@spinner_run ( f ' quant model weight to ' )
def rebuild ( self , json_path ) :
self . mnn_weight = open ( self . mnn_weight_path , ' wb ' )
mnn_graph = json . load ( open ( json_path , ' rt ' ) )
new_ops = [ ]
for op in mnn_graph [ ' oplists ' ] :
if op [ ' type ' ] == ' Extra ' :
new_ops + = self . rebuild_op ( op , mnn_graph )
else :
new_ops . append ( op )
mnn_graph [ ' oplists ' ] = new_ops
with open ( json_path , ' w ' , encoding = ' utf-8 ' ) as file :
json . dump ( mnn_graph , file , ensure_ascii = False , indent = 4 )
return self . mnn_weight_path
2024-11-18 14:37:45 +08:00
def quant ( self , weight , quant_bit , quant_block , symmetric ) :
2024-09-12 12:57:57 +08:00
weight = weight . numpy ( )
oc , ic = weight . shape
if quant_block == 0 :
block_size = ic
else :
block_size = quant_block
block_num = ic / / block_size
weight = weight . reshape ( oc , block_num , block_size )
offset = 1 << ( quant_bit - 1 )
clip_max = offset - 1
2024-11-18 14:37:45 +08:00
if symmetric :
clip_min = - clip_max
abs_max = np . max ( np . abs ( weight ) , axis = - 1 , keepdims = True )
scale = abs_max / clip_max
q_weight = np . round ( weight / scale )
q_weight = ( np . clip ( q_weight . flatten ( ) , clip_min , clip_max ) + offset ) . astype ( np . uint8 )
alpha = scale . flatten ( )
else :
clip_min = - offset
max_val = np . max ( weight , axis = - 1 , keepdims = True )
min_val = np . min ( weight , axis = - 1 , keepdims = True )
scale = ( max_val - min_val ) / ( clip_max - clip_min )
if False :
q_weight = np . round ( ( weight - min_val ) / scale ) + clip_min
zeros = min_val - scale * clip_min
else :
q_weight = np . round ( weight / scale ) - np . round ( min_val / scale ) + clip_min
zeros = ( np . round ( min_val / scale ) - clip_min ) * scale
q_weight = ( np . clip ( q_weight . flatten ( ) , clip_min , clip_max ) + offset ) . astype ( np . uint8 )
alpha = np . stack ( [ zeros . flatten ( ) , scale . flatten ( ) ] , axis = - 1 ) . flatten ( )
2024-09-12 12:57:57 +08:00
q_weight = q_weight . reshape ( - 1 , 2 )
if quant_bit == 4 :
q_weight = q_weight [ : , 0 ] * 16 + q_weight [ : , 1 ]
2024-11-18 14:37:45 +08:00
clip_min = 1
2024-09-12 12:57:57 +08:00
return q_weight , alpha , clip_min
def write_npy ( self , data ) :
return self . mnn_weight . write ( data . tobytes ( ) )
def write_header ( self , ic , oc , quant_bit ) :
dim_num = self . mnn_weight . write ( b ' \x02 ' )
shape_dtype = np . int16
if oc > 65535 or ic > 65535 :
shape_dtype = np . int32
dim_length = self . write_npy ( np . array ( [ oc , ic ] ) . astype ( shape_dtype ) )
offset = 1 << ( quant_bit - 1 )
weight_map = [ i for i in range ( - offset , offset ) ]
if len ( weight_map ) == 256 :
weight_map . insert ( 0 , 0 )
else :
weight_map . insert ( 0 , len ( weight_map ) )
map_length = self . write_npy ( np . array ( weight_map , dtype = np . int8 ) )
header_length = dim_num + dim_length + map_length
return header_length , shape_dtype == np . int32
2024-11-18 14:37:45 +08:00
def build_weight ( self , linear , quant_bit , quant_block , symmetric ) :
2024-09-12 12:57:57 +08:00
ic , oc = linear . in_features , linear . out_features
2024-11-18 14:37:45 +08:00
if quant_bit == 16 :
half_weight = linear . weight . data . half ( ) . flatten ( ) . numpy ( )
weight_len = self . write_npy ( half_weight )
alpha_len , q_min , shape_int32 = 0 , 0 , False
else :
assert ( quant_bit in ( 4 , 8 ) )
q_weight , alpha , q_min = self . quant ( linear . weight . data , quant_bit , quant_block , symmetric )
header_len , shape_int32 = self . write_header ( ic , oc , quant_bit )
weight_len = self . write_npy ( q_weight ) + header_len
alpha_len = self . write_npy ( alpha )
2024-09-12 12:57:57 +08:00
if linear . bias is not None :
bias = linear . bias . data . flatten ( ) . numpy ( )
bias_length = self . write_npy ( bias )
else :
bias_length = 0
# bias = np.zeros([oc], dtype=np.float32)
# bias_length = self.write_npy(bias)
external = [ self . mnn_weight_offset , weight_len , alpha_len , bias_length , 0 ]
self . mnn_weight_offset + = ( weight_len + alpha_len + bias_length )
2024-11-18 14:37:45 +08:00
return external , q_min , shape_int32 , header_len
2024-09-12 12:57:57 +08:00
def build_tensor ( self , graph , tensor_name ) :
tensor_idx = [ len ( graph [ ' tensorName ' ] ) ]
graph [ ' tensorName ' ] . append ( tensor_name )
return tensor_idx
def rebuild_op ( self , op , graph ) :
2024-11-18 14:37:45 +08:00
op_type = op [ ' main ' ] [ ' type ' ]
if op_type == ' FakeLinear ' :
return self . rebuild_linear ( op , graph )
if op_type == ' FusedAttention ' :
return self . rebuild_attnention ( op , graph )
def rebuild_attnention ( self , op , graph ) :
attrs = op [ ' main ' ] [ ' attr ' ]
for attr in attrs :
if attr [ ' key ' ] == ' name ' :
name = attr [ ' s ' ]
origin_input = op [ ' inputIndexes ' ]
origin_output = op [ ' outputIndexes ' ]
fused_attention = {
" inputIndexes " : origin_input ,
" main_type " : " AttentionParam " ,
" main " : { " kv_cache " : True } ,
" name " : name ,
" outputIndexes " : origin_output ,
" type " : " Attention " ,
" defaultDimentionFormat " : " NHWC "
}
return [ fused_attention ]
def rebuild_linear ( self , op , graph ) :
2024-09-12 12:57:57 +08:00
attrs = op [ ' main ' ] [ ' attr ' ]
for attr in attrs :
if attr [ ' key ' ] == ' name ' :
name = attr [ ' s ' ]
elif attr [ ' key ' ] == " in_features " :
ic = attr [ " i " ]
elif attr [ ' key ' ] == " out_features " :
oc = attr [ " i " ]
elif attr [ ' key ' ] == " has_bias " :
has_bias = attr [ " i " ]
linear = self . weight_ops [ name ]
assert ( linear . in_features == ic and
linear . out_features == oc and
( linear . bias is not None ) == has_bias )
2024-11-18 14:37:45 +08:00
is_lm = ' lm_head ' in name
quant_bit = self . lm_quant_bit if is_lm else self . quant_bit
block_size = ic if self . quant_block == 0 else self . quant_block
external , q_min , shape_int32 , header_len = self . build_weight ( linear , quant_bit , self . quant_block , self . symmetric )
if is_lm and self . config . tie_word_embeddings :
weight_offset = external [ 0 ] + header_len
alpha_offset = external [ 0 ] + external [ 1 ]
alpha_size = external [ 2 ]
self . config . llm_config [ ' tie_embeddings ' ] = [ weight_offset , alpha_offset , alpha_size , quant_bit , self . quant_block ]
2024-09-12 12:57:57 +08:00
origin_input = op [ ' inputIndexes ' ]
origin_output = op [ ' outputIndexes ' ]
# build new tensor
pre_reshape_name = f ' { name } /pre_reshape '
pre_convert_name = f ' { name } /pre_convert '
conv_name = name
post_convert_name = f ' { name } /post_convert '
post_reshape_name = f ' { name } /post_reshape '
pre_reshape_output = self . build_tensor ( graph , pre_reshape_name )
pre_convert_output = self . build_tensor ( graph , pre_convert_name )
conv_output = self . build_tensor ( graph , conv_name )
post_convert_output = self . build_tensor ( graph , post_convert_name )
# [batch, seq, hidden_size_i] -[Linear] -> [batch, seq, hidden_size_o]
# [1, seq, hidden_size_i] ->[Reshape]-> [seq, hidden_size_i, 1, 1]
# -[Convert]-[Convolution]-[Convert]-> [Reshape] -> [1, seq, hidden_size_o]
pre_reshape = {
" name " : pre_reshape_name ,
" type " : " Reshape " ,
" inputIndexes " : origin_input ,
" outputIndexes " : pre_reshape_output ,
" main_type " : " Reshape " ,
" main " : {
" dims " : [ - 1 , ic , 1 , 1 ] ,
" dimType " : " NCHW "
} ,
" defaultDimentionFormat " : " NHWC "
}
pre_convert = {
" name " : pre_convert_name ,
" inputIndexes " : pre_reshape_output ,
" outputIndexes " : pre_convert_output ,
" type " : " ConvertTensor " ,
" main_type " : " TensorConvertInfo " ,
" main " : {
" source " : " NCHW " ,
" dest " : " NC4HW4 "
} ,
" defaultDimentionFormat " : " NHWC "
}
2024-11-18 14:37:45 +08:00
if quant_bit == 16 :
quanParameter = { " type " : 3 }
else :
if self . symmetric :
aMin = 0
readType = 0
else :
aMin = q_min
readType = oc * ( ic / / block_size )
quanParameter = {
" quantScale " : 1.0 , " scaleIn " : 0.0 , " scaleOut " : 0.0 ,
" useInt32 " : False , " has_scaleInt " : False , " shapeInt32 " : shape_int32 ,
" type " : 1 , " aMax " : 0 , " aMin " : aMin , " readType " : readType , " weightSize " : 0
}
2024-09-12 12:57:57 +08:00
conv_op = {
" name " : conv_name ,
" inputIndexes " : pre_convert_output ,
" outputIndexes " : conv_output ,
" type " : " Convolution " ,
" main_type " : " Convolution2D " ,
" main " : {
' common ' : {
' dilateX ' : 1 , ' dilateY ' : 1 , ' strideX ' : 1 , ' strideY ' : 1 ,
' kernelX ' : 1 , ' kernelY ' : 1 , ' padX ' : 0 , ' padY ' : 0 , ' group ' : 1 ,
' outputCount ' : oc , ' relu ' : False , ' padMode ' : ' CAFFE ' ,
' relu6 ' : False , ' inputCount ' : ic , ' hasOutputShape ' : False
} ,
2024-11-18 14:37:45 +08:00
" quanParameter " : quanParameter ,
2024-09-12 12:57:57 +08:00
" external " : external
} ,
" defaultDimentionFormat " : " NHWC "
}
post_convert = {
" name " : post_convert_name ,
" inputIndexes " : conv_output ,
" outputIndexes " : post_convert_output ,
" type " : " ConvertTensor " ,
" main_type " : " TensorConvertInfo " ,
" main " : {
" source " : " NC4HW4 " ,
" dest " : " NCHW "
} ,
" defaultDimentionFormat " : " NHWC "
}
post_reshape = {
" name " : post_reshape_name ,
" type " : " Reshape " ,
" inputIndexes " : post_convert_output ,
" outputIndexes " : origin_output ,
" main_type " : " Reshape " ,
" main " : {
" dims " : [ 1 , - 1 , oc ] ,
" dimType " : " NCHW "
} ,
" defaultDimentionFormat " : " NHWC "
}
return [ pre_reshape , pre_convert , conv_op , post_convert , post_reshape ]
# some wrapper class for export
class Embedding ( torch . nn . Module ) :
def __init__ ( self , embed , config ) :
super ( ) . __init__ ( )
self . hidden_size = config . hidden_size
self . embed = embed
if config . model_type == ' gemma2 ' :
normalizer = torch . tensor ( self . hidden_size * * 0.5 )
self . embed . weight . data * = normalizer
def forward ( self , input_ids ) :
inputs_embeds = self . embed ( input_ids ) . view ( - 1 , 1 , self . hidden_size )
return inputs_embeds
def repeat_kv ( hidden_states : torch . Tensor , n_rep : int ) - > torch . Tensor :
batch , num_key_value_heads , slen , head_dim = hidden_states . shape
if n_rep == 1 :
return hidden_states
hidden_states = hidden_states [ : , : , None , : , : ] . expand ( batch , num_key_value_heads , n_rep , slen , head_dim )
return hidden_states . reshape ( batch , num_key_value_heads * n_rep , slen , head_dim )
class Attention ( torch . nn . Module ) :
2024-11-18 14:37:45 +08:00
def __init__ ( self , attn , layer_id , config ) :
2024-09-12 12:57:57 +08:00
super ( ) . __init__ ( )
2024-11-18 14:37:45 +08:00
self . export_fused_attn = False
self . fused_attn = FusedAttention ( config . hidden_size , f ' /layers. { layer_id } /self_attn/FusedAttention ' )
self . layer_id = layer_id
2024-09-12 12:57:57 +08:00
self . hidden_size = config . hidden_size
self . head_dim = config . head_dim
2024-11-18 14:37:45 +08:00
if isinstance ( config . num_attention_heads , list ) :
self . num_heads = config . num_attention_heads [ layer_id ]
self . num_key_value_heads = config . num_key_value_heads [ layer_id ]
else :
self . head_dim = config . head_dim
self . num_heads = config . num_attention_heads
self . num_key_value_heads = config . num_key_value_heads
2024-09-12 12:57:57 +08:00
self . num_key_value_groups = self . num_heads / / self . num_key_value_heads
self . rotary = config . rotary
2024-11-18 14:37:45 +08:00
2024-09-12 12:57:57 +08:00
ModelMapper . do_map ( self , attn , config . model_map [ ' attention ' ] )
2024-11-18 14:37:45 +08:00
2024-09-12 12:57:57 +08:00
if hasattr ( self , ' qkv_proj ' ) and self . qkv_proj is not None :
# split qkv linear to q, k, v
split_sizes = [ self . hidden_size ] * 3
if self . qkv_proj . weight . shape [ 0 ] != self . hidden_size * 3 :
# M/GQA
2024-11-18 14:37:45 +08:00
split_sizes = [
self . num_heads * self . head_dim , # q_size
self . num_key_value_heads * self . head_dim , # k_size
self . num_key_value_heads * self . head_dim # v_size
]
2024-09-12 12:57:57 +08:00
self . q_proj = torch . nn . Linear ( self . hidden_size , split_sizes [ 0 ] )
self . k_proj = torch . nn . Linear ( self . hidden_size , split_sizes [ 1 ] )
self . v_proj = torch . nn . Linear ( self . hidden_size , split_sizes [ 2 ] )
if config . model_type == ' chatglm ' :
# chatglm-6b
qkv_weight = self . qkv_proj . weight . data . view ( self . num_heads , 3 , self . head_dim , self . hidden_size )
self . q_proj . weight . data = qkv_weight [ : , 0 , : , : ] . reshape ( self . hidden_size , self . hidden_size )
self . k_proj . weight . data = qkv_weight [ : , 1 , : , : ] . reshape ( self . hidden_size , self . hidden_size )
self . v_proj . weight . data = qkv_weight [ : , 2 , : , : ] . reshape ( self . hidden_size , self . hidden_size )
qkv_bias = self . qkv_proj . bias . data . view ( self . num_heads , 3 , self . head_dim )
self . q_proj . bias . data = qkv_bias [ : , 0 , : ] . reshape ( self . hidden_size )
self . k_proj . bias . data = qkv_bias [ : , 1 , : ] . reshape ( self . hidden_size )
self . v_proj . bias . data = qkv_bias [ : , 2 , : ] . reshape ( self . hidden_size )
else :
# other
qw , kw , vw = torch . split ( self . qkv_proj . weight , split_sizes )
self . q_proj . weight . data = qw
self . k_proj . weight . data = kw
self . v_proj . weight . data = vw
if self . qkv_proj . bias is not None :
qb , kb , vb = torch . split ( self . qkv_proj . bias , split_sizes )
self . q_proj . bias . data = qb
self . k_proj . bias . data = kb
self . v_proj . bias . data = vb
2024-11-18 14:37:45 +08:00
else :
self . q_proj . bias . data = torch . zeros ( split_sizes [ 0 ] )
self . k_proj . bias . data = torch . zeros ( split_sizes [ 1 ] )
self . v_proj . bias . data = torch . zeros ( split_sizes [ 2 ] )
2024-09-12 12:57:57 +08:00
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
past_key_value : Optional [ Tuple [ torch . Tensor ] ] = None ,
rotary_pos_emb : Optional [ torch . Tensor ] = None ,
2024-11-18 14:37:45 +08:00
cross_attention_states : Optional [ torch . Tensor ] = None ,
2024-09-12 12:57:57 +08:00
) - > Tuple [ torch . Tensor , torch . Tensor ] :
bsz , q_len , _ = hidden_states . size ( )
query_states = self . q_proj ( hidden_states )
2024-11-18 14:37:45 +08:00
if cross_attention_states is not None :
hidden_states = cross_attention_states
2024-09-12 12:57:57 +08:00
key_states = self . k_proj ( hidden_states )
value_states = self . v_proj ( hidden_states )
query_states = query_states . view ( bsz , q_len , self . num_heads , self . head_dim )
key_states = key_states . view ( bsz , q_len , self . num_key_value_heads , self . head_dim )
value_states = value_states . view ( bsz , q_len , self . num_key_value_heads , self . head_dim )
2024-11-18 14:37:45 +08:00
# openelm model has qk_norm
if hasattr ( self , ' q_norm ' ) and self . q_norm is not None and \
hasattr ( self , ' k_norm ' ) and self . k_norm is not None :
query_states = self . q_norm ( query_states )
key_states = self . k_norm ( key_states )
2024-09-12 12:57:57 +08:00
kv_seq_len = key_states . shape [ 1 ]
if past_key_value is not None :
kv_seq_len + = past_key_value [ 0 ] . shape [ 1 ]
# rope
cos , sin = rotary_pos_emb [ 0 ] , rotary_pos_emb [ 1 ]
query_states = self . rotary . apply_rotary_pos ( query_states , cos , sin )
key_states = self . rotary . apply_rotary_pos ( key_states , cos , sin )
2024-11-18 14:37:45 +08:00
if self . export_fused_attn :
attn_output = self . fused_attn ( query_states , key_states , value_states , attention_mask )
attn_output = self . o_proj ( attn_output )
return attn_output , past_key_value
2024-09-12 12:57:57 +08:00
# kv cache
if past_key_value is not None :
past_key , past_value = past_key_value [ 0 ] , past_key_value [ 1 ]
key_states = torch . cat ( ( past_key , key_states ) , dim = 1 )
value_states = torch . cat ( ( past_value , value_states ) , dim = 1 )
past_key_value = torch . stack ( ( key_states , value_states ) )
query_states = query_states . transpose ( 1 , 2 )
key_states = key_states . permute ( [ 0 , 2 , 3 , 1 ] )
value_states = value_states . transpose ( 1 , 2 )
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv ( key_states , self . num_key_value_groups )
value_states = repeat_kv ( value_states , self . num_key_value_groups )
#------- attention ----------
# query_states @ key_states
attn_weights = torch . matmul ( query_states , key_states ) / math . sqrt ( self . head_dim )
# attention_mask
if attention_mask . dtype in ( torch . bool , torch . int32 ) :
# chatglm
attn_weights . masked_fill_ ( attention_mask , - 10000.0 )
else :
attn_weights = attn_weights + attention_mask
# upcast softmax to fp32
attn_weights = torch . nn . functional . softmax ( attn_weights , dim = - 1 , dtype = torch . float32 ) . to ( query_states . dtype )
# attn_weights @ value_states
attn_output = torch . matmul ( attn_weights , value_states )
attn_output = attn_output . transpose ( 1 , 2 ) . contiguous ( )
attn_output = attn_output . reshape ( bsz , q_len , - 1 )
attn_output = self . o_proj ( attn_output )
return attn_output , past_key_value
def rotate_half ( x ) :
x1 = x [ . . . , : x . shape [ - 1 ] / / 2 ]
x2 = x [ . . . , x . shape [ - 1 ] / / 2 : ]
return torch . cat ( ( - x2 , x1 ) , dim = - 1 )
class Rotary ( torch . nn . Module ) :
def __init__ ( self , config ) :
super ( ) . __init__ ( )
self . rope_theta = config . rope_theta
self . rotary_dim = config . head_dim
self . model_type = config . model_type
if hasattr ( config , ' rotary_dim ' ) :
self . rotary_dim = config . rotary_dim
if self . model_type == ' chatglm ' :
self . rotary_dim = config . head_dim / / 2
def forward ( self , position_ids ) :
theta = 1.0 / ( self . rope_theta * * ( torch . arange ( 0 , self . rotary_dim , 2 , dtype = torch . float32 ) / self . rotary_dim ) )
position_ids = position_ids . float ( ) . reshape ( - 1 , 1 )
idx_theta = position_ids * theta
rotary_pos_emb = torch . stack ( [ torch . cos ( idx_theta ) , torch . sin ( idx_theta ) ] )
if self . model_type != ' chatglm2 ' :
rotary_pos_emb = torch . cat ( ( rotary_pos_emb , rotary_pos_emb ) , dim = - 1 )
rotary_pos_emb = rotary_pos_emb . unsqueeze ( 2 ) . unsqueeze ( 1 )
return rotary_pos_emb
def apply_rotary_pos ( self , x , cos , sin ) :
if self . model_type == ' chatglm ' :
return self . chatglm_rotary_pos ( x , cos , sin )
if self . model_type == ' chatglm2 ' :
return self . chatglm2_rotary_pos ( x , cos , sin )
if self . model_type == ' phi-msft ' :
return self . phi_rotary_pos ( x , cos , sin )
return self . llama_rotary_pos ( x , cos , sin )
def llama_rotary_pos ( self , x , cos , sin ) :
x = ( x * cos ) + ( rotate_half ( x ) * sin )
return x
def phi_rotary_pos ( self , x , cos , sin ) :
x , x_pass = x [ . . . , : self . rotary_dim ] , x [ . . . , self . rotary_dim : ]
x = ( x * cos ) + ( rotate_half ( x ) * sin )
return torch . cat ( ( x , x_pass ) , dim = - 1 )
def chatglm2_rotary_pos ( self , x , cos , sin ) :
x , x_pass = x [ . . . , : self . rotary_dim ] , x [ . . . , self . rotary_dim : ]
b , s , n , h = x . shape
xshaped = x . view ( b , s , n , h / / 2 , 2 )
x = torch . concat (
[
xshaped [ . . . , 0 ] * cos - xshaped [ . . . , 1 ] * sin ,
xshaped [ . . . , 1 ] * cos + xshaped [ . . . , 0 ] * sin ,
] ,
- 1 ,
)
return torch . cat ( ( x , x_pass ) , dim = - 1 )
def chatglm_rotary_pos ( self , x , cos , sin ) :
seq = x . shape [ 1 ]
x1 , x2 = x [ . . . , : self . rotary_dim ] , x [ . . . , self . rotary_dim : ]
cos1 , sin1 = cos [ : , : seq , . . . ] , sin [ : , : seq , . . . ]
cos2 , sin2 = cos [ : , seq : , . . . ] , sin [ : , seq : , . . . ]
x1 = ( x1 * cos1 ) + ( rotate_half ( x1 ) * sin1 )
x2 = ( x2 * cos2 ) + ( rotate_half ( x2 ) * sin2 )
return torch . cat ( ( x1 , x2 ) , dim = - 1 )
class Decoder ( torch . nn . Module ) :
2024-11-18 14:37:45 +08:00
def __init__ ( self , decoder , layer_id , config ) :
2024-09-12 12:57:57 +08:00
super ( ) . __init__ ( )
2024-11-18 14:37:45 +08:00
self . cross_decoder = False
2024-09-12 12:57:57 +08:00
ModelMapper . do_map ( self , decoder , config . model_map [ ' decoder ' ] )
2024-11-18 14:37:45 +08:00
# mllama has cross_attn
if hasattr ( self , ' cross_attn ' ) and self . cross_attn is not None :
self . cross_decoder = True
self . self_attn = Attention ( self . cross_attn , layer_id , config )
else :
self . self_attn = Attention ( self . self_attn , layer_id , config )
2024-09-12 12:57:57 +08:00
self . hidden_size = config . hidden_size
# chatglm
self . alpha = ( 2 * config . num_hidden_layers ) * * 0.5 if config . model_type == ' chatglm ' else 1.0
def forward (
self ,
hidden_states : torch . Tensor ,
rotary_pos_emb : Optional [ torch . Tensor ] = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
past_key_value : Optional [ Tuple [ torch . Tensor ] ] = None ,
2024-11-18 14:37:45 +08:00
cross_attention_states : Optional [ torch . Tensor ] = None ,
cross_attention_mask : Optional [ torch . Tensor ] = None ,
2024-09-12 12:57:57 +08:00
) - > Tuple [ torch . FloatTensor , Optional [ Tuple [ torch . FloatTensor , torch . FloatTensor ] ] ] :
hidden_states = hidden_states . view ( 1 , - 1 , self . hidden_size )
residual = hidden_states
hidden_states = self . input_layernorm ( hidden_states )
norm_hidden_states = hidden_states
# Self Attention
hidden_states , present_key_value = self . self_attn (
hidden_states = hidden_states ,
rotary_pos_emb = rotary_pos_emb ,
attention_mask = attention_mask ,
past_key_value = past_key_value ,
2024-11-18 14:37:45 +08:00
cross_attention_states = cross_attention_states ,
2024-09-12 12:57:57 +08:00
)
# Fully Connected
if not hasattr ( self , ' post_attention_layernorm ' ) :
# phi
feed_forward_hidden_states = self . mlp ( norm_hidden_states )
hidden_states = hidden_states + feed_forward_hidden_states + residual
elif self . alpha != 1.0 :
# chatglm-6b
hidden_states = norm_hidden_states * self . alpha + hidden_states
mlp_input = self . post_attention_layernorm ( hidden_states )
mlp_output = self . mlp ( mlp_input )
hidden_states = mlp_input * self . alpha + mlp_output
elif hasattr ( self , ' pre_feedforward_layernorm ' ) :
# gemma2
hidden_states = self . post_attention_layernorm ( hidden_states )
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self . pre_feedforward_layernorm ( hidden_states )
hidden_states = self . mlp ( hidden_states )
hidden_states = self . post_feedforward_layernorm ( hidden_states )
hidden_states = residual + hidden_states
2024-11-18 14:37:45 +08:00
elif cross_attention_mask is not None :
hidden_states = residual + self . cross_attn_attn_gate . tanh ( ) * hidden_states
residual = hidden_states
hidden_states = self . post_attention_layernorm ( hidden_states )
hidden_states = self . mlp ( hidden_states )
hidden_states = cross_attention_mask * hidden_states
hidden_states = residual + self . cross_attn_mlp_gate . tanh ( ) * hidden_states
2024-09-12 12:57:57 +08:00
else :
# general
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self . post_attention_layernorm ( hidden_states )
hidden_states = self . mlp ( hidden_states )
hidden_states = residual + hidden_states
return hidden_states , present_key_value
class Lm ( torch . nn . Module ) :
def __init__ ( self , lm_ , final_layernorm_ , config ) :
super ( ) . __init__ ( )
self . final_layernorm = final_layernorm_
self . lm = lm_
self . hidden_size = config . hidden_size
2024-11-18 14:37:45 +08:00
self . ppl = config . ppl
2024-09-12 12:57:57 +08:00
def forward ( self , hidden_states ) :
2024-11-18 14:37:45 +08:00
if not self . ppl :
# just need last logit for predict next token
hidden_states = hidden_states . view ( - 1 , self . hidden_size ) [ - 1 ] . view ( 1 , 1 , self . hidden_size )
2024-09-12 12:57:57 +08:00
hidden_states = self . final_layernorm ( hidden_states )
m_logits = self . lm ( hidden_states )
return m_logits
2024-09-12 20:19:02 +08:00
class Visual ( torch . nn . Module ) :
def __init__ ( self , visual , base ) :
super ( ) . __init__ ( )
self . visual = visual . eval ( )
self . embed_ = base . embed
self . tokenizer = base . tokenizer
self . config = base . config
self . hidden_size = base . hidden_size
self . llm_config = base . llm_config
2024-11-18 14:37:45 +08:00
# mllama
self . cross_attention_states = None
self . cross_attention_mask = None
2024-09-12 20:19:02 +08:00
self . init_config ( )
self . load ( )
@staticmethod
def get_visual ( model_type ) :
visual_models = {
' qwen ' : QwenVisual ,
2024-11-18 14:37:45 +08:00
' qwen2_vl ' : Qwen2Visual ,
' mllama ' : MllamaVision
2024-09-12 20:19:02 +08:00
}
if model_type in visual_models :
return visual_models [ model_type ]
return None
def init_config ( self ) :
from transformers . image_utils import ( OPENAI_CLIP_MEAN , OPENAI_CLIP_STD )
self . llm_config [ ' is_visual ' ] = True
image_mean = np . array ( OPENAI_CLIP_MEAN ) * 255.0
image_norm = 1 / ( np . array ( OPENAI_CLIP_STD ) * 255.0 )
self . llm_config [ ' image_mean ' ] = image_mean . tolist ( )
self . llm_config [ ' image_norm ' ] = image_norm . tolist ( )
def load ( self ) :
raise NotImplementedError
def str_to_ids ( self , prompt ) :
input_ids = self . tokenizer ( prompt , return_tensors = " pt " ) [ ' input_ids ' ]
return input_ids
def forward ( self , images ) :
raise NotImplementedError
def embed ( self , input_ids , images = None , videos = None ) :
raise NotImplementedError
class QwenVisual ( Visual ) :
def __init__ ( self , visual , base ) :
self . quant_bit = 16
super ( ) . __init__ ( visual , base )
def load ( self ) :
self . image_start_id = self . config . visual [ ' image_start_id ' ]
self . image_size = self . config . visual [ ' image_size ' ]
self . llm_config [ ' is_visual ' ] = True
self . llm_config [ ' image_size ' ] = self . image_size
self . llm_config [ ' vision_start ' ] = self . tokenizer . img_start_id
self . llm_config [ ' vision_end ' ] = self . tokenizer . img_end_id
self . llm_config [ ' image_pad ' ] = self . tokenizer . img_pad_id
def forward ( self , images ) :
return self . visual ( images ) . transpose ( 1 , 0 )
def embed ( self , input_ids , images = None , videos = None ) :
if not torch . any ( input_ids == self . image_start_id ) :
return self . embed_ ( input_ids )
bos_pos = torch . where ( input_ids == self . image_start_id )
eos_pos = torch . where ( input_ids == self . image_start_id + 1 )
img_pos = torch . stack ( ( bos_pos [ 0 ] , bos_pos [ 1 ] , eos_pos [ 1 ] ) , dim = 1 )
images = [ ]
for i , a , b in img_pos :
image = input_ids [ i ] [ a + 1 : b - 1 ] . tolist ( )
image = image [ : image . index ( self . image_start_id + 2 ) ]
images . append ( bytes ( image ) . decode ( ' utf-8 ' ) )
images = self . visual . encode ( images ) . transpose ( 1 , 0 )
hidden_states = self . embed_ ( input_ids )
for idx , ( i , a , b ) in enumerate ( img_pos ) :
hidden_states [ a + 1 : b , i ] = images [ : , idx ]
return hidden_states
class Qwen2Visual ( Visual ) :
def __init__ ( self , visual , base ) :
self . quant_bit = 4
self . temporal_patch_size = 2
self . patch_size = 14
self . merge_size = 2
self . image_size = 420
self . image_embeds = None
super ( ) . __init__ ( visual , base )
def load ( self ) :
self . vision_start_id = self . config . vision_start_token_id
self . vision_end_id = self . config . vision_end_token_id
self . image_pad_id = self . config . image_token_id
self . llm_config [ ' image_size ' ] = self . image_size
self . llm_config [ ' vision_start ' ] = self . vision_start_id
self . llm_config [ ' vision_end ' ] = self . vision_end_id
self . llm_config [ ' image_pad ' ] = self . image_pad_id
def str_to_ids ( self , prompt ) :
if ' <img> ' in prompt and ' </img> ' in prompt :
import re
import requests
from PIL import Image
pattern = r ' (<img>.*?</img>) '
parts = re . split ( pattern , prompt )
txt_prompt = ' '
for part in parts :
if re . match ( pattern , part ) :
img_content = re . search ( r ' <img>(.*?)</img> ' , part ) . group ( 1 )
if img_content . startswith ( ' http:// ' ) or img_content . startswith ( ' https:// ' ) :
image_obj = Image . open ( requests . get ( img_content , stream = True ) . raw )
img_pad_len = self . img_process ( image_obj )
img_pad_str = ' <|image_pad|> ' * img_pad_len
img_str = f ' <|vision_start|> { img_pad_str } <|vision_end|> '
txt_prompt + = img_str
else :
txt_prompt + = part
else :
txt_prompt = prompt
input_ids = self . tokenizer ( txt_prompt , return_tensors = " pt " ) [ ' input_ids ' ]
return input_ids
def forward ( self , images ) :
images = [ images ] * self . temporal_patch_size
patches = torch . concat ( images , axis = 0 )
_ , channel , height , width = patches . shape
grid_t = patches . shape [ 0 ] / / self . temporal_patch_size
grid_h , grid_w = height / / self . patch_size , width / / self . patch_size
patches = patches . reshape (
grid_t ,
self . temporal_patch_size ,
channel ,
grid_h / / self . merge_size ,
self . merge_size ,
self . patch_size ,
grid_w / / self . merge_size ,
self . merge_size ,
self . patch_size ,
)
patches = patches . permute ( 0 , 3 , 6 , 4 , 7 , 2 , 1 , 5 , 8 )
flatten_patches = patches . reshape (
grid_t * grid_h * grid_w , channel * self . temporal_patch_size * self . patch_size * self . patch_size
)
image_grid_thw = torch . tensor ( [ [ grid_t , grid_h , grid_w ] ] )
image_embeds = self . visual ( flatten_patches , image_grid_thw )
image_embeds = image_embeds . unsqueeze ( 1 )
return image_embeds
def img_process ( self , image ) :
resized_height = self . image_size
resized_width = self . image_size
from transformers . image_transforms import (
convert_to_rgb ,
resize ,
rescale ,
normalize
)
from transformers . image_utils import (
OPENAI_CLIP_MEAN ,
OPENAI_CLIP_STD ,
PILImageResampling ,
infer_channel_dimension_format ,
to_numpy_array
)
image = convert_to_rgb ( image )
image = to_numpy_array ( image )
format = infer_channel_dimension_format ( image )
resample = PILImageResampling . BICUBIC
image = resize ( image , size = ( resized_height , resized_width ) , resample = resample , input_data_format = format )
image = rescale ( image , scale = 1 / 255.0 , input_data_format = format )
image = normalize ( image = image , mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , input_data_format = format )
image = np . expand_dims ( image , [ 0 ] )
image = image . transpose ( 0 , 3 , 1 , 2 )
image = torch . from_numpy ( image )
self . image_embeds = self . forward ( image )
return self . image_embeds . shape [ 0 ]
def embed ( self , input_ids , images = None , videos = None ) :
input_embeds = self . embed_ ( input_ids )
if self . image_embeds is not None :
image_mask = ( input_ids == self . image_pad_id ) . squeeze ( )
input_embeds [ image_mask ] = self . image_embeds
return input_embeds
2024-11-18 14:37:45 +08:00
class MllamaVision ( Visual ) :
def __init__ ( self , visual , base ) :
super ( ) . __init__ ( visual , base )
self . image_objs = [ ]
def load ( self ) :
self . llm_config [ ' is_visual ' ] = True
self . llm_config [ ' image_size ' ] = self . config . vision_config . image_size
self . image_size = self . config . vision_config . image_size
def str_to_ids ( self , prompt ) :
if ' <img> ' in prompt and ' </img> ' in prompt :
import re
import requests
from PIL import Image
pattern = r ' (<img>.*?</img>) '
parts = re . split ( pattern , prompt )
txt_prompt = ' '
for part in parts :
if re . match ( pattern , part ) :
img_content = re . search ( r ' <img>(.*?)</img> ' , part ) . group ( 1 )
if img_content . startswith ( ' http:// ' ) or img_content . startswith ( ' https:// ' ) :
self . image_objs . append ( Image . open ( requests . get ( img_content , stream = True ) . raw ) )
txt_prompt + = ' <|image|> '
else :
txt_prompt + = part
else :
txt_prompt = prompt
input_ids = self . tokenizer ( txt_prompt , return_tensors = " pt " ) [ ' input_ids ' ]
# image process
for img in self . image_objs :
image_embeds = self . img_process ( img )
print ( image_embeds . shape )
pass
return input_ids
def img_process ( self , image ) :
resized_height = self . image_size
resized_width = self . image_size
from transformers . image_transforms import (
convert_to_rgb ,
resize ,
rescale ,
normalize
)
from transformers . image_utils import (
OPENAI_CLIP_MEAN ,
OPENAI_CLIP_STD ,
PILImageResampling ,
infer_channel_dimension_format ,
to_numpy_array
)
image = convert_to_rgb ( image )
image = to_numpy_array ( image )
format = infer_channel_dimension_format ( image )
resample = PILImageResampling . BICUBIC
image = resize ( image , size = ( resized_height , resized_width ) , resample = resample , input_data_format = format )
image = rescale ( image , scale = 1 / 255.0 , input_data_format = format )
image = normalize ( image = image , mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , input_data_format = format )
image = image . transpose ( 2 , 0 , 1 )
image = np . expand_dims ( image , [ 0 , 1 , 2 ] )
pad_val = np . zeros_like ( image )
image = np . concatenate ( [ image , pad_val , pad_val , pad_val ] , axis = 2 )
print ( image . shape )
image = torch . from_numpy ( image )
image_embeds = self . forward ( image )
print ( image_embeds . shape )
return image_embeds
def forward ( self , images ) :
aspect_ratio_ids = torch . tensor ( [ [ 1 ] ] )
aspect_ratio_mask = torch . tensor ( [ [ [ 1 , 0 , 0 , 0 ] ] ] )
return self . visual ( images , aspect_ratio_ids , aspect_ratio_mask )
def embed ( self , input_ids , images = None , videos = None ) :
return self . embed_ ( input_ids )
2024-09-12 12:57:57 +08:00
class LlmExporter ( torch . nn . Module ) :
'''
Base class for all llm model export . Inherits from [ ` torch . nn . Module ` ] .
'''
def __init__ ( self , args ) :
super ( ) . __init__ ( )
self . init_from_args ( args )
self . load_model ( args . path )
def init_from_args ( self , args ) :
2024-11-18 14:37:45 +08:00
self . max_length = 128
2024-09-12 12:57:57 +08:00
self . stop_ids = [ ]
self . visual = None
self . dst_name = ' llm '
# load config from args
self . path = args . path
self . dst_path = args . dst_path
2024-09-12 20:19:02 +08:00
self . onnx_path = os . path . join ( self . dst_path , ' onnx ' )
2024-11-18 14:37:45 +08:00
self . tokenizer_path = args . tokenizer_path
2024-09-12 12:57:57 +08:00
self . lora_path = args . lora_path
2024-11-18 14:37:45 +08:00
self . onnx_slim = args . onnx_slim
self . ppl = args . ppl
self . awq = args . awq
2024-09-12 12:57:57 +08:00
self . quant_bit = args . quant_bit
self . quant_block = args . quant_block
2024-11-18 14:37:45 +08:00
self . symmetric = args . sym
2024-09-12 12:57:57 +08:00
self . mnnconvert = args . mnnconvert
2024-11-18 14:37:45 +08:00
if self . tokenizer_path is None :
self . tokenizer_path = self . path
2024-09-12 12:57:57 +08:00
if args . lm_quant_bit is not None :
self . lm_quant_bit = args . lm_quant_bit
else :
self . lm_quant_bit = self . quant_bit
# init export dst dir
if not os . path . exists ( self . dst_path ) :
os . makedirs ( self . dst_path )
2024-09-12 20:19:02 +08:00
if not os . path . exists ( self . onnx_path ) :
os . makedirs ( self . onnx_path )
2024-09-12 12:57:57 +08:00
def load_pretrained ( self , model_path : str ) :
2024-11-18 14:37:45 +08:00
self . tokenizer = AutoTokenizer . from_pretrained ( self . tokenizer_path , trust_remote_code = True , use_fast = False )
2024-09-12 20:19:02 +08:00
if ' Qwen2-VL ' in model_path :
from transformers import Qwen2VLForConditionalGeneration
self . model = Qwen2VLForConditionalGeneration . from_pretrained ( model_path ) . float ( ) . eval ( )
2024-11-18 14:37:45 +08:00
elif ' Llama-3.2 ' in model_path and ' Vision ' in model_path :
from transformers import MllamaForConditionalGeneration
self . model = MllamaForConditionalGeneration . from_pretrained ( model_path ) . float ( ) . eval ( )
2024-09-12 20:19:02 +08:00
else :
try :
self . model = AutoModelForCausalLM . from_pretrained ( model_path , trust_remote_code = True ) . float ( ) . eval ( )
except :
self . model = AutoModel . from_pretrained ( model_path , trust_remote_code = True ) . float ( ) . eval ( )
2024-09-12 12:57:57 +08:00
self . config = self . model . config
if self . lora_path is not None :
from peft import PeftModel
adapter = PeftModel . from_pretrained ( self . model , model_id = self . lora_path )
self . model = adapter . merge_and_unload ( progressbar = True )
2024-09-12 20:19:02 +08:00
@staticmethod
def has_attr ( obj , attr ) :
return hasattr ( obj , attr ) and getattr ( obj , attr ) is not None
2024-09-12 12:57:57 +08:00
@spinner_run ( f ' load pretrained model ' )
def load_model ( self , model_path ) :
self . load_pretrained ( model_path )
self . attention_mask_type = ' float '
# load tokenizer info
self . stop_ids . append ( self . tokenizer . eos_token_id )
if hasattr ( self . tokenizer , ' im_end_id ' ) :
self . stop_ids . append ( self . tokenizer . im_end_id )
eot_id = self . tokenizer . encode ( ' <|eot_id|> ' )
if len ( eot_id ) == 1 :
self . stop_ids . append ( eot_id [ 0 ] )
2024-09-12 20:19:02 +08:00
if hasattr ( self . model , ' generation_config ' ) and self . model . generation_config is not None :
2024-09-12 12:57:57 +08:00
eos_token_id = self . model . generation_config . eos_token_id
from collections . abc import Iterable
if isinstance ( eos_token_id , int ) :
self . stop_ids . append ( eos_token_id )
elif isinstance ( eos_token_id , Iterable ) :
for id in eos_token_id :
self . stop_ids . append ( id )
self . stop_ids = [ stop_id for stop_id in self . stop_ids if stop_id is not None ]
self . stop_ids = list ( set ( self . stop_ids ) )
model_mapper = ModelMapper ( )
2024-11-18 14:37:45 +08:00
self . tie_word_embeddings = ( hasattr ( self . config , ' tie_word_embeddings ' ) and self . config . tie_word_embeddings )
2024-09-12 12:57:57 +08:00
self . model_type , self . model_map = model_mapper . get_map ( self . config )
2024-09-12 20:19:02 +08:00
# print(self.config, self.model_type, self.model_map, self.model)
2024-09-12 12:57:57 +08:00
# load config info
ModelMapper . do_map ( self , self . config , self . model_map [ ' config ' ] )
if not hasattr ( self , ' num_key_value_heads ' ) or self . num_key_value_heads is None :
self . num_key_value_heads = self . num_attention_heads
if not hasattr ( self , ' rope_theta ' ) or self . rope_theta is None :
self . rope_theta = 10000.0
if not hasattr ( self , ' head_dim ' ) or self . head_dim is None :
2024-11-18 14:37:45 +08:00
if isinstance ( self . num_attention_heads , list ) :
self . head_dim = [ self . hidden_size / / atten_head for atten_head in self . num_attention_heads ]
else :
self . head_dim = self . hidden_size / / self . num_attention_heads
2024-09-12 12:57:57 +08:00
# some export info
2024-11-18 14:37:45 +08:00
if isinstance ( self . num_attention_heads , list ) :
self . past_kv_shape = [ self . num_hidden_layers , 2 , 1 , 0 , self . num_key_value_heads [ 0 ] , self . head_dim ]
else :
self . past_kv_shape = [ self . num_hidden_layers , 2 , 1 , 0 , self . num_key_value_heads , self . head_dim ]
2024-09-12 12:57:57 +08:00
self . block_dynamic_axes = {
" inputs_embeds " : { 0 : " seq_len " } ,
" attention_mask " : { 2 : " seq_len " , 3 : " seq_len " } ,
" position_ids " : { 0 : " seq_len " } ,
" past_key_values " : { 1 : " history_len " }
}
self . model_dynamic_axes = {
" input_ids " : { 0 : " seq_len " } ,
" attention_mask " : { 2 : " seq_len " , 3 : " seq_len " } ,
2024-11-18 14:37:45 +08:00
" position_ids " : { 1 : " seq_len " } ,
" past_key_values " : { 3 : " history_len " }
2024-09-12 12:57:57 +08:00
}
self . llm_config = {
' hidden_size ' : self . hidden_size ,
' layer_nums ' : self . num_hidden_layers ,
' attention_mask ' : self . attention_mask_type ,
' key_value_shape ' : self . past_kv_shape [ 1 : ] ,
" prompt_template " : self . build_prompt ( ' %s ' ) ,
' is_visual ' : False
}
# load modules
ModelMapper . do_map ( self , self . model , self . model_map [ ' model ' ] )
# rebuild modules
2024-11-18 14:37:45 +08:00
if self . lm_ is None :
out_features , in_features = self . embed_ . weight . shape
self . lm_ = torch . nn . Linear ( in_features , out_features )
self . lm_ . weight = self . embed_ . weight
2024-09-12 12:57:57 +08:00
if self . embed_ . weight is self . lm_ . weight :
import copy
embed_copy = copy . deepcopy ( self . embed_ )
self . embed = Embedding ( embed_copy , self )
else :
self . embed = Embedding ( self . embed_ , self )
# Rotary
self . rotary = Rotary ( self )
self . blocks = [ ]
for block in self . blocks_ . children ( ) :
2024-11-18 14:37:45 +08:00
layer_id = len ( self . blocks )
self . blocks . append ( Decoder ( block , layer_id , self ) )
2024-09-12 12:57:57 +08:00
self . lm = Lm ( self . lm_ , self . final_layernorm_ , self )
# visual model
if self . visual is not None :
2024-09-12 20:19:02 +08:00
self . visual = Visual . get_visual ( self . model_type ) ( self . visual , self )
2024-09-12 12:57:57 +08:00
return model_path
def get_attention_mask ( self ) - > torch . Tensor :
if self . model_type == ' chatglm ' :
return self . chatglm_attention_mask ( )
if self . token_len :
return torch . zeros ( [ 1 , 1 , 1 , self . seq_len ] , dtype = torch . float32 )
return ( 1 - torch . tril ( torch . ones ( [ 1 , 1 , self . seq_len , self . seq_len ] ) ) ) * torch . finfo ( torch . float32 ) . min
def get_position_ids ( self ) - > torch . Tensor :
if self . model_type == ' chatglm ' :
return self . chatglm_position_ids ( )
if self . token_len :
2024-11-18 14:37:45 +08:00
return torch . tensor ( [ [ self . seq_len - 1 ] ] , dtype = torch . int )
return torch . arange ( self . seq_len , dtype = torch . int ) . unsqueeze ( 0 )
2024-09-12 12:57:57 +08:00
def chatglm_attention_mask ( self ) :
if self . token_len :
return torch . zeros ( [ 1 ] ) . bool ( ) . reshape ( [ 1 , 1 , 1 , 1 ] )
attention_mask = torch . zeros ( [ self . seq_len , self . seq_len ] , dtype = torch . bool )
for i in range ( self . seq_len - 1 ) :
attention_mask [ i ] [ - 1 ] = True
attention_mask = attention_mask . reshape ( [ 1 , 1 , self . seq_len , self . seq_len ] )
return attention_mask
def chatglm_position_ids ( self ) :
if self . token_len :
return torch . tensor ( [ self . context_len , self . token_len + 1 ] ) . reshape ( [ 1 , 2 , 1 ] )
2024-11-18 14:37:45 +08:00
position_ids_0 = torch . arange ( self . seq_len , dtype = torch . int )
position_ids_1 = torch . zeros ( self . seq_len , dtype = torch . int )
2024-09-12 12:57:57 +08:00
position_ids_0 [ - 1 ] = position_ids_0 [ - 2 ]
position_ids_1 [ - 1 ] = 1
position_ids = torch . stack ( [ position_ids_0 , position_ids_1 ] ) . view ( 1 , 2 , - 1 )
return position_ids
def visual_embed ( self , input_ids ) :
2024-09-12 20:19:02 +08:00
return self . visual . embed ( input_ids )
2024-09-12 12:57:57 +08:00
def embedding ( self , input_ids ) :
if self . visual is not None and self . token_len == 0 :
input_embeds = self . visual_embed ( input_ids )
else :
input_embeds = self . embed ( input_ids )
return input_embeds
2024-11-18 14:37:45 +08:00
def forward ( self ,
input_ids : torch . Tensor ,
attention_mask : torch . Tensor ,
position_ids : torch . Tensor ,
past_key_values : Optional [ list [ torch . Tensor ] ] = None ,
cross_attention_states : Optional [ torch . Tensor ] = None ,
cross_attention_mask : Optional [ torch . Tensor ] = None ,
) :
2024-09-12 12:57:57 +08:00
hidden_states = input_ids # llm forward without embedding
2024-11-18 14:37:45 +08:00
presents = [ None for i in range ( self . num_hidden_layers ) ]
2024-09-12 12:57:57 +08:00
rotary_pos_emb = self . rotary ( position_ids )
for i in range ( self . num_hidden_layers ) :
2024-11-18 14:37:45 +08:00
if self . blocks [ i ] . cross_decoder and cross_attention_states is None :
continue
2024-09-12 12:57:57 +08:00
hidden_states , kv = self . blocks [ i ] ( hidden_states , rotary_pos_emb , attention_mask , past_key_values [ i ] )
2024-11-18 14:37:45 +08:00
presents [ i ] = kv
logits = self . lm ( hidden_states )
if not self . ppl :
logits = logits . reshape ( - 1 )
if presents [ 0 ] . shape == presents [ - 1 ] . shape and None not in presents :
presents = torch . stack ( presents )
2024-09-12 12:57:57 +08:00
self . seq_len + = 1
self . token_len + = 1
return logits , presents
# some test functions
def build_prompt ( self , query ) :
# just for test
2024-11-18 14:37:45 +08:00
if ' Qwen2 ' in self . path or ' reader ' in self . path :
2024-09-12 12:57:57 +08:00
return f ' <|im_start|>user \n { query } <|im_end|> \n <|im_start|>assistant \n '
if ' Qwen ' in self . path :
return f ' \n <|im_start|>user \n { query } <|im_end|> \n <|im_start|>assistant \n '
if ' Baichuan2 ' in self . path :
return f ' <reserved_106> { query } <reserved_107> '
if ' internlm ' in self . path :
return f ' <|User|>: { query } <eoh> \n <|Bot|>: '
if ' TinyLlama ' in self . path :
return f ' <s><|system|> \n You are a friendly chatbot who always responds in the style of a pirate</s> \n <|user|> \n { query } </s> \n <|assistant|> \n '
if ' Yi ' in self . path :
return f ' <|im_start|> user \n { query } <|im_end|> \n <|im_start|> assistant \n '
if ' deepseek ' in self . path :
return f ' <|begin_of_sentence|>User: { query } \n \n Assistant: '
if ' Llama-3.1 ' in self . path :
return f ' <|start_header_id|>user<|end_header_id|> \n \n { query } <|eot_id|><|start_header_id|>assistant<|end_header_id|> \n \n '
if ' Llama-3 ' in self . path :
return f ' <|begin_of_text|><|start_header_id|>user<|end_header_id|> \n \n { query } <|eot_id|><|start_header_id|>assistant<|end_header_id|> \n \n '
if ' Llama-2 ' in self . path :
return f ' [INST] { query } [/INST] '
if ' chatglm2 ' in self . path :
return f ' [Round 1] \n \n 问: { query } \n \n 答: '
if ' chatglm3 ' in self . path or ' glm-4 ' in self . path :
return f ' <|user|> \n { query } \n <|assistant|> \n '
if ' chatglm ' in self . path :
return f ' { query } [gMASK]<sop> '
if ' phi-2 ' in self . path :
return f ' Instruct: { query } \n Output: '
if ' gemma-2 ' in self . path :
return f ' <bos><start_of_turn>user \n { query } <end_of_turn> \n <start_of_turn>model \n '
2024-11-18 14:37:45 +08:00
if ' OpenELM ' in self . path :
return f ' <s> { query } '
if ' SmolLM2 ' in self . path :
return f ' <|im_start|>system \n You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|> \n <|im_start|>user \n { query } <|im_end|> \n <|im_start|>assistant \n '
2024-09-12 12:57:57 +08:00
return query
def str_to_ids ( self , prompt ) :
2024-09-12 20:19:02 +08:00
if self . visual is not None :
return self . visual . str_to_ids ( prompt )
2024-09-12 12:57:57 +08:00
input_ids = self . tokenizer ( prompt , return_tensors = " pt " ) [ ' input_ids ' ]
return input_ids
def id_to_str ( self , token_id ) :
2024-11-18 14:37:45 +08:00
def contains_replacement ( text ) : return ' \uFFFD ' in text
def decode_id ( token_id ) :
return self . tokenizer . convert_tokens_to_string (
self . tokenizer . _convert_id_to_token ( int ( token_id ) ) )
def decode_ids ( token_ids ) :
return self . tokenizer . convert_tokens_to_string (
self . tokenizer . convert_ids_to_tokens ( token_ids ) )
word = decode_id ( int ( token_id ) )
# Smollm tokenizer will produce half chinese character, using buffer to decode
if contains_replacement ( word ) :
self . decode_buffer . append ( token_id )
buffer_txt = decode_ids ( self . decode_buffer )
if not contains_replacement ( buffer_txt ) :
word = buffer_txt
self . decode_buffer . clear ( )
else :
word = ' '
2024-09-12 12:57:57 +08:00
return word
def response ( self , query ) :
2024-11-18 14:37:45 +08:00
# self.imitate_quant()
self . decode_buffer = [ ]
2024-09-12 12:57:57 +08:00
prompt = self . build_prompt ( query )
input_ids = self . str_to_ids ( prompt )
2024-11-18 14:37:45 +08:00
if self . visual is not None :
cross_attention_states = self . visual . cross_attention_states
cross_attention_mask = self . visual . cross_attention_mask
else :
cross_attention_states = None
cross_attention_mask = None
2024-09-12 12:57:57 +08:00
self . seq_len = input_ids . numel ( )
self . context_len = self . seq_len - 2
self . token_len = 0
past_key_values = [ None for i in range ( self . num_hidden_layers ) ]
token_id = input_ids
while self . token_len < self . max_length :
attention_mask = self . get_attention_mask ( )
position_ids = self . get_position_ids ( )
2024-09-12 20:19:02 +08:00
input_ids = self . embedding ( token_id )
2024-11-18 14:37:45 +08:00
logits , past_key_values = self . forward ( input_ids ,
attention_mask ,
position_ids ,
past_key_values ,
cross_attention_states ,
cross_attention_mask )
2024-09-12 12:57:57 +08:00
token_id = torch . argmax ( logits )
if token_id in self . stop_ids :
print ( " " , end = ' \n ' )
break
word = self . id_to_str ( token_id )
print ( word , end = " " , flush = True )
2024-09-12 20:19:02 +08:00
@spinner_run ( f ' export visual to ' )
2024-09-12 12:57:57 +08:00
def export_visual ( self ) :
if self . visual is None :
return
2024-09-12 20:19:02 +08:00
input_images = torch . randn ( ( 1 , 3 , self . visual . image_size , self . visual . image_size ) )
2024-09-12 12:57:57 +08:00
model = self . visual
2024-09-12 20:19:02 +08:00
onnx_model = f ' { self . onnx_path } /visual.onnx '
2024-09-12 12:57:57 +08:00
torch . onnx . export ( model , ( input_images ) ,
onnx_model ,
input_names = [ ' input_images ' ] ,
output_names = [ ' image_embeds ' ] ,
dynamic_axes = { " input_images " : {
0 : " size "
} } ,
do_constant_folding = True ,
2024-09-12 20:19:02 +08:00
verbose = False ,
2024-09-12 12:57:57 +08:00
opset_version = 15 )
return onnx_model
@spinner_run ( f ' export embedding to ' )
def export_embed ( self ) :
import ctypes
if hasattr ( self , ' word_embeddings ' ) :
# embedding model's embed
tensor_data = self . word_embeddings . weight . data . bfloat16 ( )
else :
tensor_data = self . embed . embed . weight . data . bfloat16 ( )
data_ptr = tensor_data . untyped_storage ( ) . data_ptr ( )
buffer = ( ctypes . c_byte * ( tensor_data . numel ( ) * 2 ) ) . from_address ( data_ptr )
embedding_file = f ' { self . dst_path } /embeddings_bf16.bin '
with open ( embedding_file , ' wb ' ) as f :
f . write ( buffer )
return embedding_file
@spinner_run ( f ' export config to ' )
def export_config ( self , mnn_config = False ) :
config_json = f ' { self . dst_path } /llm_config.json '
with open ( config_json , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( self . llm_config , f , ensure_ascii = False , indent = 4 )
if not mnn_config :
return config_json
with open ( f ' { self . dst_path } /config.json ' , ' w ' , encoding = ' utf-8 ' ) as f :
config = {
" llm_model " : f " { self . dst_name } .mnn " ,
" llm_weight " : f " { self . dst_name } .mnn.weight " ,
" backend_type " : " cpu " ,
" thread_num " : 4 ,
" precision " : " low " ,
" memory " : " low "
}
json . dump ( config , f , ensure_ascii = False , indent = 4 )
return config_json
def imitate_quant ( self ) :
def quant_dequant ( linear , quant_bit = self . quant_bit , quant_block = self . quant_block ) :
weight = linear . weight . data
oc , ic = weight . shape
if quant_block == 0 :
block_size = ic
else :
block_size = quant_block
block_num = ic / / block_size
weight = weight . reshape ( oc , block_num , block_size )
max_val = torch . max ( weight , axis = - 1 , keepdims = True ) . values
min_val = torch . min ( weight , axis = - 1 , keepdims = True ) . values
offset = 1 << ( quant_bit - 1 )
clip_max = offset - 1
clip_min = - offset
scale = ( max_val - min_val ) / ( clip_max - clip_min )
q_weight = torch . round ( ( weight - min_val ) / scale ) + clip_min
q_weight = torch . clip ( q_weight , clip_min , clip_max )
dq_weight = ( q_weight - clip_min ) * scale + min_val
dq_weight = dq_weight . reshape ( oc , ic ) . float ( )
linear . weight . data = dq_weight
return linear
with torch . no_grad ( ) :
for i in range ( self . num_hidden_layers ) :
for name , child in self . blocks [ i ] . self_attn . named_children ( ) :
if isinstance ( child , torch . nn . Linear ) :
setattr ( self . blocks [ i ] . self_attn , name , quant_dequant ( child ) )
for name , child in self . blocks [ i ] . mlp . named_children ( ) :
if isinstance ( child , torch . nn . Linear ) :
setattr ( self . blocks [ i ] . mlp , name , quant_dequant ( child ) )
self . lm . lm = quant_dequant ( self . lm . lm )
def unload_param ( self ) :
self . unloaded_ops = { }
def build_faker ( real , name ) :
faker = FakeLinear ( real . in_features , real . out_features , real . bias is not None , name )
self . unloaded_ops [ name ] = real
return faker
# replace linear with fakelinear to save export memory and time
with torch . no_grad ( ) :
for i in range ( self . num_hidden_layers ) :
2024-11-18 14:37:45 +08:00
# different kv cache shape in different layers
if isinstance ( self . num_attention_heads , list ) :
self . blocks [ i ] . self_attn . export_fused_attn = True
2024-09-12 12:57:57 +08:00
for name , child in self . blocks [ i ] . self_attn . named_children ( ) :
if isinstance ( child , torch . nn . Linear ) :
setattr ( self . blocks [ i ] . self_attn , name , build_faker ( child , f ' /layers. { i } /self_attn/ { name } /Linear ' ) )
for name , child in self . blocks [ i ] . mlp . named_children ( ) :
if isinstance ( child , torch . nn . Linear ) :
setattr ( self . blocks [ i ] . mlp , name , build_faker ( child , f ' /layers. { i } /mlp/ { name } /Linear ' ) )
self . lm . lm = build_faker ( self . lm . lm , f ' /lm/lm_head/Linear ' )
@spinner_run ( f ' export model weight to ' )
def onnx_load_param ( self , onnx_path ) :
return OnnxRebuilder ( onnx_path , self . unloaded_ops ) . rebuild ( )
@spinner_run ( f ' slim the graph of ' )
def onnx_slim ( self , onnx_model ) :
import onnxslim
model = onnxslim . slim ( onnx_model )
onnx . save ( model , onnx_model )
return onnx_model
@spinner_run ( f ' export onnx model to ' )
def export_onnx ( self ) :
# unload linear weight to save export memory
self . unload_param ( )
model = self
self . seq_len = 3
self . token_len = 0
input_ids = torch . arange ( 3 , dtype = torch . long )
attention_mask = self . get_attention_mask ( )
position_ids = self . get_position_ids ( )
2024-09-12 20:19:02 +08:00
onnx_model = f ' { self . onnx_path } / { self . dst_name } .onnx '
2024-09-12 12:57:57 +08:00
input_ids = self . embedding ( input_ids )
2024-11-18 14:37:45 +08:00
past_key_values = torch . zeros ( self . past_kv_shape )
2024-09-12 12:57:57 +08:00
# export to onnx
torch . onnx . export (
model , ( input_ids , attention_mask , position_ids , past_key_values ) ,
onnx_model ,
input_names = [
' input_ids ' , ' attention_mask ' , ' position_ids ' , ' past_key_values '
] ,
output_names = [ ' logits ' , ' presents ' ] ,
dynamic_axes = self . model_dynamic_axes ,
do_constant_folding = True ,
2024-11-18 14:37:45 +08:00
verbose = False ,
2024-09-12 12:57:57 +08:00
opset_version = 15 )
return onnx_model
2024-11-18 14:37:45 +08:00
def awq_quant ( self ) :
self . awq_quantizer = AwqQuantizer ( self )
self . awq_quantizer . quantize ( )
self . is_awq_quantized = True
2024-09-12 12:57:57 +08:00
def export ( self , export_type ) :
2024-11-18 14:37:45 +08:00
if self . awq :
self . awq_quant ( )
2024-09-12 12:57:57 +08:00
export_mnn = export_type == ' mnn '
# export tokenizer
self . export_tokenizer ( )
2024-11-18 14:37:45 +08:00
if export_mnn and self . tie_word_embeddings :
pass # mnn tie_word_embeddings need't export embedding
else :
self . export_embed ( )
2024-09-12 12:57:57 +08:00
if self . visual :
2024-09-12 20:19:02 +08:00
visual_onnx = self . export_visual ( )
2024-11-18 14:37:45 +08:00
#if self.onnx_slim:
2024-09-12 20:19:02 +08:00
#visual_onnx = self.onnx_slim(visual_onnx)
if export_mnn :
MNNConveter ( visual_onnx , None , self ) . export ( quant_bit = self . visual . quant_bit )
2024-09-12 12:57:57 +08:00
# export graph to llm.onnx
onnx_model = self . export_onnx ( )
2024-11-18 14:37:45 +08:00
if self . onnx_slim :
2024-09-12 12:57:57 +08:00
self . onnx_slim ( onnx_model )
if export_mnn :
# convert onnx to mnn and quant weight
MNNConveter ( onnx_model , self . unloaded_ops , self ) . export ( )
2024-11-18 14:37:45 +08:00
# delete onnx file
if os . path . exists ( onnx_model ) :
try :
os . remove ( onnx_model )
os . rmdir ( self . onnx_path )
except Exception as e :
print ( f " remove onnx error: { e } " )
2024-09-12 12:57:57 +08:00
else :
# export weight to llm.onnx.data
self . onnx_load_param ( onnx_model )
2024-11-18 14:37:45 +08:00
# export llm_config.json and config.json
self . export_config ( export_mnn )
2024-09-12 12:57:57 +08:00
@spinner_run ( f ' export tokenizer to ' )
def export_tokenizer ( self ) :
# load tokenizer file
2024-11-18 14:37:45 +08:00
tokenizer_model = os . path . join ( self . tokenizer_path , ' tokenizer.model ' )
ice_text_model = os . path . join ( self . tokenizer_path , ' ice_text.model ' )
2024-09-12 12:57:57 +08:00
try :
import sentencepiece as spm
if os . path . exists ( tokenizer_model ) :
self . sp_model = spm . SentencePieceProcessor ( tokenizer_model )
elif os . path . exists ( ice_text_model ) :
self . sp_model = spm . SentencePieceProcessor ( ice_text_model )
else :
self . sp_model = None
except :
self . sp_model = None
merge_file = os . path . join ( self . path , ' merges.txt ' )
if os . path . exists ( merge_file ) :
self . merge_txt = merge_file
else :
self . merge_txt = None
# TOKENIZER MAGIC NUMBER
MAGIC_NUMBER = 430
# TOKENIZER TYPE
SENTENCEPIECE = 0 ; TIKTOIKEN = 1 ; BERT = 2 ; HUGGINGFACE = 3
def write_line ( fp , * args ) :
for arg in args :
for token in arg :
fp . write ( str ( token ) + ' ' )
fp . write ( ' \n ' )
def write_header ( fp , type , speicals , prefix = [ ] ) :
fp . write ( f ' { MAGIC_NUMBER } { type } \n ' )
fp . write ( f ' { len ( speicals ) } { len ( self . stop_ids ) } { len ( prefix ) } \n ' )
write_line ( fp , speicals , self . stop_ids , prefix )
file_path = os . path . join ( self . dst_path , " tokenizer.txt " )
special_list = list ( self . tokenizer . added_tokens_decoder . keys ( ) )
if hasattr ( self . tokenizer , ' special_tokens ' ) :
for k , v in self . tokenizer . special_tokens . items ( ) :
special_list . append ( v )
if hasattr ( self . tokenizer , ' gmask_token_id ' ) :
special_list . append ( self . tokenizer . gmask_token_id )
vocab_list = [ ]
prefix_list = [ ]
if hasattr ( self . tokenizer , ' get_prefix_tokens ' ) :
prefix_list = self . tokenizer . get_prefix_tokens ( )
2024-11-18 14:37:45 +08:00
if len ( prefix_list ) == 0 :
test_txt = ' A '
ids = self . tokenizer . encode ( test_txt )
get_txt = self . tokenizer . decode ( ids [ - 1 ] )
if len ( ids ) > 1 and get_txt == test_txt :
prefix_list + = ids [ : - 1 ]
2024-09-12 12:57:57 +08:00
if self . sp_model is not None :
# senetencepiece
NORMAL = 1 ; UNKNOWN = 2 ; CONTROL = 3
USER_DEFINED = 4 ; UNUSED = 5 ; BYTE = 6
for i in range ( self . sp_model . GetPieceSize ( ) ) :
token = self . sp_model . IdToPiece ( i )
score = self . sp_model . GetScore ( i )
token_type = NORMAL
if self . sp_model . IsUnknown ( i ) :
token_type = UNKNOWN
elif self . sp_model . IsControl ( i ) :
token_type = CONTROL
elif self . sp_model . IsUnused ( i ) :
token_type = UNUSED
elif self . sp_model . IsByte ( i ) :
token_type = BYTE
if self . path == ' Chatglm_6b ' :
if ' <n> ' in token : token = ' \n '
if ' <|tab|> ' in token : token = ' \t '
if ' <|blank_ ' in token : token = ' ' * int ( token [ 8 : token . find ( ' |> ' ) ] )
if ' ▁ ' in token : token = token . replace ( ' ▁ ' , ' ' )
token_encode = base64 . b64encode ( token . encode ( " utf-8 " ) ) . decode ( " utf8 " )
vocab_list . append ( f ' { token_encode } { score } { token_type } \n ' )
with open ( file_path , " w " , encoding = " utf8 " ) as fp :
write_header ( fp , SENTENCEPIECE , special_list , prefix_list )
fp . write ( f ' { len ( vocab_list ) } \n ' )
for vocab in vocab_list :
fp . write ( vocab )
elif hasattr ( self . tokenizer , ' mergeable_ranks ' ) :
# tikton
vocab_list = [ ]
for k , v in self . tokenizer . mergeable_ranks . items ( ) :
line = base64 . b64encode ( k ) . decode ( " utf8 " ) + " \n "
vocab_list . append ( line )
if hasattr ( self . tokenizer , ' special_tokens ' ) :
for k , v in self . tokenizer . special_tokens . items ( ) :
line = base64 . b64encode ( k . encode ( " utf-8 " ) ) . decode ( " utf8 " ) + " \n "
vocab_list . append ( line )
if hasattr ( self . tokenizer , ' added_tokens_decoder ' ) :
for k , v in self . tokenizer . added_tokens_decoder . items ( ) :
line = base64 . b64encode ( v . __str__ ( ) . encode ( " utf-8 " ) ) . decode ( " utf8 " ) + " \n "
vocab_list . append ( line )
with open ( file_path , " w " , encoding = " utf8 " ) as fp :
write_header ( fp , TIKTOIKEN , special_list , prefix_list )
fp . write ( f ' { len ( vocab_list ) } \n ' )
for vocab in vocab_list :
fp . write ( vocab )
elif self . merge_txt is not None :
# huggingface tokenizer
merge_list = [ ]
vocab = self . tokenizer . get_vocab ( )
special_list = list ( self . tokenizer . added_tokens_decoder . keys ( ) )
vocab_list = [ ' <unk> ' for i in range ( len ( vocab ) ) ]
# load vocab
for k , v in vocab . items ( ) :
vocab_list [ int ( v ) ] = k
# load merge
with open ( self . merge_txt , ' rt ' ) as merge :
for line in merge . readlines ( ) :
merge_list . append ( line )
# write to tokenizer.txt
with open ( file_path , " w " , encoding = " utf8 " ) as fp :
write_header ( fp , HUGGINGFACE , special_list )
fp . write ( f ' { len ( vocab_list ) } { len ( merge_list ) } \n ' )
for v in vocab_list :
fp . write ( v + ' \n ' )
for m in merge_list :
fp . write ( m )
else :
# tiktoken or bert
if ' bert ' in type ( self . tokenizer ) . __name__ . lower ( ) :
tokenizer_type = BERT
else :
tokenizer_type = TIKTOIKEN
# bert tokenizer
def unicode_to_byte ( u : int ) :
if u > = 256 and u < = 288 :
return u - 256
if u > = 289 and u < = 322 :
return u - 162
if u == 323 :
return 173
if u == 65372 : # |
return 124
if u == 9601 : # _
return 95
return u
vocab = self . tokenizer . get_vocab ( )
vocab_list = [ ' <unk> ' for i in range ( len ( vocab ) ) ]
for k , v in vocab . items ( ) :
try :
vocab_list [ int ( v ) ] = bytes ( [ unicode_to_byte ( ord ( c ) ) for c in k ] ) . decode ( ' utf-8 ' , errors = ' ignore ' )
except :
vocab_list [ int ( v ) ] = k
special_list = list ( self . tokenizer . added_tokens_decoder . keys ( ) )
with open ( file_path , " w " , encoding = " utf8 " ) as fp :
write_header ( fp , tokenizer_type , special_list )
fp . write ( f ' { len ( vocab_list ) } \n ' )
for v in vocab_list :
line = base64 . b64encode ( v . encode ( ' utf-8 ' ) ) . decode ( " utf8 " ) + " \n "
fp . write ( line )
return file_path
class EmbeddingExporter ( LlmExporter ) :
def __init__ ( self , args ) :
super ( ) . __init__ ( args )
self . dst_name = ' embedding '
def word_embed ( self , input_ids ) :
return self . word_embeddings ( input_ids . view ( 1 , - 1 ) )
def bge_forward ( self , inputs_embeds , position_ids , attention_mask ) :
# bert absolute position
inputs_embeds = inputs_embeds . reshape ( 1 , - 1 , self . hidden_size )
position_embeddings = self . position_embeddings ( position_ids )
embeddings = inputs_embeds + position_embeddings + self . token_type_embeddings
hidden_states = self . embedding_layernorm ( embeddings )
for i in range ( self . num_hidden_layers ) :
hidden_states = self . blocks [ i ] ( hidden_states , attention_mask ) [ 0 ]
sentence_embeddings = hidden_states [ : , 0 ]
sentence_embeddings = torch . nn . functional . normalize ( sentence_embeddings , p = 2 , dim = 1 )
return sentence_embeddings
def gte_forward ( self , inputs_embeds , position_ids , attention_mask ) :
# rope position
inputs_embeds = inputs_embeds . reshape ( 1 , - 1 , self . hidden_size )
freqs = position_ids . float ( ) . reshape ( - 1 , 1 ) * self . inv_freq
emb = torch . cat ( ( freqs , freqs ) , dim = - 1 )
rope_embeds = torch . stack ( [ emb . cos ( ) , emb . sin ( ) ] ) . unsqueeze ( - 2 ) . unsqueeze ( 1 )
attention_bias = 1 - attention_mask . float ( )
hidden_states = self . embedding_layernorm ( inputs_embeds + self . token_type_embeddings )
for i in range ( self . num_hidden_layers ) :
hidden_states = self . blocks [ i ] ( hidden_states , attention_bias , rope_embeds ) [ 0 ]
sentence_embeddings = hidden_states [ : , 0 ]
sentence_embeddings = torch . nn . functional . normalize ( sentence_embeddings , p = 2 , dim = 1 )
return sentence_embeddings
def forward ( self , inputs_embeds , position_ids , attention_mask ) :
if self . model_type == ' bert ' :
return self . bge_forward ( inputs_embeds , position_ids , attention_mask )
if self . model_type == ' new ' :
return self . gte_forward ( inputs_embeds , position_ids , attention_mask )
raise RuntimeError ( f ' Not support embedding model: { self . model_type } ! ' )
def response ( self , query ) :
self . eval ( )
input_ids = self . tokenizer ( query ) [ ' input_ids ' ]
self . seq_len = len ( input_ids )
input_ids = torch . tensor ( input_ids )
position_ids = self . get_position_ids ( )
attention_mask = self . get_attention_mask ( )
inputs_embeds = self . word_embed ( input_ids )
res = self . forward ( inputs_embeds , position_ids , attention_mask )
2024-09-12 20:19:02 +08:00
# print(res)
2024-09-12 12:57:57 +08:00
return res
@spinner_run ( f ' load pretrained model ' )
def load_model ( self , model_path ) :
self . tokenizer = AutoTokenizer . from_pretrained ( model_path , trust_remote_code = True )
2024-11-18 14:37:45 +08:00
self . config = AutoConfig . from_pretrained ( model_path )
self . config . _attn_implementation = ' eager '
self . model = AutoModel . from_config ( self . config )
2024-09-12 12:57:57 +08:00
transformer = self . model . encoder
self . model_type = self . config . model_type
self . lm_ = self . model . pooler
self . embed_ = self . model . embeddings
self . word_embeddings = self . embed_ . word_embeddings
self . token_type_embeddings = self . embed_ . token_type_embeddings . weight . data [ 0 ]
self . embedding_layernorm = self . embed_ . LayerNorm
if hasattr ( self . embed_ , ' position_embeddings ' ) :
self . position_embeddings = self . embed_ . position_embeddings
self . hidden_size = self . word_embeddings . weight . shape [ - 1 ]
self . blocks = transformer . layer
if self . model_type == ' new ' :
self . inv_freq = self . embed_ . rotary_emb . inv_freq
# some wrapper
self . stop_ids = [ ]
self . num_hidden_layers = len ( self . blocks )
self . embed = self . embed_
self . lm = self . lm_
# some config for export
self . model_dynamic_axes = {
" input_ids " : { 1 : " seq_len " } ,
" position_ids " : { 1 : " seq_len " } ,
" attention_mask " : { 3 : " seq_len " }
}
self . attention_mask_type = ' int '
self . llm_config = {
' hidden_size ' : self . hidden_size ,
' layer_nums ' : self . num_hidden_layers ,
' attention_mask ' : self . attention_mask_type ,
' key_value_shape ' : [ ] ,
" prompt_template " : self . build_prompt ( ' %s ' ) ,
' is_visual ' : False
}
return model_path
@spinner_run ( f ' export onnx model to ' )
def export_onnx ( self ) :
model = self . eval ( )
self . seq_len = 3
input_ids = torch . arange ( 3 , dtype = torch . long )
position_ids = self . get_position_ids ( )
attention_mask = self . get_attention_mask ( )
inputs_embeds = self . word_embed ( input_ids )
2024-09-12 20:19:02 +08:00
onnx_model = f ' { self . onnx_path } / { self . dst_name } .onnx '
2024-09-12 12:57:57 +08:00
torch . onnx . export (
model , ( inputs_embeds , position_ids , attention_mask ) ,
onnx_model ,
input_names = [
' input_ids ' ,
' position_ids ' ,
' attention_mask '
] ,
output_names = [ ' sentence_embeddings ' ] ,
dynamic_axes = self . model_dynamic_axes ,
do_constant_folding = True ,
opset_version = 15 )
return onnx_model
def export ( self , export_type ) :
export_mnn = ' mnn ' in export_type
self . export_tokenizer ( )
self . export_config ( export_mnn )
self . export_embed ( )
onnx_model = self . export_onnx ( )
2024-11-18 14:37:45 +08:00
if self . onnx_slim :
2024-09-12 12:57:57 +08:00
self . onnx_slim ( onnx_model )
if export_mnn :
MNNConveter ( onnx_model , None , self ) . export ( )
def build_prompt ( self , query ) :
if self . model_type == ' bert ' :
return f ' [CLS] { query } [SEP] '
if self . model_type == ' new ' :
return f ' <s> { query } </s> '
def get_position_ids ( self ) - > torch . Tensor :
return torch . arange ( self . seq_len , dtype = torch . long ) . unsqueeze ( 0 )
def get_attention_mask ( self ) - > torch . Tensor :
return torch . ones ( [ 1 , 1 , 1 , self . seq_len ] , dtype = torch . long )
2024-11-18 14:37:45 +08:00
2024-09-12 12:57:57 +08:00
def export ( path ,
type = None ,
lora_path = None ,
dst_path = ' ./model ' ,
export = ' onnx ' ,
2024-11-18 14:37:45 +08:00
onnx_slim = False ,
2024-09-12 12:57:57 +08:00
quant_bit = 4 ,
quant_block = 128 ,
lm_quant_bit = None ) :
args = argparse . Namespace ( )
for k , v in {
' path ' : path ,
' type ' : type ,
' lora_path ' : lora_path ,
' dst_path ' : dst_path ,
' export ' : export ,
2024-11-18 14:37:45 +08:00
' onnx_slim ' : onnx_slim ,
2024-09-12 12:57:57 +08:00
' quant_bit ' : quant_bit ,
' quant_block ' : quant_block ,
' lm_quant_bit ' : lm_quant_bit
} . items ( ) :
setattr ( args , k , v )
if ' bge ' in path :
llm_exporter = EmbeddingExporter ( args )
else :
llm_exporter = LlmExporter ( args )
# export
llm_exporter . export ( export )
def main ( ) :
parser = argparse . ArgumentParser ( description = ' llm_exporter ' , formatter_class = argparse . RawTextHelpFormatter )
parser . add_argument ( ' --path ' , type = str , required = True ,
help = ' path(`str` or `os.PathLike`): \n Can be either: '
' \n \t - A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO] '
' \n \t - A path to a *directory* clone from repo like `../chatglm-6b`. ' )
parser . add_argument ( ' --type ' , type = str , default = None ,
help = ' type(`str`, *optional*): '
' \n \t The pretrain llm model type. '
)
2024-11-18 14:37:45 +08:00
parser . add_argument ( ' --tokenizer_path ' , type = str , default = None , help = ' tokenizer path, defaut is `None` mean using `--path` value. ' )
2024-09-12 12:57:57 +08:00
parser . add_argument ( ' --lora_path ' , type = str , default = None , help = ' lora path, defaut is `None` mean not apply lora. ' )
parser . add_argument ( ' --dst_path ' , type = str , default = ' ./model ' , help = ' export onnx/mnn model to path, defaut is `./model`. ' )
2024-11-18 14:37:45 +08:00
parser . add_argument ( ' --verbose ' , action = ' store_true ' , help = ' Whether or not to print verbose. ' )
2024-09-12 12:57:57 +08:00
parser . add_argument ( ' --test ' , type = str , help = ' test model inference with query `TEST`. ' )
parser . add_argument ( ' --export ' , type = str , default = None , help = ' export model to an onnx/mnn model. ' )
2024-11-18 14:37:45 +08:00
parser . add_argument ( ' --onnx_slim ' , action = ' store_true ' , help = ' Whether or not to use onnx-slim. ' )
2024-09-12 12:57:57 +08:00
parser . add_argument ( ' --quant_bit ' , type = int , default = 4 , help = ' mnn quant bit, 4 or 8, default is 4. ' )
2024-11-18 14:37:45 +08:00
parser . add_argument ( ' --quant_block ' , type = int , default = 128 , help = ' mnn quant block, default is 0 mean channle-wise. ' )
2024-09-12 12:57:57 +08:00
parser . add_argument ( ' --lm_quant_bit ' , type = int , default = None , help = ' mnn lm_head quant bit, 4 or 8, default is `quant_bit`. ' )
parser . add_argument ( ' --mnnconvert ' , type = str , default = ' ../../../build/MNNConvert ' , help = ' local mnnconvert path, if invalid, using pymnn. ' )
2024-11-18 14:37:45 +08:00
parser . add_argument ( ' --ppl ' , action = ' store_true ' , help = ' Whether or not to get all logits of input tokens. ' )
parser . add_argument ( ' --awq ' , action = ' store_true ' , help = ' Whether or not to use awq quant. ' )
parser . add_argument ( ' --sym ' , action = ' store_true ' , help = ' Whether or not to using symmetric quant (without zeropoint), defualt is False. ' )
2024-09-12 12:57:57 +08:00
args = parser . parse_args ( )
model_path = args . path
model_type = args . type
if ' gte ' in model_path or ' bge ' in model_path :
llm_exporter = EmbeddingExporter ( args )
else :
llm_exporter = LlmExporter ( args )
# some actions
if args . test is not None :
llm_exporter . response ( args . test )
if args . export is not None :
llm_exporter . export ( args . export )
if __name__ == ' __main__ ' :
2024-11-18 14:37:45 +08:00
main ( )