MNN/transformers/llm/export/llmexport.py

1705 lines
70 KiB
Python
Raw Normal View History

2024-09-12 12:57:57 +08:00
import os
import sys
import math
import copy
import json
import time
import base64
import logging
import warnings
import argparse
import functools
from typing import Optional, Tuple
from yaspin import yaspin
import onnx
import torch
import numpy as np
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
RESET = "\033[0m"
GREEN = "\033[32;1m"
YELLOW = "\033[33;4m"
EXPORT_LOG = '.export.log'
# ignore warnning info
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
def spinner_run(text='Processing...'):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with yaspin(text=text, color="cyan") as spinner:
start = time.time()
try:
result = func(*args, **kwargs)
except Exception as e:
spinner.fail("💥 Failed")
print(e)
exit(1)
end = time.time()
during = f'[{end-start:05.2f} s]'.replace('[0', '[ ')
padding = ' ' * (64 - len(spinner.text) - len(result))
spinner.text = f'{spinner.text}{YELLOW}{result}{RESET}{padding}{GREEN}{during}{RESET}'
spinner.ok("✅ Done")
return result
return wrapper
return decorator
class ModelMapper:
def __init__(self):
self.attrs = []
self.mapper = dict()
self.regist_models()
def get_map(self, config):
model_type = config.model_type
if model_type == 'chatglm':
if hasattr(config, 'vocab_size') and config.vocab_size == 130528:
model_type = 'chatglm'
else:
model_type = 'chatglm2'
if model_type in self.mapper:
return model_type, self.mapper[model_type]
return model_type, self.default_map
def regist(self, model_type, model_map):
assert('config' in model_map and
'decoder' in model_map and
'attention' in model_map)
self.mapper[model_type] = model_map
def regist_models(self):
self.defualt_map()
# regist models
self.regist_llama()
self.regist_qwen()
self.regist_glm()
self.regist_glm2()
self.regist_phi()
self.regist_gemma2()
def regist_llama(self):
llama_map = self.default_map
self.regist('llama', llama_map)
self.regist('qwen2', llama_map)
self.regist('internlm', llama_map)
baichuan_map = copy.deepcopy(self.default_map)
baichuan_map[self.attention_key] = {
'qkv_proj': 'W_pack',
'o_proj': 'o_proj'
}
self.regist('baichuan', baichuan_map)
def regist_qwen(self):
qwen_map = {
'config': {
'hidden_size': 'hidden_size',
'num_attention_heads': 'num_attention_heads',
'num_hidden_layers': 'num_hidden_layers',
'rope_theta': 'rotary_emb_base',
},
'model': {
'lm_': 'lm_head',
'embed_': 'transformer.wte',
'blocks_': 'transformer.h',
'final_layernorm_': 'transformer.ln_f',
'visual': 'transformer.visual'
},
'decoder': {
'self_attn': 'attn',
'mlp': 'mlp',
'input_layernorm': 'ln_1',
'post_attention_layernorm': 'ln_2'
},
'attention': {
'qkv_proj': 'c_attn',
'o_proj': 'c_proj'
}
}
self.regist('qwen', qwen_map)
def regist_glm(self):
glm_map = {
'config': {
'hidden_size': 'hidden_size',
'num_attention_heads': 'num_attention_heads',
'num_hidden_layers': 'num_layers'
},
'model': {
'lm_': 'lm_head',
'embed_': 'transformer.word_embeddings',
'blocks_': 'transformer.layers',
'final_layernorm_': 'transformer.final_layernorm',
},
'decoder': {
'self_attn': 'attention',
'mlp': 'mlp',
'input_layernorm': 'input_layernorm',
'post_attention_layernorm': 'post_attention_layernorm'
},
'attention': {
'qkv_proj': 'query_key_value',
'o_proj': 'dense'
}
}
self.regist('chatglm', glm_map)
def regist_glm2(self):
glm2_map = {
'config': {
'hidden_size': 'hidden_size',
'num_attention_heads': 'num_attention_heads',
'num_key_value_heads': 'multi_query_group_num',
'num_hidden_layers': 'num_layers',
},
'model': {
'lm_': 'transformer.output_layer',
'embed_': 'transformer.embedding.word_embeddings',
'blocks_': 'transformer.encoder.layers',
'final_layernorm_': 'transformer.encoder.final_layernorm',
},
'decoder': {
'self_attn': 'self_attention',
'mlp': 'mlp',
'input_layernorm': 'input_layernorm',
'post_attention_layernorm': 'post_attention_layernorm'
},
'attention': {
'qkv_proj': 'query_key_value',
'o_proj': 'dense'
}
}
self.regist('chatglm2', glm2_map)
def regist_phi(self):
phi_map = {
'config': {
'hidden_size': 'n_embd',
'num_attention_heads': 'n_head',
'num_hidden_layers': 'n_layer',
'rotary_dim': 'rotary_dim'
},
'model': {
'lm_': 'lm_head.linear',
'embed_': 'transformer.embd.wte',
'blocks_': 'transformer.h',
'final_layernorm_': 'lm_head.ln',
},
'decoder': {
'self_attn': 'mixer',
'mlp': 'mlp',
'input_layernorm': 'ln',
},
'attention': {
'qkv_proj': 'Wqkv',
'o_proj': 'out_proj'
}
}
self.regist('phi-msft', phi_map)
def regist_gemma2(self):
gemma2_config = copy.deepcopy(self.default_config)
gemma2_config['head_dim'] = 'head_dim'
gemma2_decoder = copy.deepcopy(self.default_decoder)
gemma2_decoder['pre_feedforward_layernorm'] = 'pre_feedforward_layernorm'
gemma2_decoder['post_feedforward_layernorm'] = 'post_feedforward_layernorm'
gemma2_map = {
'config': gemma2_config,
'model': self.defualt_model,
'decoder': gemma2_decoder,
'attention': self.default_attention
}
self.regist('gemma2', gemma2_map)
def defualt_map(self):
# default map is `LlamaForCausalLM`
self.config_key = 'config'
self.model_key = 'model'
self.decoder_key = 'decoder'
self.attention_key = 'attention'
self.default_config = {
'hidden_size': 'hidden_size',
'num_attention_heads': 'num_attention_heads',
'num_hidden_layers': 'num_hidden_layers',
'num_key_value_heads': 'num_key_value_heads',
'rope_theta': 'rope_theta'
}
self.defualt_model = {
'lm_': 'lm_head',
'embed_': 'model.embed_tokens',
'blocks_': 'model.layers',
'final_layernorm_': 'model.norm',
}
self.default_decoder = {
'self_attn': 'self_attn',
'mlp': 'mlp',
'input_layernorm': 'input_layernorm',
'post_attention_layernorm': 'post_attention_layernorm'
}
self.default_attention = {
'q_proj': 'q_proj',
'k_proj': 'k_proj',
'v_proj': 'v_proj',
'o_proj': 'o_proj'
}
self.default_map = {
'config': self.default_config,
'model': self.defualt_model,
'decoder': self.default_decoder,
'attention': self.default_attention
}
@staticmethod
def do_map(dst, src, map):
for dst_attr, src_attr in map.items():
attributes = src_attr.split('.')
obj = src
for attr in attributes:
if hasattr(obj, attr):
obj = getattr(obj, attr)
else:
obj = None
break
setattr(dst, dst_attr, obj)
# Export class
class LlmExporterOp(torch.autograd.Function):
@staticmethod
def symbolic(g, input, in_features, out_features, has_bias, name):
args = [input]
# These become the operator attributes.
kwargs = {
"in_features_i": in_features,
"out_features_i": out_features,
"has_bias_i": has_bias,
"name_s": name
}
from torch.onnx.symbolic_helper import _get_tensor_sizes
out_sizes = _get_tensor_sizes(input)[:-1] + [out_features]
output_type = input.type().with_sizes(out_sizes)
return g.op("LlmExporter::FakeLinear", input, **kwargs).setType(output_type)
@staticmethod
def forward(ctx, input, in_features, out_features, has_bias, name):
out_shape = list(input.shape)[:-1] + [out_features]
return input.new_zeros(out_shape)
class FakeLinear(torch.nn.Module):
def __init__(self, in_features, out_features, has_bias, name):
super(FakeLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.has_bias = has_bias
self.name = name
def forward(self, x):
return LlmExporterOp.apply(x, self.in_features, self.out_features, self.has_bias, self.name)
class OnnxRebuilder:
def __init__(self, onnx_path, weight_ops):
self.weight_ops = weight_ops
self.onnx_model = onnx.load(onnx_path)
self.dst_path = onnx_path
self.onnx_weight_path = f'{onnx_path}.data'
self.onnx_weight_offset = 0
def make_external(self, name, data, shape):
# write to external weight
length = self.onnx_weight.write(data.tobytes())
location = os.path.basename(self.onnx_weight_path)
offset = self.onnx_weight_offset
self.onnx_weight_offset += length
tensor = onnx.TensorProto()
tensor.name = name
tensor.data_type = onnx.TensorProto.FLOAT
tensor.dims.extend(shape)
# external info
tensor.data_location = onnx.TensorProto.EXTERNAL
for k, v in { "location": location, "offset": offset, "length": length }.items():
entry = tensor.external_data.add()
entry.key = k
entry.value = str(v)
self.onnx_model.graph.initializer.append(tensor)
def build_weight(self, name, has_bias, ic, oc):
assert(name in self.weight_ops)
linear = self.weight_ops[name]
assert(linear.in_features == ic and
linear.out_features == oc and
(linear.bias is not None) == has_bias)
weight_name, bias_name = f'{name}_weight', f'{name}_bias'
weight = linear.weight.data.transpose(1, 0).flatten().numpy()
self.make_external(weight_name, weight, [ic, oc])
if has_bias:
bias = linear.bias.data.flatten().numpy()
self.make_external(bias_name, bias, [oc])
return weight_name, bias_name
def rebuild(self):
from onnx import helper
new_nodes = []
self.onnx_weight = open(self.onnx_weight_path, 'wb')
for node in self.onnx_model.graph.node:
if node.op_type == 'FakeLinear':
attributes = {a.name: a for a in node.attribute}
name = attributes.get('name').s.decode('utf-8')
has_bias = attributes.get('has_bias').i
ic = attributes.get('in_features').i
oc = attributes.get('out_features').i
weight, bias = self.build_weight(name, has_bias, ic, oc)
if has_bias:
# fakelinear -> matmul + add
middle_tensor = f'{name}_matmul'
new_nodes.append(helper.make_node('MatMul', [node.input[0], weight], [middle_tensor], name))
new_nodes.append(helper.make_node('Add', [middle_tensor, bias], node.output, name))
else:
# fakelinear -> matmul
new_nodes.append(helper.make_node('MatMul', [node.input[0], weight], node.output, name))
else:
new_nodes.append(node)
self.onnx_weight.close()
del self.onnx_model.graph.node[:]
self.onnx_model.graph.node.extend(new_nodes)
onnx.save(self.onnx_model, self.dst_path)
return self.onnx_weight_path
class MNNConveter:
def __init__(self, onnx_path, weight_ops, config):
self.weight_ops = weight_ops
self.quant_block = config.quant_block
self.quant_bit = config.quant_bit
self.lm_quant_bit = config.lm_quant_bit
self.mnn_weight_offset = 0
self.onnx_model_path = onnx_path
self.mnn_model_path = onnx_path.replace('.onnx', '.mnn')
self.mnn_weight_path = f'{self.mnn_model_path}.weight'
if os.path.exists(config.mnnconvert):
self.mnnconvert = config.mnnconvert
else:
self.mnnconvert = None
def convert(self, convert_args):
sfd = os.dup(1)
log_fp = open(EXPORT_LOG, "a")
log_fd = log_fp.fileno()
# mnnconvert ... > .convert_mnn.log
os.dup2(log_fd, 1)
try:
sys.argv = convert_args
sys.argc = len(convert_args)
if self.mnnconvert is None:
from MNN.tools import mnnconvert
mnnconvert.main()
else:
convert_args[0] = self.mnnconvert
cmd = ' '.join(convert_args)
message = os.popen(cmd).read()
print(message)
sys.argv = []
finally:
os.dup2(sfd, 1)
os.close(log_fd)
@spinner_run(f'convert onnx model to ')
def onnx2mnn(self, onnx_path, mnn_path, args = []):
convert_args = [
'',
'-f',
'ONNX',
'--modelFile',
str(onnx_path),
'--MNNModel',
str(mnn_path),
'--transformerFuse',
'--allowCustomOp'
]
convert_args += args
self.convert(convert_args)
return mnn_path
def mnn2json(self, mnn_path, json_path):
convert_args = [
'',
'-f',
'MNN',
'--modelFile',
str(mnn_path),
'--JsonFile',
str(json_path)
]
self.convert(convert_args)
return json_path
def json2mnn(self, json_path, mnn_path):
convert_args = [
'',
'-f',
'JSON',
'--modelFile',
str(json_path),
'--MNNModel',
str(mnn_path)
]
self.convert(convert_args)
return mnn_path
def export(self):
if self.weight_ops is None:
quant_args = [
'--weightQuantBits',
str(self.quant_bit),
'--weightQuantBlock',
str(self.quant_block)
]
self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, quant_args)
else:
mnn_json = f'{self.mnn_model_path}.json'
self.onnx2mnn(self.onnx_model_path, self.mnn_model_path)
self.mnn2json(self.mnn_model_path, mnn_json)
self.rebuild(mnn_json)
self.json2mnn(mnn_json, self.mnn_model_path)
@spinner_run(f'quant model weight to ')
def rebuild(self, json_path):
self.mnn_weight = open(self.mnn_weight_path, 'wb')
mnn_graph = json.load(open(json_path, 'rt'))
new_ops = []
for op in mnn_graph['oplists']:
if op['type'] == 'Extra':
new_ops += self.rebuild_op(op, mnn_graph)
else:
new_ops.append(op)
mnn_graph['oplists'] = new_ops
with open(json_path, 'w', encoding='utf-8') as file:
json.dump(mnn_graph, file, ensure_ascii=False, indent=4)
return self.mnn_weight_path
def quant(self, weight, quant_bit, quant_block):
weight = weight.numpy()
oc, ic = weight.shape
if quant_block == 0:
block_size = ic
else:
block_size = quant_block
block_num = ic // block_size
weight = weight.reshape(oc, block_num, block_size)
max_val = np.max(weight, axis=-1, keepdims=True)
min_val = np.min(weight, axis=-1, keepdims=True)
offset = 1 << (quant_bit - 1)
clip_max = offset - 1
clip_min = -offset
scale = (max_val - min_val) / (clip_max - clip_min)
q_weight = np.round((weight - min_val) / scale) + clip_min
q_weight = (np.clip(q_weight.flatten(), clip_min, clip_max) + offset).astype(np.uint8)
q_weight = q_weight.reshape(-1, 2)
if quant_bit == 4:
q_weight = q_weight[:, 0] * 16 + q_weight[:, 1]
alpha = np.stack([min_val.flatten(), scale.flatten()], axis=-1).flatten()
return q_weight, alpha, clip_min
def write_npy(self, data):
return self.mnn_weight.write(data.tobytes())
def write_header(self, ic, oc, quant_bit):
dim_num = self.mnn_weight.write(b'\x02')
shape_dtype = np.int16
if oc > 65535 or ic > 65535:
shape_dtype = np.int32
dim_length = self.write_npy(np.array([oc, ic]).astype(shape_dtype))
offset = 1 << (quant_bit - 1)
weight_map = [i for i in range(-offset, offset)]
if len(weight_map) == 256:
weight_map.insert(0, 0)
else:
weight_map.insert(0, len(weight_map))
map_length = self.write_npy(np.array(weight_map, dtype=np.int8))
header_length = dim_num + dim_length + map_length
return header_length, shape_dtype == np.int32
def build_weight(self, linear, quant_bit, quant_block):
ic, oc = linear.in_features, linear.out_features
q_weight, alpha, q_min = self.quant(linear.weight.data, quant_bit, quant_block)
header_len, shape_int32 = self.write_header(ic, oc, quant_bit)
weight_len = self.write_npy(q_weight) + header_len
alpha_len = self.write_npy(alpha)
if linear.bias is not None:
bias = linear.bias.data.flatten().numpy()
bias_length = self.write_npy(bias)
else:
bias_length = 0
# bias = np.zeros([oc], dtype=np.float32)
# bias_length = self.write_npy(bias)
external = [self.mnn_weight_offset, weight_len, alpha_len, bias_length, 0]
self.mnn_weight_offset += (weight_len + alpha_len + bias_length)
return external, q_min, shape_int32
def build_tensor(self, graph, tensor_name):
tensor_idx = [len(graph['tensorName'])]
graph['tensorName'].append(tensor_name)
return tensor_idx
def rebuild_op(self, op, graph):
attrs = op['main']['attr']
for attr in attrs:
if attr['key'] == 'name':
name = attr['s']
elif attr['key'] == "in_features":
ic = attr["i"]
elif attr['key'] == "out_features":
oc = attr["i"]
elif attr['key'] == "has_bias":
has_bias = attr["i"]
linear = self.weight_ops[name]
assert(linear.in_features == ic and
linear.out_features == oc and
(linear.bias is not None) == has_bias)
quant_bit = self.lm_quant_bit if 'lm_head' in name else self.quant_bit
external, q_min, shape_int32 = self.build_weight(linear, quant_bit, self.quant_block)
origin_input = op['inputIndexes']
origin_output = op['outputIndexes']
# build new tensor
pre_reshape_name = f'{name}/pre_reshape'
pre_convert_name = f'{name}/pre_convert'
conv_name = name
post_convert_name = f'{name}/post_convert'
post_reshape_name = f'{name}/post_reshape'
pre_reshape_output = self.build_tensor(graph, pre_reshape_name)
pre_convert_output = self.build_tensor(graph, pre_convert_name)
conv_output = self.build_tensor(graph, conv_name)
post_convert_output = self.build_tensor(graph, post_convert_name)
# [batch, seq, hidden_size_i] -[Linear] -> [batch, seq, hidden_size_o]
# [1, seq, hidden_size_i] ->[Reshape]-> [seq, hidden_size_i, 1, 1]
# -[Convert]-[Convolution]-[Convert]-> [Reshape] -> [1, seq, hidden_size_o]
pre_reshape = {
"name": pre_reshape_name,
"type": "Reshape",
"inputIndexes": origin_input,
"outputIndexes": pre_reshape_output,
"main_type": "Reshape",
"main": {
"dims": [-1, ic, 1, 1],
"dimType": "NCHW"
},
"defaultDimentionFormat": "NHWC"
}
pre_convert = {
"name": pre_convert_name,
"inputIndexes": pre_reshape_output,
"outputIndexes": pre_convert_output,
"type": "ConvertTensor",
"main_type": "TensorConvertInfo",
"main": {
"source": "NCHW",
"dest": "NC4HW4"
},
"defaultDimentionFormat": "NHWC"
}
conv_op = {
"name": conv_name,
"inputIndexes": pre_convert_output,
"outputIndexes": conv_output,
"type": "Convolution",
"main_type": "Convolution2D",
"main": {
'common': {
'dilateX': 1, 'dilateY': 1, 'strideX': 1, 'strideY': 1,
'kernelX': 1, 'kernelY': 1, 'padX': 0, 'padY': 0, 'group': 1,
'outputCount': oc, 'relu': False, 'padMode': 'CAFFE',
'relu6': False, 'inputCount': ic, 'hasOutputShape': False
},
"quanParameter": {
"quantScale": 1.0, "scaleIn": 0.0, "scaleOut": 0.0,
"useInt32": False, "has_scaleInt": False, "shapeInt32": shape_int32,
"type": 1, "aMax": 0, "aMin": q_min, "readType": oc * (ic // self.quant_block), "weightSize": 0
},
"external": external
},
"defaultDimentionFormat": "NHWC"
}
post_convert = {
"name": post_convert_name,
"inputIndexes": conv_output,
"outputIndexes": post_convert_output,
"type": "ConvertTensor",
"main_type": "TensorConvertInfo",
"main": {
"source": "NC4HW4",
"dest": "NCHW"
},
"defaultDimentionFormat": "NHWC"
}
post_reshape = {
"name": post_reshape_name,
"type": "Reshape",
"inputIndexes": post_convert_output,
"outputIndexes": origin_output,
"main_type": "Reshape",
"main": {
"dims": [1, -1, oc],
"dimType": "NCHW"
},
"defaultDimentionFormat": "NHWC"
}
return [pre_reshape, pre_convert, conv_op, post_convert, post_reshape]
# some wrapper class for export
class Embedding(torch.nn.Module):
def __init__(self, embed, config):
super().__init__()
self.hidden_size = config.hidden_size
self.embed = embed
if config.model_type == 'gemma2':
normalizer = torch.tensor(self.hidden_size**0.5)
self.embed.weight.data *= normalizer
def forward(self, input_ids):
inputs_embeds = self.embed(input_ids).view(-1, 1, self.hidden_size)
return inputs_embeds
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Attention(torch.nn.Module):
def __init__(self, attn, config):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.rotary = config.rotary
ModelMapper.do_map(self, attn, config.model_map['attention'])
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
# split qkv linear to q, k, v
split_sizes = [self.hidden_size] * 3
if self.qkv_proj.weight.shape[0] != self.hidden_size * 3:
# M/GQA
qkv_hidden_size = self.qkv_proj.weight.shape[0]
kv_hidden_size = (qkv_hidden_size - self.hidden_size) // 2
split_sizes = [self.hidden_size, kv_hidden_size, kv_hidden_size]
self.q_proj = torch.nn.Linear(self.hidden_size, split_sizes[0])
self.k_proj = torch.nn.Linear(self.hidden_size, split_sizes[1])
self.v_proj = torch.nn.Linear(self.hidden_size, split_sizes[2])
if config.model_type == 'chatglm':
# chatglm-6b
qkv_weight = self.qkv_proj.weight.data.view(self.num_heads, 3, self.head_dim, self.hidden_size)
self.q_proj.weight.data = qkv_weight[:, 0, :, :].reshape(self.hidden_size, self.hidden_size)
self.k_proj.weight.data = qkv_weight[:, 1, :, :].reshape(self.hidden_size, self.hidden_size)
self.v_proj.weight.data = qkv_weight[:, 2, :, :].reshape(self.hidden_size, self.hidden_size)
qkv_bias = self.qkv_proj.bias.data.view(self.num_heads, 3, self.head_dim)
self.q_proj.bias.data = qkv_bias[:, 0, :].reshape(self.hidden_size)
self.k_proj.bias.data = qkv_bias[:, 1, :].reshape(self.hidden_size)
self.v_proj.bias.data = qkv_bias[:, 2, :].reshape(self.hidden_size)
else:
# other
qw, kw, vw = torch.split(self.qkv_proj.weight, split_sizes)
self.q_proj.weight.data = qw
self.k_proj.weight.data = kw
self.v_proj.weight.data = vw
if self.qkv_proj.bias is not None:
qb, kb, vb = torch.split(self.qkv_proj.bias, split_sizes)
self.q_proj.bias.data = qb
self.k_proj.bias.data = kb
self.v_proj.bias.data = vb
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
kv_seq_len = key_states.shape[1]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[1]
# rope
cos, sin = rotary_pos_emb[0], rotary_pos_emb[1]
query_states = self.rotary.apply_rotary_pos(query_states, cos, sin)
key_states = self.rotary.apply_rotary_pos(key_states, cos, sin)
# kv cache
if past_key_value is not None:
past_key, past_value = past_key_value[0], past_key_value[1]
key_states = torch.cat((past_key, key_states), dim=1)
value_states = torch.cat((past_value, value_states), dim=1)
past_key_value = torch.stack((key_states, value_states))
query_states = query_states.transpose(1, 2)
key_states = key_states.permute([0, 2, 3, 1])
value_states = value_states.transpose(1, 2)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
#------- attention ----------
# query_states @ key_states
attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim)
# attention_mask
if attention_mask.dtype in (torch.bool, torch.int32):
# chatglm
attn_weights.masked_fill_(attention_mask, -10000.0)
else:
attn_weights = attn_weights + attention_mask
# upcast softmax to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# attn_weights @ value_states
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, past_key_value
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class Rotary(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.rope_theta = config.rope_theta
self.rotary_dim = config.head_dim
self.model_type = config.model_type
if hasattr(config, 'rotary_dim'):
self.rotary_dim = config.rotary_dim
if self.model_type == 'chatglm':
self.rotary_dim = config.head_dim // 2
def forward(self, position_ids):
theta = 1.0 / (self.rope_theta ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim))
position_ids = position_ids.float().reshape(-1, 1)
idx_theta = position_ids * theta
rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)])
if self.model_type != 'chatglm2':
rotary_pos_emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
rotary_pos_emb = rotary_pos_emb.unsqueeze(2).unsqueeze(1)
return rotary_pos_emb
def apply_rotary_pos(self, x, cos, sin):
if self.model_type == 'chatglm':
return self.chatglm_rotary_pos(x, cos, sin)
if self.model_type == 'chatglm2':
return self.chatglm2_rotary_pos(x, cos, sin)
if self.model_type == 'phi-msft':
return self.phi_rotary_pos(x, cos, sin)
return self.llama_rotary_pos(x, cos, sin)
def llama_rotary_pos(self, x, cos, sin):
x = (x * cos) + (rotate_half(x) * sin)
return x
def phi_rotary_pos(self, x, cos, sin):
x, x_pass = x[..., :self.rotary_dim], x[..., self.rotary_dim:]
x = (x * cos) + (rotate_half(x) * sin)
return torch.cat((x, x_pass), dim=-1)
def chatglm2_rotary_pos(self, x, cos, sin):
x, x_pass = x[..., :self.rotary_dim], x[..., self.rotary_dim:]
b, s, n, h = x.shape
xshaped = x.view(b, s, n, h//2, 2)
x = torch.concat(
[
xshaped[..., 0] * cos - xshaped[..., 1] * sin,
xshaped[..., 1] * cos + xshaped[..., 0] * sin,
],
-1,
)
return torch.cat((x, x_pass), dim=-1)
def chatglm_rotary_pos(self, x, cos, sin):
seq = x.shape[1]
x1, x2 = x[..., :self.rotary_dim], x[..., self.rotary_dim:]
cos1, sin1 = cos[:, :seq, ...], sin[:, :seq, ...]
cos2, sin2 = cos[:, seq:, ...], sin[:, seq:, ...]
x1 = (x1 * cos1) + (rotate_half(x1) * sin1)
x2 = (x2 * cos2) + (rotate_half(x2) * sin2)
return torch.cat((x1, x2), dim=-1)
class Decoder(torch.nn.Module):
def __init__(self, decoder, config):
super().__init__()
ModelMapper.do_map(self, decoder, config.model_map['decoder'])
self.hidden_size = config.hidden_size
self.self_attn = Attention(self.self_attn, config)
# chatglm
self.alpha = (2 * config.num_hidden_layers) ** 0.5 if config.model_type == 'chatglm' else 1.0
def forward(
self,
hidden_states: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = hidden_states.view(1, -1, self.hidden_size)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
norm_hidden_states = hidden_states
# Self Attention
hidden_states, present_key_value = self.self_attn(
hidden_states=hidden_states,
rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
past_key_value=past_key_value,
)
# Fully Connected
if not hasattr(self, 'post_attention_layernorm'):
# phi
feed_forward_hidden_states = self.mlp(norm_hidden_states)
hidden_states = hidden_states + feed_forward_hidden_states + residual
elif self.alpha != 1.0:
# chatglm-6b
hidden_states = norm_hidden_states * self.alpha + hidden_states
mlp_input = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(mlp_input)
hidden_states = mlp_input * self.alpha + mlp_output
elif hasattr(self, 'pre_feedforward_layernorm'):
# gemma2
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
else:
# general
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, present_key_value
class Lm(torch.nn.Module):
def __init__(self, lm_, final_layernorm_, config):
super().__init__()
self.final_layernorm = final_layernorm_
self.lm = lm_
self.hidden_size = config.hidden_size
def forward(self, hidden_states):
hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
hidden_states = self.final_layernorm(hidden_states)
m_logits = self.lm(hidden_states)
return m_logits
class LlmExporter(torch.nn.Module):
'''
Base class for all llm model export. Inherits from [`torch.nn.Module`].
'''
def __init__(self, args):
super().__init__()
self.init_from_args(args)
self.load_model(args.path)
def init_from_args(self, args):
self.max_length = 1024
self.stop_ids = []
self.visual = None
self.dst_name = 'llm'
# load config from args
self.path = args.path
self.dst_path = args.dst_path
self.lora_path = args.lora_path
self.skip_slim = args.skip_slim
self.quant_bit = args.quant_bit
self.quant_block = args.quant_block
self.mnnconvert = args.mnnconvert
if args.lm_quant_bit is not None:
self.lm_quant_bit = args.lm_quant_bit
else:
self.lm_quant_bit = self.quant_bit
# init export dst dir
if not os.path.exists(self.dst_path):
os.makedirs(self.dst_path)
def load_pretrained(self, model_path: str):
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
try:
self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).float().eval()
except:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).float().eval()
self.config = self.model.config
if self.lora_path is not None:
from peft import PeftModel
adapter = PeftModel.from_pretrained(self.model, model_id=self.lora_path)
self.model = adapter.merge_and_unload(progressbar=True)
@spinner_run(f'load pretrained model ')
def load_model(self, model_path):
self.load_pretrained(model_path)
self.attention_mask_type = 'float'
# load tokenizer info
self.stop_ids.append(self.tokenizer.eos_token_id)
if hasattr(self.tokenizer, 'im_end_id'):
self.stop_ids.append(self.tokenizer.im_end_id)
eot_id = self.tokenizer.encode('<|eot_id|>')
if len(eot_id) == 1:
self.stop_ids.append(eot_id[0])
if hasattr(self.model, 'generation_config'):
eos_token_id = self.model.generation_config.eos_token_id
from collections.abc import Iterable
if isinstance(eos_token_id, int):
self.stop_ids.append(eos_token_id)
elif isinstance(eos_token_id, Iterable):
for id in eos_token_id:
self.stop_ids.append(id)
self.stop_ids = [stop_id for stop_id in self.stop_ids if stop_id is not None]
self.stop_ids = list(set(self.stop_ids))
model_mapper = ModelMapper()
self.model_type, self.model_map = model_mapper.get_map(self.config)
# print(self.model)
# print(self.model_type, self.model_map)
# load config info
ModelMapper.do_map(self, self.config, self.model_map['config'])
if not hasattr(self, 'num_key_value_heads') or self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if not hasattr(self, 'rope_theta') or self.rope_theta is None:
self.rope_theta = 10000.0
if not hasattr(self, 'head_dim') or self.head_dim is None:
self.head_dim = self.hidden_size // self.num_attention_heads
# some export info
self.past_kv_shape = [self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads, self.head_dim]
self.block_dynamic_axes = {
"inputs_embeds" : { 0: "seq_len" },
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
"position_ids" : { 0: "seq_len" },
"past_key_values" : { 1: "history_len" }
}
self.model_dynamic_axes = {
"input_ids" : { 0: "seq_len" },
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
"position_ids" : { 0: "seq_len" },
"past_key_values" : { 2: "history_len" }
}
self.llm_config = {
'hidden_size' : self.hidden_size,
'layer_nums' : self.num_hidden_layers,
'attention_mask': self.attention_mask_type,
'key_value_shape': self.past_kv_shape[1:],
"prompt_template": self.build_prompt('%s'),
'is_visual': False
}
# load modules
ModelMapper.do_map(self, self.model, self.model_map['model'])
# rebuild modules
if self.embed_.weight is self.lm_.weight:
import copy
embed_copy = copy.deepcopy(self.embed_)
self.embed = Embedding(embed_copy, self)
else:
self.embed = Embedding(self.embed_, self)
# Rotary
self.rotary = Rotary(self)
self.blocks = []
for block in self.blocks_.children():
self.blocks.append(Decoder(block, self))
self.lm = Lm(self.lm_, self.final_layernorm_, self)
# visual model
if self.visual is not None:
self.image_start_id = self.config.visual['image_start_id']
self.image_size = self.config.visual['image_size']
self.llm_config['is_visual'] = True
self.llm_config['img_size'] = self.image_size
self.llm_config['imgpad_len'] = 256
self.llm_config['img_start'] = self.tokenizer.img_start_id
self.llm_config['img_end'] = self.tokenizer.img_end_id
self.llm_config['img_pad'] = self.tokenizer.img_pad_id
return model_path
def get_attention_mask(self) -> torch.Tensor:
if self.model_type == 'chatglm':
return self.chatglm_attention_mask()
if self.token_len:
return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32)
return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min
def get_position_ids(self) -> torch.Tensor:
if self.model_type == 'chatglm':
return self.chatglm_position_ids()
if self.token_len:
return torch.tensor([[self.seq_len - 1]], dtype=torch.long)
return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
def chatglm_attention_mask(self):
if self.token_len:
return torch.zeros([1]).bool().reshape([1, 1, 1, 1])
attention_mask = torch.zeros([self.seq_len, self.seq_len], dtype=torch.bool)
for i in range(self.seq_len - 1):
attention_mask[i][-1] = True
attention_mask = attention_mask.reshape([1, 1, self.seq_len, self.seq_len])
return attention_mask
def chatglm_position_ids(self):
if self.token_len:
return torch.tensor([self.context_len, self.token_len + 1]).reshape([1, 2, 1])
position_ids_0 = torch.arange(self.seq_len, dtype=torch.long)
position_ids_1 = torch.zeros(self.seq_len, dtype=torch.long)
position_ids_0[-1] = position_ids_0[-2]
position_ids_1[-1] = 1
position_ids = torch.stack([position_ids_0, position_ids_1]).view(1, 2, -1)
return position_ids
def visual_embed(self, input_ids):
if not torch.any(input_ids == self.image_start_id):
return self.embed(input_ids)
bos_pos = torch.where(input_ids == self.image_start_id)
eos_pos = torch.where(input_ids == self.image_start_id + 1)
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
images = []
for i, a, b in img_pos:
image = input_ids[i][a + 1 : b - 1].tolist()
image = image[ : image.index(self.image_start_id + 2)]
images.append(bytes(image).decode('utf-8'))
images = self.visual.encode(images)
hidden_states = self.embed(input_ids).view(1, -1, self.hidden_size)
for idx, (i, a, b) in enumerate(img_pos):
hidden_states[i][a + 1 : b] = images[idx]
return hidden_states.view(-1, 1, self.hidden_size)
def embedding(self, input_ids):
if self.visual is not None and self.token_len == 0:
input_embeds = self.visual_embed(input_ids)
else:
input_embeds = self.embed(input_ids)
return input_embeds
def forward(self, input_ids, attention_mask, position_ids, past_key_values):
hidden_states = input_ids # llm forward without embedding
presents = []
rotary_pos_emb = self.rotary(position_ids)
for i in range(self.num_hidden_layers):
hidden_states, kv = self.blocks[i](hidden_states, rotary_pos_emb, attention_mask, past_key_values[i])
presents.append(kv)
logits = self.lm(hidden_states).reshape(-1)
presents = torch.stack(presents)
self.seq_len += 1
self.token_len += 1
return logits, presents
# some test functions
def build_prompt(self, query):
# just for test
if 'Qwen2' in self.path:
return f'<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n'
if 'Qwen' in self.path:
return f'\n<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n'
if 'Baichuan2' in self.path:
return f'<reserved_106>{query}<reserved_107>'
if 'internlm' in self.path:
return f'<|User|>:{query}<eoh>\n<|Bot|>:'
if 'TinyLlama' in self.path:
return f'<s><|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s>\n<|user|>\n{query}</s>\n<|assistant|>\n'
if 'Yi' in self.path:
return f'<|im_start|> user\n{query}<|im_end|>\n<|im_start|> assistant\n'
if 'deepseek' in self.path:
return f'<|begin_of_sentence|>User: {query}\n\nAssistant:'
if 'Llama-3.1' in self.path:
return f'<|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
if 'Llama-3' in self.path:
return f'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
if 'Llama-2' in self.path:
return f'[INST]{query}[/INST]'
if 'chatglm2' in self.path:
return f'[Round 1]\n\n问:{query}\n\n答:'
if 'chatglm3' in self.path or 'glm-4' in self.path:
return f'<|user|>\n{query}\n<|assistant|>\n'
if 'chatglm' in self.path:
return f'{query}[gMASK]<sop>'
if 'phi-2' in self.path:
return f'Instruct: {query}\nOutput:'
if 'gemma-2' in self.path:
return f'<bos><start_of_turn>user\n{query}<end_of_turn>\n<start_of_turn>model\n'
return query
def str_to_ids(self, prompt):
input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids']
return input_ids
def id_to_str(self, token_id):
word = self.tokenizer._convert_id_to_token(int(token_id))
word = self.tokenizer.convert_tokens_to_string([word])
return word
def response(self, query):
self.imitate_quant()
prompt = self.build_prompt(query)
input_ids = self.str_to_ids(prompt)
# print(f'prompt = {prompt}, ids = {input_ids}')
self.seq_len = input_ids.numel()
self.context_len = self.seq_len - 2
self.token_len = 0
past_key_values = [None for i in range(self.num_hidden_layers)]
token_id = input_ids
while self.token_len < self.max_length:
attention_mask = self.get_attention_mask()
position_ids = self.get_position_ids()
input_ids = self.embed(token_id)
logits, past_key_values = self.forward(input_ids, attention_mask, position_ids, past_key_values)
token_id = torch.argmax(logits)
if token_id in self.stop_ids:
print("", end='\n')
break
word = self.id_to_str(token_id)
print(word, end="", flush=True)
def export_visual(self):
if self.visual is None:
return
input_images = torch.randn((1, 3, self.image_size, self.image_size))
model = self.visual
onnx_model = f'{self.dst_path}/visual.onnx'
torch.onnx.export(model, (input_images),
onnx_model,
input_names=['input_images'],
output_names=['image_embeds'],
dynamic_axes={"input_images": {
0: "size"
}},
do_constant_folding=True,
opset_version=15)
return onnx_model
if not self.skip_slim:
slim(onnx_model, output_model=onnx_model)
@spinner_run(f'export embedding to ')
def export_embed(self):
import ctypes
if hasattr(self, 'word_embeddings'):
# embedding model's embed
tensor_data = self.word_embeddings.weight.data.bfloat16()
else:
tensor_data = self.embed.embed.weight.data.bfloat16()
data_ptr = tensor_data.untyped_storage().data_ptr()
buffer = (ctypes.c_byte * (tensor_data.numel() * 2)).from_address(data_ptr)
embedding_file = f'{self.dst_path}/embeddings_bf16.bin'
with open(embedding_file, 'wb') as f:
f.write(buffer)
return embedding_file
@spinner_run(f'export config to ')
def export_config(self, mnn_config = False):
config_json = f'{self.dst_path}/llm_config.json'
with open(config_json, 'w', encoding='utf-8') as f:
json.dump(self.llm_config, f, ensure_ascii=False, indent=4)
if not mnn_config:
return config_json
with open(f'{self.dst_path}/config.json', 'w', encoding='utf-8') as f:
config = {
"llm_model": f"{self.dst_name}.mnn",
"llm_weight": f"{self.dst_name}.mnn.weight",
"backend_type": "cpu",
"thread_num": 4,
"precision": "low",
"memory": "low"
}
json.dump(config, f, ensure_ascii=False, indent=4)
return config_json
def quant(self, weight, quant_bit, quant_block):
weight = weight.numpy()
oc, ic = weight.shape
if quant_block == 0:
block_size = ic
else:
block_size = quant_block
block_num = ic // block_size
weight = weight.reshape(oc, block_num, block_size)
max_val = np.max(weight, axis=-1, keepdims=True)
min_val = np.min(weight, axis=-1, keepdims=True)
offset = 1 << (quant_bit - 1)
clip_max = offset - 1
clip_min = -offset
scale = (max_val - min_val) / (clip_max - clip_min)
q_weight = np.round((weight - min_val) / scale) + clip_min
q_weight = (np.clip(q_weight.flatten(), clip_min, clip_max) + offset).astype(np.uint8)
q_weight = q_weight.reshape(-1, 2)
if quant_bit == 4:
q_weight = q_weight[:, 0] * 16 + q_weight[:, 1]
alpha = np.stack([min_val.flatten(), scale.flatten()], axis=-1).flatten()
return q_weight, alpha, clip_min
def imitate_quant(self):
def quant_dequant(linear, quant_bit = self.quant_bit, quant_block = self.quant_block):
weight = linear.weight.data
oc, ic = weight.shape
if quant_block == 0:
block_size = ic
else:
block_size = quant_block
block_num = ic // block_size
weight = weight.reshape(oc, block_num, block_size)
max_val = torch.max(weight, axis=-1, keepdims=True).values
min_val = torch.min(weight, axis=-1, keepdims=True).values
offset = 1 << (quant_bit - 1)
clip_max = offset - 1
clip_min = -offset
scale = (max_val - min_val) / (clip_max - clip_min)
q_weight = torch.round((weight - min_val) / scale) + clip_min
q_weight = torch.clip(q_weight, clip_min, clip_max)
dq_weight = (q_weight - clip_min) * scale + min_val
dq_weight = dq_weight.reshape(oc, ic).float()
linear.weight.data = dq_weight
return linear
with torch.no_grad():
for i in range(self.num_hidden_layers):
for name, child in self.blocks[i].self_attn.named_children():
if isinstance(child, torch.nn.Linear):
setattr(self.blocks[i].self_attn, name, quant_dequant(child))
for name, child in self.blocks[i].mlp.named_children():
if isinstance(child, torch.nn.Linear):
setattr(self.blocks[i].mlp, name, quant_dequant(child))
self.lm.lm = quant_dequant(self.lm.lm)
def unload_param(self):
self.unloaded_ops = {}
def build_faker(real, name):
faker = FakeLinear(real.in_features, real.out_features, real.bias is not None, name)
self.unloaded_ops[name] = real
return faker
# replace linear with fakelinear to save export memory and time
with torch.no_grad():
for i in range(self.num_hidden_layers):
for name, child in self.blocks[i].self_attn.named_children():
if isinstance(child, torch.nn.Linear):
setattr(self.blocks[i].self_attn, name, build_faker(child, f'/layers.{i}/self_attn/{name}/Linear'))
for name, child in self.blocks[i].mlp.named_children():
if isinstance(child, torch.nn.Linear):
setattr(self.blocks[i].mlp, name, build_faker(child, f'/layers.{i}/mlp/{name}/Linear'))
self.lm.lm = build_faker(self.lm.lm, f'/lm/lm_head/Linear')
@spinner_run(f'export model weight to ')
def onnx_load_param(self, onnx_path):
return OnnxRebuilder(onnx_path, self.unloaded_ops).rebuild()
@spinner_run(f'slim the graph of ')
def onnx_slim(self, onnx_model):
import onnxslim
model = onnxslim.slim(onnx_model)
onnx.save(model, onnx_model)
return onnx_model
@spinner_run(f'export onnx model to ')
def export_onnx(self):
# unload linear weight to save export memory
self.unload_param()
model = self
self.seq_len = 3
self.token_len = 0
input_ids = torch.arange(3, dtype=torch.long)
attention_mask = self.get_attention_mask()
position_ids = self.get_position_ids()
past_key_values = torch.zeros(self.past_kv_shape)
onnx_model = f'{self.dst_path}/{self.dst_name}.onnx'
input_ids = self.embedding(input_ids)
# export to onnx
torch.onnx.export(
model, (input_ids, attention_mask, position_ids, past_key_values),
onnx_model,
input_names=[
'input_ids', 'attention_mask', 'position_ids', 'past_key_values'
],
output_names=['logits', 'presents'],
dynamic_axes=self.model_dynamic_axes,
do_constant_folding=True,
opset_version=15)
return onnx_model
def export(self, export_type):
export_mnn = export_type == 'mnn'
# export tokenizer
self.export_tokenizer()
self.export_config(export_mnn)
self.export_embed()
if self.visual:
self.export_visual()
# export graph to llm.onnx
onnx_model = self.export_onnx()
if not self.skip_slim:
self.onnx_slim(onnx_model)
if export_mnn:
# convert onnx to mnn and quant weight
MNNConveter(onnx_model, self.unloaded_ops, self).export()
else:
# export weight to llm.onnx.data
self.onnx_load_param(onnx_model)
@spinner_run(f'export tokenizer to ')
def export_tokenizer(self):
# load tokenizer file
tokenizer_model = os.path.join(self.path, 'tokenizer.model')
ice_text_model = os.path.join(self.path, 'ice_text.model')
try:
import sentencepiece as spm
if os.path.exists(tokenizer_model):
self.sp_model = spm.SentencePieceProcessor(tokenizer_model)
elif os.path.exists(ice_text_model):
self.sp_model = spm.SentencePieceProcessor(ice_text_model)
else:
self.sp_model = None
except:
self.sp_model = None
merge_file = os.path.join(self.path, 'merges.txt')
if os.path.exists(merge_file):
self.merge_txt = merge_file
else:
self.merge_txt = None
# TOKENIZER MAGIC NUMBER
MAGIC_NUMBER = 430
# TOKENIZER TYPE
SENTENCEPIECE = 0; TIKTOIKEN = 1; BERT = 2; HUGGINGFACE = 3
def write_line(fp, *args):
for arg in args:
for token in arg:
fp.write(str(token) + ' ')
fp.write('\n')
def write_header(fp, type, speicals, prefix = []):
fp.write(f'{MAGIC_NUMBER} {type}\n')
fp.write(f'{len(speicals)} {len(self.stop_ids)} {len(prefix)}\n')
write_line(fp, speicals, self.stop_ids, prefix)
file_path = os.path.join(self.dst_path, "tokenizer.txt")
special_list = list(self.tokenizer.added_tokens_decoder.keys())
if hasattr(self.tokenizer, 'special_tokens'):
for k, v in self.tokenizer.special_tokens.items():
special_list.append(v)
if hasattr(self.tokenizer, 'gmask_token_id'):
special_list.append(self.tokenizer.gmask_token_id)
vocab_list = []
prefix_list = []
if hasattr(self.tokenizer, 'get_prefix_tokens'):
prefix_list = self.tokenizer.get_prefix_tokens()
if self.sp_model is not None:
# senetencepiece
NORMAL = 1; UNKNOWN = 2; CONTROL = 3
USER_DEFINED = 4; UNUSED = 5; BYTE = 6
for i in range(self.sp_model.GetPieceSize()):
token = self.sp_model.IdToPiece(i)
score = self.sp_model.GetScore(i)
token_type = NORMAL
if self.sp_model.IsUnknown(i):
token_type = UNKNOWN
elif self.sp_model.IsControl(i):
token_type = CONTROL
elif self.sp_model.IsUnused(i):
token_type = UNUSED
elif self.sp_model.IsByte(i):
token_type = BYTE
if self.path == 'Chatglm_6b':
if '<n>' in token: token = '\n'
if '<|tab|>' in token: token = '\t'
if '<|blank_' in token: token = ' ' * int(token[8:token.find('|>')])
if '' in token: token = token.replace('', ' ')
token_encode = base64.b64encode(token.encode("utf-8")).decode("utf8")
vocab_list.append(f'{token_encode} {score} {token_type}\n')
with open(file_path, "w", encoding="utf8") as fp:
write_header(fp, SENTENCEPIECE, special_list, prefix_list)
fp.write(f'{len(vocab_list)}\n')
for vocab in vocab_list:
fp.write(vocab)
elif hasattr(self.tokenizer, 'mergeable_ranks'):
# tikton
vocab_list = []
for k, v in self.tokenizer.mergeable_ranks.items():
line = base64.b64encode(k).decode("utf8") + "\n"
vocab_list.append(line)
if hasattr(self.tokenizer, 'special_tokens'):
for k, v in self.tokenizer.special_tokens.items():
line = base64.b64encode(k.encode("utf-8")).decode("utf8") + "\n"
vocab_list.append(line)
if hasattr(self.tokenizer, 'added_tokens_decoder'):
for k, v in self.tokenizer.added_tokens_decoder.items():
line = base64.b64encode(v.__str__().encode("utf-8")).decode("utf8") + "\n"
vocab_list.append(line)
with open(file_path, "w", encoding="utf8") as fp:
write_header(fp, TIKTOIKEN, special_list, prefix_list)
fp.write(f'{len(vocab_list)}\n')
for vocab in vocab_list:
fp.write(vocab)
elif self.merge_txt is not None:
# huggingface tokenizer
merge_list = []
vocab = self.tokenizer.get_vocab()
special_list = list(self.tokenizer.added_tokens_decoder.keys())
vocab_list = ['<unk>' for i in range(len(vocab))]
# load vocab
for k, v in vocab.items():
vocab_list[int(v)] = k
# load merge
with open(self.merge_txt, 'rt') as merge:
for line in merge.readlines():
merge_list.append(line)
# write to tokenizer.txt
with open(file_path, "w", encoding="utf8") as fp:
write_header(fp, HUGGINGFACE, special_list)
fp.write(f'{len(vocab_list)} {len(merge_list)}\n')
for v in vocab_list:
fp.write(v + '\n')
for m in merge_list:
fp.write(m)
else:
# tiktoken or bert
if 'bert' in type(self.tokenizer).__name__.lower():
tokenizer_type = BERT
else:
tokenizer_type = TIKTOIKEN
# bert tokenizer
def unicode_to_byte(u: int):
if u >= 256 and u <= 288:
return u - 256
if u >= 289 and u <= 322:
return u - 162
if u == 323:
return 173
if u == 65372: # |
return 124
if u == 9601: # _
return 95
return u
vocab = self.tokenizer.get_vocab()
vocab_list = ['<unk>' for i in range(len(vocab))]
for k, v in vocab.items():
try:
vocab_list[int(v)] = bytes([unicode_to_byte(ord(c)) for c in k]).decode('utf-8', errors='ignore')
except:
vocab_list[int(v)] = k
special_list = list(self.tokenizer.added_tokens_decoder.keys())
with open(file_path, "w", encoding="utf8") as fp:
write_header(fp, tokenizer_type, special_list)
fp.write(f'{len(vocab_list)}\n')
for v in vocab_list:
line = base64.b64encode(v.encode('utf-8')).decode("utf8") + "\n"
fp.write(line)
return file_path
class EmbeddingExporter(LlmExporter):
def __init__(self, args):
super().__init__(args)
self.dst_name = 'embedding'
def word_embed(self, input_ids):
return self.word_embeddings(input_ids.view(1, -1))
def bge_forward(self, inputs_embeds, position_ids, attention_mask):
# bert absolute position
inputs_embeds = inputs_embeds.reshape(1, -1, self.hidden_size)
position_embeddings = self.position_embeddings(position_ids)
embeddings = inputs_embeds + position_embeddings + self.token_type_embeddings
hidden_states = self.embedding_layernorm(embeddings)
for i in range(self.num_hidden_layers):
hidden_states = self.blocks[i](hidden_states, attention_mask)[0]
sentence_embeddings = hidden_states[:, 0]
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
def gte_forward(self, inputs_embeds, position_ids, attention_mask):
# rope position
inputs_embeds = inputs_embeds.reshape(1, -1, self.hidden_size)
freqs = position_ids.float().reshape(-1, 1) * self.inv_freq
emb = torch.cat((freqs, freqs), dim=-1)
rope_embeds = torch.stack([emb.cos(), emb.sin()]).unsqueeze(-2).unsqueeze(1)
attention_bias = 1 - attention_mask.float()
hidden_states = self.embedding_layernorm(inputs_embeds + self.token_type_embeddings)
for i in range(self.num_hidden_layers):
hidden_states = self.blocks[i](hidden_states, attention_bias, rope_embeds)[0]
sentence_embeddings = hidden_states[:, 0]
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
def forward(self, inputs_embeds, position_ids, attention_mask):
if self.model_type == 'bert':
return self.bge_forward(inputs_embeds, position_ids, attention_mask)
if self.model_type == 'new':
return self.gte_forward(inputs_embeds, position_ids, attention_mask)
raise RuntimeError(f'Not support embedding model: {self.model_type}!')
def response(self, query):
self.eval()
input_ids = self.tokenizer(query)['input_ids']
self.seq_len = len(input_ids)
input_ids = torch.tensor(input_ids)
position_ids = self.get_position_ids()
attention_mask = self.get_attention_mask()
inputs_embeds = self.word_embed(input_ids)
res = self.forward(inputs_embeds, position_ids, attention_mask)
print(res)
return res
@spinner_run(f'load pretrained model ')
def load_model(self, model_path):
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).float().eval()
self.config = self.model.config
transformer = self.model.encoder
self.model_type = self.config.model_type
self.lm_ = self.model.pooler
self.embed_ = self.model.embeddings
self.word_embeddings = self.embed_.word_embeddings
self.token_type_embeddings = self.embed_.token_type_embeddings.weight.data[0]
self.embedding_layernorm = self.embed_.LayerNorm
if hasattr(self.embed_, 'position_embeddings'):
self.position_embeddings = self.embed_.position_embeddings
self.hidden_size = self.word_embeddings.weight.shape[-1]
self.blocks = transformer.layer
if self.model_type == 'new':
self.inv_freq = self.embed_.rotary_emb.inv_freq
# some wrapper
self.stop_ids = []
self.num_hidden_layers = len(self.blocks)
self.embed = self.embed_
self.lm = self.lm_
# some config for export
self.model_dynamic_axes = {
"input_ids" : { 1: "seq_len" },
"position_ids" : { 1: "seq_len" },
"attention_mask" : { 3: "seq_len" }
}
self.attention_mask_type = 'int'
self.llm_config = {
'hidden_size' : self.hidden_size,
'layer_nums' : self.num_hidden_layers,
'attention_mask': self.attention_mask_type,
'key_value_shape': [],
"prompt_template": self.build_prompt('%s'),
'is_visual': False
}
return model_path
@spinner_run(f'export onnx model to ')
def export_onnx(self):
model = self.eval()
self.seq_len = 3
input_ids = torch.arange(3, dtype=torch.long)
position_ids = self.get_position_ids()
attention_mask = self.get_attention_mask()
inputs_embeds = self.word_embed(input_ids)
onnx_model = f'{self.dst_path}/{self.dst_name}.onnx'
torch.onnx.export(
model, (inputs_embeds, position_ids, attention_mask),
onnx_model,
input_names=[
'input_ids',
'position_ids',
'attention_mask'
],
output_names=['sentence_embeddings'],
dynamic_axes=self.model_dynamic_axes,
do_constant_folding=True,
opset_version=15)
return onnx_model
def export(self, export_type):
export_mnn = 'mnn' in export_type
self.export_tokenizer()
self.export_config(export_mnn)
self.export_embed()
onnx_model = self.export_onnx()
if not self.skip_slim:
self.onnx_slim(onnx_model)
if export_mnn:
MNNConveter(onnx_model, None, self).export()
def build_prompt(self, query):
if self.model_type == 'bert':
return f'[CLS]{query}[SEP]'
if self.model_type == 'new':
return f'<s> {query}</s>'
def get_position_ids(self) -> torch.Tensor:
return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
def get_attention_mask(self) -> torch.Tensor:
return torch.ones([1, 1, 1, self.seq_len], dtype=torch.long)
def export(path,
type = None,
lora_path = None,
dst_path = './model',
export = 'onnx',
skip_slim = False,
quant_bit = 4,
quant_block = 128,
lm_quant_bit = None):
args = argparse.Namespace()
for k, v in {
'path': path,
'type': type,
'lora_path': lora_path,
'dst_path': dst_path,
'export': export,
'skip_slim': skip_slim,
'quant_bit': quant_bit,
'quant_block': quant_block,
'lm_quant_bit': lm_quant_bit
}.items():
setattr(args, k, v)
if 'bge' in path:
llm_exporter = EmbeddingExporter(args)
else:
llm_exporter = LlmExporter(args)
# export
llm_exporter.export(export)
def main():
parser = argparse.ArgumentParser(description='llm_exporter', formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('--path', type=str, required=True,
help='path(`str` or `os.PathLike`):\nCan be either:'
'\n\t- A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO]'
'\n\t- A path to a *directory* clone from repo like `../chatglm-6b`.')
parser.add_argument('--type', type=str, default=None,
help='type(`str`, *optional*):'
'\n\tThe pretrain llm model type.'
)
parser.add_argument('--lora_path', type=str, default=None, help='lora path, defaut is `None` mean not apply lora.')
parser.add_argument('--dst_path', type=str, default='./model', help='export onnx/mnn model to path, defaut is `./model`.')
parser.add_argument('--test', type=str, help='test model inference with query `TEST`.')
parser.add_argument('--export', type=str, default=None, help='export model to an onnx/mnn model.')
parser.add_argument('--skip_slim', action='store_true', help='Whether or not to skip onnx-slim.')
parser.add_argument('--quant_bit', type=int, default=4, help='mnn quant bit, 4 or 8, default is 4.')
parser.add_argument('--quant_block', type=int, default=128, help='mnn quant block, default is 0 mean channle-wise.')
parser.add_argument('--lm_quant_bit', type=int, default=None, help='mnn lm_head quant bit, 4 or 8, default is `quant_bit`.')
parser.add_argument('--mnnconvert', type=str, default='../../../build/MNNConvert', help='local mnnconvert path, if invalid, using pymnn.')
args = parser.parse_args()
model_path = args.path
model_type = args.type
if 'gte' in model_path or 'bge' in model_path:
llm_exporter = EmbeddingExporter(args)
else:
llm_exporter = LlmExporter(args)
# some actions
if args.test is not None:
llm_exporter.response(args.test)
if args.export is not None:
llm_exporter.export(args.export)
if __name__ == '__main__':
main()