2024-09-12 12:57:57 +08:00
|
|
|
import os
|
2024-11-18 14:37:45 +08:00
|
|
|
import gc
|
2024-09-12 12:57:57 +08:00
|
|
|
import sys
|
|
|
|
import math
|
|
|
|
import copy
|
|
|
|
import json
|
|
|
|
import time
|
|
|
|
import base64
|
|
|
|
import logging
|
2024-11-18 14:37:45 +08:00
|
|
|
import inspect
|
2024-09-12 12:57:57 +08:00
|
|
|
import warnings
|
|
|
|
import argparse
|
|
|
|
import functools
|
2024-11-18 14:37:45 +08:00
|
|
|
import traceback
|
|
|
|
from collections import defaultdict
|
|
|
|
from typing import Optional, Tuple, List, Union, Dict
|
2024-09-12 12:57:57 +08:00
|
|
|
|
2024-11-18 14:37:45 +08:00
|
|
|
from tqdm import tqdm
|
2024-09-12 12:57:57 +08:00
|
|
|
from yaspin import yaspin
|
|
|
|
|
|
|
|
import onnx
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
2024-11-18 14:37:45 +08:00
|
|
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
RESET = "\033[0m"
|
|
|
|
GREEN = "\033[32;1m"
|
|
|
|
YELLOW = "\033[33;4m"
|
|
|
|
EXPORT_LOG = '.export.log'
|
|
|
|
|
|
|
|
# ignore warnning info
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
logging.basicConfig(level=logging.ERROR)
|
|
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
|
|
|
2024-12-02 10:12:08 +08:00
|
|
|
def spinner_run(text='Processing...', hide=False):
|
2024-09-12 12:57:57 +08:00
|
|
|
def decorator(func):
|
|
|
|
@functools.wraps(func)
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
with yaspin(text=text, color="cyan") as spinner:
|
|
|
|
start = time.time()
|
|
|
|
try:
|
2024-12-02 10:12:08 +08:00
|
|
|
if hide: spinner.hide()
|
2024-09-12 12:57:57 +08:00
|
|
|
result = func(*args, **kwargs)
|
2024-12-02 10:12:08 +08:00
|
|
|
if hide: spinner.show()
|
2024-09-12 12:57:57 +08:00
|
|
|
except Exception as e:
|
|
|
|
spinner.fail("💥 Failed")
|
2024-11-18 14:37:45 +08:00
|
|
|
traceback.print_exc()
|
2024-09-12 12:57:57 +08:00
|
|
|
exit(1)
|
|
|
|
end = time.time()
|
|
|
|
during = f'[{end-start:05.2f} s]'.replace('[0', '[ ')
|
|
|
|
padding = ' ' * (64 - len(spinner.text) - len(result))
|
|
|
|
spinner.text = f'{spinner.text}{YELLOW}{result}{RESET}{padding}{GREEN}{during}{RESET}'
|
|
|
|
spinner.ok("✅ Done")
|
|
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
class ModelMapper:
|
|
|
|
def __init__(self):
|
|
|
|
self.attrs = []
|
|
|
|
self.mapper = dict()
|
|
|
|
self.regist_models()
|
|
|
|
|
|
|
|
def get_map(self, config):
|
|
|
|
model_type = config.model_type
|
|
|
|
if model_type == 'chatglm':
|
|
|
|
if hasattr(config, 'vocab_size') and config.vocab_size == 130528:
|
|
|
|
model_type = 'chatglm'
|
|
|
|
else:
|
|
|
|
model_type = 'chatglm2'
|
|
|
|
if model_type in self.mapper:
|
|
|
|
return model_type, self.mapper[model_type]
|
|
|
|
return model_type, self.default_map
|
|
|
|
|
|
|
|
def regist(self, model_type, model_map):
|
|
|
|
assert('config' in model_map and
|
|
|
|
'decoder' in model_map and
|
|
|
|
'attention' in model_map)
|
|
|
|
self.mapper[model_type] = model_map
|
|
|
|
|
|
|
|
def regist_models(self):
|
|
|
|
self.defualt_map()
|
|
|
|
# regist models
|
|
|
|
self.regist_llama()
|
2024-11-18 14:37:45 +08:00
|
|
|
self.regist_mllama()
|
2024-09-12 12:57:57 +08:00
|
|
|
self.regist_qwen()
|
|
|
|
self.regist_glm()
|
|
|
|
self.regist_glm2()
|
|
|
|
self.regist_phi()
|
|
|
|
self.regist_gemma2()
|
2024-11-18 14:37:45 +08:00
|
|
|
self.register_openelm()
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
def regist_llama(self):
|
|
|
|
llama_map = self.default_map
|
|
|
|
self.regist('llama', llama_map)
|
|
|
|
self.regist('qwen2', llama_map)
|
|
|
|
self.regist('internlm', llama_map)
|
2024-11-18 14:37:45 +08:00
|
|
|
self.regist('mobilellm', llama_map)
|
|
|
|
# baichuan
|
2024-09-12 12:57:57 +08:00
|
|
|
baichuan_map = copy.deepcopy(self.default_map)
|
|
|
|
baichuan_map[self.attention_key] = {
|
|
|
|
'qkv_proj': 'W_pack',
|
|
|
|
'o_proj': 'o_proj'
|
|
|
|
}
|
|
|
|
self.regist('baichuan', baichuan_map)
|
|
|
|
|
2024-11-18 14:37:45 +08:00
|
|
|
def regist_mllama(self):
|
|
|
|
mllama_map = {
|
|
|
|
'config': {
|
|
|
|
'hidden_size': 'text_config.hidden_size',
|
|
|
|
'num_attention_heads': 'text_config.num_attention_heads',
|
|
|
|
'num_hidden_layers': 'text_config.num_hidden_layers',
|
|
|
|
'num_key_value_heads': 'text_config.num_key_value_heads',
|
|
|
|
'rope_theta': 'text_config.rope_theta'
|
|
|
|
},
|
|
|
|
'model': {
|
|
|
|
'lm_': 'language_model.lm_head',
|
|
|
|
'embed_': 'language_model.model.embed_tokens',
|
|
|
|
'blocks_': 'language_model.model.layers',
|
|
|
|
'final_layernorm_': 'language_model.model.norm',
|
|
|
|
'visual': 'vision_model'
|
|
|
|
},
|
|
|
|
'decoder': {
|
|
|
|
'self_attn': 'self_attn',
|
|
|
|
'cross_attn': 'cross_attn',
|
|
|
|
'mlp': 'mlp',
|
|
|
|
'input_layernorm': 'input_layernorm',
|
|
|
|
'post_attention_layernorm': 'post_attention_layernorm'
|
|
|
|
},
|
|
|
|
'attention': {
|
|
|
|
'q_proj': 'q_proj',
|
|
|
|
'k_proj': 'k_proj',
|
|
|
|
'v_proj': 'v_proj',
|
|
|
|
'o_proj': 'o_proj',
|
|
|
|
'q_norm': 'q_norm',
|
|
|
|
'k_norm': 'k_norm',
|
|
|
|
'cross_attn_attn_gate': 'cross_attn_attn_gate',
|
|
|
|
'cross_attn_mlp_gate': 'cross_attn_mlp_gate'
|
|
|
|
}
|
|
|
|
}
|
|
|
|
self.regist('mllama', mllama_map)
|
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
def regist_qwen(self):
|
|
|
|
qwen_map = {
|
|
|
|
'config': {
|
|
|
|
'hidden_size': 'hidden_size',
|
|
|
|
'num_attention_heads': 'num_attention_heads',
|
|
|
|
'num_hidden_layers': 'num_hidden_layers',
|
|
|
|
'rope_theta': 'rotary_emb_base',
|
|
|
|
},
|
|
|
|
'model': {
|
|
|
|
'lm_': 'lm_head',
|
|
|
|
'embed_': 'transformer.wte',
|
|
|
|
'blocks_': 'transformer.h',
|
|
|
|
'final_layernorm_': 'transformer.ln_f',
|
|
|
|
'visual': 'transformer.visual'
|
|
|
|
},
|
|
|
|
'decoder': {
|
|
|
|
'self_attn': 'attn',
|
|
|
|
'mlp': 'mlp',
|
|
|
|
'input_layernorm': 'ln_1',
|
|
|
|
'post_attention_layernorm': 'ln_2'
|
|
|
|
},
|
|
|
|
'attention': {
|
|
|
|
'qkv_proj': 'c_attn',
|
|
|
|
'o_proj': 'c_proj'
|
|
|
|
}
|
|
|
|
}
|
|
|
|
self.regist('qwen', qwen_map)
|
|
|
|
|
|
|
|
def regist_glm(self):
|
|
|
|
glm_map = {
|
|
|
|
'config': {
|
|
|
|
'hidden_size': 'hidden_size',
|
|
|
|
'num_attention_heads': 'num_attention_heads',
|
|
|
|
'num_hidden_layers': 'num_layers'
|
|
|
|
},
|
|
|
|
'model': {
|
|
|
|
'lm_': 'lm_head',
|
|
|
|
'embed_': 'transformer.word_embeddings',
|
|
|
|
'blocks_': 'transformer.layers',
|
|
|
|
'final_layernorm_': 'transformer.final_layernorm',
|
|
|
|
},
|
|
|
|
'decoder': {
|
|
|
|
'self_attn': 'attention',
|
|
|
|
'mlp': 'mlp',
|
|
|
|
'input_layernorm': 'input_layernorm',
|
|
|
|
'post_attention_layernorm': 'post_attention_layernorm'
|
|
|
|
},
|
|
|
|
'attention': {
|
|
|
|
'qkv_proj': 'query_key_value',
|
|
|
|
'o_proj': 'dense'
|
|
|
|
}
|
|
|
|
}
|
|
|
|
self.regist('chatglm', glm_map)
|
|
|
|
|
|
|
|
def regist_glm2(self):
|
|
|
|
glm2_map = {
|
|
|
|
'config': {
|
|
|
|
'hidden_size': 'hidden_size',
|
|
|
|
'num_attention_heads': 'num_attention_heads',
|
|
|
|
'num_key_value_heads': 'multi_query_group_num',
|
|
|
|
'num_hidden_layers': 'num_layers',
|
|
|
|
},
|
|
|
|
'model': {
|
|
|
|
'lm_': 'transformer.output_layer',
|
|
|
|
'embed_': 'transformer.embedding.word_embeddings',
|
|
|
|
'blocks_': 'transformer.encoder.layers',
|
|
|
|
'final_layernorm_': 'transformer.encoder.final_layernorm',
|
|
|
|
},
|
|
|
|
'decoder': {
|
|
|
|
'self_attn': 'self_attention',
|
|
|
|
'mlp': 'mlp',
|
|
|
|
'input_layernorm': 'input_layernorm',
|
|
|
|
'post_attention_layernorm': 'post_attention_layernorm'
|
|
|
|
},
|
|
|
|
'attention': {
|
|
|
|
'qkv_proj': 'query_key_value',
|
|
|
|
'o_proj': 'dense'
|
|
|
|
}
|
|
|
|
}
|
|
|
|
self.regist('chatglm2', glm2_map)
|
|
|
|
|
|
|
|
def regist_phi(self):
|
|
|
|
phi_map = {
|
|
|
|
'config': {
|
|
|
|
'hidden_size': 'n_embd',
|
|
|
|
'num_attention_heads': 'n_head',
|
|
|
|
'num_hidden_layers': 'n_layer',
|
|
|
|
'rotary_dim': 'rotary_dim'
|
|
|
|
},
|
|
|
|
'model': {
|
|
|
|
'lm_': 'lm_head.linear',
|
|
|
|
'embed_': 'transformer.embd.wte',
|
|
|
|
'blocks_': 'transformer.h',
|
|
|
|
'final_layernorm_': 'lm_head.ln',
|
|
|
|
},
|
|
|
|
'decoder': {
|
|
|
|
'self_attn': 'mixer',
|
|
|
|
'mlp': 'mlp',
|
|
|
|
'input_layernorm': 'ln',
|
|
|
|
},
|
|
|
|
'attention': {
|
|
|
|
'qkv_proj': 'Wqkv',
|
|
|
|
'o_proj': 'out_proj'
|
|
|
|
}
|
|
|
|
}
|
|
|
|
self.regist('phi-msft', phi_map)
|
|
|
|
|
|
|
|
def regist_gemma2(self):
|
|
|
|
gemma2_config = copy.deepcopy(self.default_config)
|
|
|
|
gemma2_config['head_dim'] = 'head_dim'
|
|
|
|
gemma2_decoder = copy.deepcopy(self.default_decoder)
|
|
|
|
gemma2_decoder['pre_feedforward_layernorm'] = 'pre_feedforward_layernorm'
|
|
|
|
gemma2_decoder['post_feedforward_layernorm'] = 'post_feedforward_layernorm'
|
|
|
|
gemma2_map = {
|
|
|
|
'config': gemma2_config,
|
|
|
|
'model': self.defualt_model,
|
|
|
|
'decoder': gemma2_decoder,
|
|
|
|
'attention': self.default_attention
|
|
|
|
}
|
|
|
|
self.regist('gemma2', gemma2_map)
|
|
|
|
|
2024-11-18 14:37:45 +08:00
|
|
|
def register_openelm(self):
|
|
|
|
openelm_config = {
|
|
|
|
'hidden_size': 'model_dim',
|
|
|
|
'head_dim': 'head_dim',
|
|
|
|
'num_attention_heads': 'num_query_heads',
|
|
|
|
'num_hidden_layers': 'num_transformer_layers',
|
|
|
|
'num_key_value_heads': 'num_kv_heads',
|
|
|
|
'rope_theta': 'rope_freq_constant'
|
|
|
|
}
|
|
|
|
openelm_model = {
|
|
|
|
'lm_': 'lm_head',
|
|
|
|
'embed_': 'transformer.token_embeddings',
|
|
|
|
'blocks_': 'transformer.layers',
|
|
|
|
'final_layernorm_': 'transformer.norm'
|
|
|
|
}
|
|
|
|
openelm_decoder = {
|
|
|
|
'self_attn': 'attn',
|
|
|
|
'mlp': 'ffn',
|
|
|
|
'input_layernorm': 'attn_norm',
|
|
|
|
'post_attention_layernorm': 'ffn_norm'
|
|
|
|
}
|
|
|
|
openelm_attention = {
|
|
|
|
'qkv_proj': 'qkv_proj',
|
|
|
|
'o_proj': 'out_proj',
|
|
|
|
'q_norm': 'q_norm',
|
|
|
|
'k_norm': 'k_norm'
|
|
|
|
}
|
|
|
|
openelm_map = {
|
|
|
|
'config': openelm_config,
|
|
|
|
'model': openelm_model,
|
|
|
|
'decoder': openelm_decoder,
|
|
|
|
'attention': openelm_attention
|
|
|
|
}
|
|
|
|
self.regist('openelm', openelm_map)
|
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
def defualt_map(self):
|
|
|
|
# default map is `LlamaForCausalLM`
|
|
|
|
self.config_key = 'config'
|
|
|
|
self.model_key = 'model'
|
|
|
|
self.decoder_key = 'decoder'
|
|
|
|
self.attention_key = 'attention'
|
|
|
|
self.default_config = {
|
|
|
|
'hidden_size': 'hidden_size',
|
|
|
|
'num_attention_heads': 'num_attention_heads',
|
|
|
|
'num_hidden_layers': 'num_hidden_layers',
|
|
|
|
'num_key_value_heads': 'num_key_value_heads',
|
|
|
|
'rope_theta': 'rope_theta'
|
|
|
|
}
|
|
|
|
self.defualt_model = {
|
|
|
|
'lm_': 'lm_head',
|
|
|
|
'embed_': 'model.embed_tokens',
|
|
|
|
'blocks_': 'model.layers',
|
|
|
|
'final_layernorm_': 'model.norm',
|
2024-09-12 20:19:02 +08:00
|
|
|
'visual': 'visual'
|
2024-09-12 12:57:57 +08:00
|
|
|
}
|
|
|
|
self.default_decoder = {
|
|
|
|
'self_attn': 'self_attn',
|
|
|
|
'mlp': 'mlp',
|
|
|
|
'input_layernorm': 'input_layernorm',
|
|
|
|
'post_attention_layernorm': 'post_attention_layernorm'
|
|
|
|
}
|
|
|
|
self.default_attention = {
|
|
|
|
'q_proj': 'q_proj',
|
|
|
|
'k_proj': 'k_proj',
|
|
|
|
'v_proj': 'v_proj',
|
|
|
|
'o_proj': 'o_proj'
|
|
|
|
}
|
|
|
|
self.default_map = {
|
|
|
|
'config': self.default_config,
|
|
|
|
'model': self.defualt_model,
|
|
|
|
'decoder': self.default_decoder,
|
|
|
|
'attention': self.default_attention
|
|
|
|
}
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def do_map(dst, src, map):
|
|
|
|
for dst_attr, src_attr in map.items():
|
|
|
|
attributes = src_attr.split('.')
|
|
|
|
obj = src
|
|
|
|
for attr in attributes:
|
|
|
|
if hasattr(obj, attr):
|
|
|
|
obj = getattr(obj, attr)
|
|
|
|
else:
|
|
|
|
obj = None
|
|
|
|
break
|
|
|
|
setattr(dst, dst_attr, obj)
|
|
|
|
|
2024-11-18 14:37:45 +08:00
|
|
|
# Quant class
|
|
|
|
|
|
|
|
# awq quantizer start
|
|
|
|
class AwqQuantizer:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model,
|
|
|
|
modules_to_not_convert=None,
|
|
|
|
apply_clip=True,
|
|
|
|
n_parallel_calib_samples=None,
|
|
|
|
max_calib_samples=128,
|
|
|
|
max_calib_seq_len=512,
|
|
|
|
max_chunk_memory=1024 * 1024 * 1024,
|
|
|
|
) -> None:
|
|
|
|
self.awq_model = model
|
|
|
|
self.model = model
|
|
|
|
self.tokenizer = model.tokenizer
|
|
|
|
self.w_bit = model.quant_bit
|
|
|
|
self.group_size = model.quant_block
|
|
|
|
self.zeropoint = not model.symmetric
|
|
|
|
self.calib_data = 'ag_news'
|
|
|
|
self.split = 'test'
|
|
|
|
self.duo_scaling = True
|
|
|
|
self.apply_clip = apply_clip
|
|
|
|
self.n_parallel_calib_samples = n_parallel_calib_samples
|
|
|
|
self.max_calib_samples = max_calib_samples
|
|
|
|
self.max_calib_seq_len = max_calib_seq_len
|
|
|
|
self.max_chunk_memory = max_chunk_memory
|
|
|
|
self.modules_to_not_convert = (
|
|
|
|
modules_to_not_convert if modules_to_not_convert is not None else []
|
|
|
|
)
|
|
|
|
self.modules, self.module_kwargs, self.inps = self.init_quant(
|
|
|
|
n_samples=self.max_calib_samples, max_seq_len=self.max_calib_seq_len
|
|
|
|
)
|
|
|
|
|
|
|
|
def pseudo_quantize_tensor(self, w: torch.Tensor):
|
|
|
|
org_w_shape = w.shape
|
|
|
|
if self.group_size > 0:
|
|
|
|
assert org_w_shape[-1] % self.group_size == 0
|
|
|
|
w = w.reshape(-1, self.group_size)
|
|
|
|
assert w.dim() == 2
|
|
|
|
assert torch.isnan(w).sum() == 0
|
|
|
|
# zero point quantization
|
|
|
|
if self.zeropoint:
|
|
|
|
max_val = w.amax(dim=1, keepdim=True)
|
|
|
|
min_val = w.amin(dim=1, keepdim=True)
|
|
|
|
offset = 1 << (self.w_bit - 1)
|
|
|
|
clip_max = offset - 1
|
|
|
|
clip_min = -offset
|
|
|
|
scales = (max_val - min_val) / (clip_max - clip_min)
|
|
|
|
zeros = - torch.round(min_val / scales) + clip_min
|
|
|
|
qw = torch.round(w / scales) + zeros
|
|
|
|
qw = torch.clamp(qw, clip_min, clip_max)
|
|
|
|
w = (qw - zeros) * scales
|
|
|
|
zeros = min_val.view(org_w_shape[0], -1)
|
|
|
|
else:
|
|
|
|
abs_max = w.abs().amax(dim=1, keepdim=True)
|
|
|
|
offset = 1 << (self.w_bit - 1)
|
|
|
|
clip_max = offset - 1
|
|
|
|
clip_min = -clip_max
|
|
|
|
scales = abs_max / clip_max
|
|
|
|
w = torch.clamp(torch.round(w / scales), clip_min, clip_max) * scales
|
|
|
|
zeros = None
|
|
|
|
|
|
|
|
assert torch.isnan(scales).sum() == 0
|
|
|
|
assert torch.isnan(w).sum() == 0
|
|
|
|
|
|
|
|
scales = scales.view(org_w_shape[0], -1)
|
|
|
|
w = w.reshape(org_w_shape)
|
|
|
|
|
|
|
|
return w, scales, zeros
|
|
|
|
|
|
|
|
def quantize(self):
|
|
|
|
for i in tqdm(range(len(self.modules)), desc="AWQ"):
|
|
|
|
# if i > 0: break
|
|
|
|
# Move module and inputs to correct device
|
|
|
|
common_device = next(self.modules[i].parameters()).device
|
|
|
|
if common_device is None or str(common_device) == "cpu":
|
|
|
|
best_device = AwqQuantizer.get_best_device()
|
|
|
|
|
|
|
|
self.modules[i] = self.modules[i].to(best_device)
|
|
|
|
common_device = next(self.modules[i].parameters()).device
|
|
|
|
|
|
|
|
if self.module_kwargs.get("position_ids") is not None:
|
|
|
|
self.module_kwargs["position_ids"] = self.module_kwargs[
|
|
|
|
"position_ids"
|
|
|
|
].to(common_device)
|
|
|
|
|
|
|
|
if self.module_kwargs.get("attention_mask") is not None:
|
|
|
|
self.module_kwargs["attention_mask"] = self.module_kwargs[
|
|
|
|
"attention_mask"
|
|
|
|
].to(common_device)
|
|
|
|
|
|
|
|
self.inps = self.inps.to(common_device)
|
|
|
|
# print(f'# {i} inps shape: {self.inps.shape}, inps.max: {self.inps.max()}')
|
|
|
|
|
|
|
|
# [STEP 1]: Get layer, extract linear modules, extract input features
|
|
|
|
named_linears = AwqQuantizer.get_named_linears(self.modules[i])
|
|
|
|
|
|
|
|
# Filter out the linear layers we don't want to exclude
|
|
|
|
named_linears = AwqQuantizer.exclude_layers_to_not_quantize(
|
|
|
|
named_linears, self.modules_to_not_convert
|
|
|
|
)
|
|
|
|
input_feat = self._get_input_feat(self.modules[i], named_linears)
|
|
|
|
AwqQuantizer.clear_memory()
|
|
|
|
|
|
|
|
# [STEP 2]: Compute and apply scale list
|
|
|
|
module_config = []
|
|
|
|
# q, k, v proj
|
|
|
|
module_config.append(
|
|
|
|
dict(
|
|
|
|
prev_op=self.modules[i].input_layernorm,
|
|
|
|
layers=[
|
|
|
|
self.modules[i].self_attn.q_proj,
|
|
|
|
self.modules[i].self_attn.k_proj,
|
|
|
|
self.modules[i].self_attn.v_proj,
|
|
|
|
],
|
|
|
|
inp=input_feat["self_attn.q_proj"],
|
|
|
|
module2inspect=self.modules[i].self_attn,
|
|
|
|
kwargs=self.module_kwargs,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# o_proj
|
|
|
|
if self.modules[i].self_attn.v_proj.weight.shape == self.modules[i].self_attn.o_proj.weight.shape:
|
|
|
|
module_config.append(
|
|
|
|
dict(
|
|
|
|
prev_op=self.modules[i].self_attn.v_proj,
|
|
|
|
layers=[self.modules[i].self_attn.o_proj],
|
|
|
|
inp=input_feat["self_attn.o_proj"],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# mlp gate
|
|
|
|
module_config.append(
|
|
|
|
dict(
|
|
|
|
prev_op=self.modules[i].post_attention_layernorm,
|
|
|
|
layers=[self.modules[i].mlp.gate_proj, self.modules[i].mlp.up_proj],
|
|
|
|
inp=input_feat["mlp.gate_proj"],
|
|
|
|
module2inspect=self.modules[i].mlp,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# mlp down
|
|
|
|
module_config.append(
|
|
|
|
dict(
|
|
|
|
prev_op=self.modules[i].mlp.up_proj,
|
|
|
|
layers=[self.modules[i].mlp.down_proj],
|
|
|
|
inp=input_feat["mlp.down_proj"],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
scales_list = [
|
|
|
|
self._search_best_scale(self.modules[i], **layer)
|
|
|
|
for layer in module_config
|
|
|
|
]
|
|
|
|
# print(scales_list); exit(0)
|
|
|
|
AwqQuantizer.apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
|
|
|
|
# [STEP 3]: Compute and apply clipping list
|
|
|
|
if self.apply_clip:
|
|
|
|
clip_list = self._search_best_clip(
|
|
|
|
self.modules[i], named_linears, input_feat
|
|
|
|
)
|
|
|
|
AwqQuantizer.apply_clip(self.modules[i], clip_list)
|
|
|
|
|
|
|
|
AwqQuantizer.clear_memory()
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def _module_forward(
|
|
|
|
self, x: torch.Tensor, module: torch.nn.Module, module_kwargs: Dict
|
|
|
|
) -> torch.Tensor:
|
|
|
|
if self.n_parallel_calib_samples is None:
|
|
|
|
# runs through all samples at once
|
|
|
|
# print(module, x, module_kwargs); exit(0)
|
|
|
|
module_output = module(x, **module_kwargs)
|
|
|
|
if isinstance(module_output, tuple):
|
|
|
|
module_output = module_output[0]
|
|
|
|
else:
|
|
|
|
# memory efficiently runs through all calibration samples
|
|
|
|
# but only n_parallel_calib_samples at a time
|
|
|
|
module_output = []
|
|
|
|
partitioned_inputs = torch.split(x, self.n_parallel_calib_samples)
|
|
|
|
for x_partial in partitioned_inputs:
|
|
|
|
partial_output = module(x_partial, **module_kwargs)
|
|
|
|
|
|
|
|
if isinstance(partial_output, tuple):
|
|
|
|
partial_output = partial_output[0]
|
|
|
|
|
|
|
|
module_output.append(partial_output.cpu())
|
|
|
|
|
|
|
|
module_output = torch.cat(module_output, dim=0)
|
|
|
|
|
|
|
|
return module_output
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def _search_best_scale(
|
|
|
|
self,
|
|
|
|
module,
|
|
|
|
prev_op,
|
|
|
|
layers: List[torch.nn.Linear],
|
|
|
|
inp: torch.Tensor,
|
|
|
|
module2inspect=None,
|
|
|
|
kwargs={},
|
|
|
|
):
|
|
|
|
if module2inspect is None:
|
|
|
|
assert len(layers) == 1
|
|
|
|
module2inspect = layers[0]
|
|
|
|
|
|
|
|
if "use_cache" in kwargs:
|
|
|
|
kwargs.pop("use_cache")
|
|
|
|
|
|
|
|
# Put x on the right device
|
|
|
|
inp = inp.to(next(module2inspect.parameters()).device)
|
|
|
|
|
|
|
|
# [STEP 1]: Compute per-channel mean of normalised weights
|
|
|
|
# All layer weights are concatted together
|
|
|
|
weight = torch.cat([_m.weight for _m in layers], dim=0)
|
|
|
|
org_shape = weight.shape
|
|
|
|
# The weights are reshaped to be organised by quantization group
|
|
|
|
weight = weight.view(-1, self.group_size)
|
|
|
|
# Calculates the relative magnitude of the weights within each of the quantization groups,
|
|
|
|
# and rescales each group individually so that each group has weights on a 0-1 scale.
|
|
|
|
w_scale = weight.abs() / (weight.abs().amax(dim=1, keepdim=True) + 1e-6)
|
|
|
|
# Resizes the rescaled weight matrix back up to its original dimensions
|
|
|
|
w_scale = w_scale.view(org_shape)
|
|
|
|
# Gets the average rescaled magnitude for each output channel
|
|
|
|
w_mean = w_scale.mean(0)
|
|
|
|
AwqQuantizer.clear_memory(weight)
|
|
|
|
|
|
|
|
# [STEP 2]: Compute per-channel mean of the input activation with chunking
|
|
|
|
# move inp to cpu to avoid memory leak
|
|
|
|
inp_flat = inp.cpu().abs().view(-1, inp.shape[-1])
|
|
|
|
num_elements = inp_flat.size(0)
|
|
|
|
num_channels = inp_flat.size(1)
|
|
|
|
element_size_bytes = inp_flat.element_size() * 2 # multiplied by 2 for FP32
|
|
|
|
|
|
|
|
# Calculate chunk size dynamically based on max_chunk_memory
|
|
|
|
chunk_size = int(self.max_chunk_memory // (element_size_bytes * num_channels))
|
|
|
|
chunk_size = min(chunk_size, num_elements)
|
|
|
|
|
|
|
|
# Use float32 for sum calculation
|
|
|
|
x_sum = torch.zeros(num_channels, dtype=torch.float32, device=inp.device)
|
|
|
|
|
|
|
|
for i in range(0, num_elements, chunk_size):
|
|
|
|
end = min(i + chunk_size, num_elements)
|
|
|
|
chunk_sum = inp_flat[i:end].to(torch.float32).sum(dim=0)
|
|
|
|
x_sum += chunk_sum.to(inp.device)
|
|
|
|
|
|
|
|
x_mean = (x_sum / num_elements).to(inp.dtype)
|
|
|
|
AwqQuantizer.clear_memory(x_sum)
|
|
|
|
|
|
|
|
# [STEP 3]: Compute output of module
|
|
|
|
with torch.no_grad():
|
|
|
|
module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)
|
|
|
|
fp16_output = self._module_forward(inp, module2inspect, module_kwargs)
|
|
|
|
|
|
|
|
# [STEP 4]: Compute loss
|
|
|
|
best_scales = self._compute_best_scale(
|
|
|
|
inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
|
|
|
|
)
|
|
|
|
|
|
|
|
return (
|
|
|
|
AwqQuantizer.get_op_name(module, prev_op),
|
|
|
|
tuple([AwqQuantizer.get_op_name(module, m) for m in layers]),
|
|
|
|
best_scales,
|
|
|
|
)
|
|
|
|
|
|
|
|
def _compute_best_scale(
|
|
|
|
self,
|
|
|
|
x: torch.Tensor,
|
|
|
|
w_mean: torch.Tensor,
|
|
|
|
x_mean: torch.Tensor,
|
|
|
|
module2inspect: torch.nn.Module,
|
|
|
|
linears2scale: List[torch.nn.Linear],
|
|
|
|
fp16_output: torch.Tensor,
|
|
|
|
kwargs: Dict={},
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Compute loss and select best scales
|
|
|
|
|
|
|
|
L(s) = || Q(W * s) (s^-1 * X) - W * X ||
|
|
|
|
Q: weight quantization function | pseudo_quantize_tensor(W * s)
|
|
|
|
X: inputs from calib dataset | X
|
|
|
|
W: original weights in FP16 | layer
|
|
|
|
s: per channel scaling factor | s^-1 * X
|
|
|
|
"""
|
|
|
|
n_grid = 20
|
|
|
|
history = []
|
|
|
|
best_ratio = -1
|
|
|
|
best_scales = None
|
|
|
|
best_error = float("inf")
|
|
|
|
|
|
|
|
device = x.device
|
|
|
|
x_mean = x_mean.view(-1).to(device)
|
|
|
|
w_mean = w_mean.view(-1).to(device)
|
|
|
|
|
|
|
|
ord_weights = []
|
|
|
|
for fc in linears2scale:
|
|
|
|
ord_weights.append(fc.weight.data.clone())
|
|
|
|
|
|
|
|
for ratio in range(n_grid):
|
|
|
|
# create new scales
|
|
|
|
ratio = ratio / n_grid
|
|
|
|
|
|
|
|
# NOTE: s^-1 * x is fused here, according to paper
|
|
|
|
if self.duo_scaling:
|
|
|
|
scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(min=1e-4)
|
|
|
|
else:
|
|
|
|
scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
|
|
|
|
scales = scales / (scales.max() * scales.min()).sqrt()
|
|
|
|
scales_view = scales.view(1, -1).to(device)
|
|
|
|
|
|
|
|
# avoid scaling values that overflow
|
|
|
|
scales[torch.isinf(scales)] = 1
|
|
|
|
scales[torch.isnan(scales)] = 1
|
|
|
|
|
|
|
|
# Q(W * s)
|
|
|
|
for fc in linears2scale:
|
|
|
|
fc.weight.mul_(scales_view)
|
|
|
|
fc.weight.data = (
|
|
|
|
self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
|
|
|
|
)
|
|
|
|
|
|
|
|
# W * X
|
|
|
|
int_w_output = self._module_forward(x, module2inspect, kwargs)
|
|
|
|
|
|
|
|
# compute mean squared error (L2 norm)
|
|
|
|
loss = self._compute_loss(fp16_output, int_w_output, device)
|
|
|
|
|
|
|
|
history.append(loss)
|
|
|
|
if loss < best_error:
|
|
|
|
best_error = loss
|
|
|
|
best_ratio = ratio
|
|
|
|
best_scales = scales.clone()
|
|
|
|
|
|
|
|
for fc, ord_weight in zip(linears2scale, ord_weights):
|
|
|
|
fc.weight.data = ord_weight.clone()
|
|
|
|
|
|
|
|
del ord_weights
|
|
|
|
|
|
|
|
if best_ratio == -1:
|
|
|
|
logging.debug(history)
|
|
|
|
raise Exception
|
|
|
|
|
|
|
|
assert torch.isnan(best_scales).sum() == 0, best_scales
|
|
|
|
|
|
|
|
return best_scales.detach().cpu()
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def _compute_loss(
|
|
|
|
self,
|
|
|
|
fp16_output: torch.Tensor,
|
|
|
|
int_w_output: torch.Tensor,
|
|
|
|
device: torch.device,
|
|
|
|
):
|
|
|
|
loss = 0.0
|
|
|
|
fp16_output_flat = fp16_output.view(-1)
|
|
|
|
int_w_output_flat = int_w_output.view(-1)
|
|
|
|
num_elements = fp16_output_flat.size(0)
|
|
|
|
element_size_bytes = fp16_output.element_size()
|
|
|
|
|
|
|
|
# Calculate chunk size dynamically based on max_chunk_memory
|
|
|
|
# Divide the max_chunk_memory by twice the element size
|
|
|
|
chunk_size = self.max_chunk_memory // (element_size_bytes * 2)
|
|
|
|
chunk_size = min(chunk_size, num_elements)
|
|
|
|
|
|
|
|
# Split the computation into chunks
|
|
|
|
fp16_chunks = torch.split(fp16_output_flat, chunk_size)
|
|
|
|
int_w_chunks = torch.split(int_w_output_flat, chunk_size)
|
|
|
|
|
|
|
|
# Compute the loss for each chunk
|
|
|
|
for fp16_chunk, int_w_chunk in zip(fp16_chunks, int_w_chunks):
|
|
|
|
chunk_loss = (fp16_chunk.to(device) - int_w_chunk.to(device)).float().pow(2).sum().item()
|
|
|
|
loss += chunk_loss
|
|
|
|
|
|
|
|
# Normalize the loss by the total number of elements
|
|
|
|
loss /= num_elements
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def _search_best_clip(self, layer, named_linears, input_feat):
|
|
|
|
clip_list = []
|
|
|
|
avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]
|
|
|
|
|
|
|
|
for name in named_linears:
|
|
|
|
# due to qk bmm, it is hard to clip precisely
|
|
|
|
if any([_ in name for _ in avoid_clipping]):
|
|
|
|
continue
|
|
|
|
|
|
|
|
named_linears[name].to(AwqQuantizer.get_best_device())
|
|
|
|
max_val = self._compute_best_clip(
|
|
|
|
named_linears[name].weight, input_feat[name]
|
|
|
|
)
|
|
|
|
clip_list.append((name, max_val))
|
|
|
|
named_linears[name].cpu()
|
|
|
|
|
|
|
|
return clip_list
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def _compute_best_clip(
|
|
|
|
self,
|
|
|
|
w: torch.Tensor,
|
|
|
|
input_feat: torch.Tensor,
|
|
|
|
n_grid=20,
|
|
|
|
max_shrink=0.5,
|
|
|
|
n_sample_token=512,
|
|
|
|
):
|
|
|
|
assert w.dim() == 2
|
|
|
|
org_w_shape = w.shape
|
|
|
|
# w [co, ci] -> [co, 1, n_group, group size]
|
|
|
|
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
|
|
|
|
group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
|
|
|
|
input_feat = input_feat.view(-1, input_feat.shape[-1])
|
|
|
|
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
|
|
|
|
|
|
|
|
# Compute input feature step size (minimum 1)
|
|
|
|
step_size = max(1, input_feat.shape[1] // n_sample_token)
|
|
|
|
input_feat = input_feat[:, ::step_size]
|
|
|
|
|
|
|
|
w = w.reshape(org_w_shape[0], 1, -1, group_size)
|
|
|
|
|
|
|
|
oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM
|
|
|
|
assert org_w_shape[0] % oc_batch_size == 0
|
|
|
|
w_all = w
|
|
|
|
best_max_val_all = []
|
|
|
|
|
|
|
|
for i_b in range(org_w_shape[0] // oc_batch_size):
|
|
|
|
w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]
|
|
|
|
|
|
|
|
org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1
|
|
|
|
|
|
|
|
best_max_val = org_max_val.clone()
|
|
|
|
min_errs = torch.ones_like(org_max_val) * 1e9
|
|
|
|
input_feat = input_feat.to(w.device)
|
|
|
|
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group
|
|
|
|
|
|
|
|
for i_s in range(int(max_shrink * n_grid)):
|
|
|
|
max_val = org_max_val * (1 - i_s / n_grid)
|
|
|
|
min_val = -max_val
|
|
|
|
cur_w = torch.clamp(w, min_val, max_val)
|
|
|
|
q_w = self.pseudo_quantize_tensor(cur_w)[0]
|
|
|
|
cur_out = (input_feat * q_w).sum(dim=-1)
|
|
|
|
|
|
|
|
# co, 1, n_group, 1
|
|
|
|
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
|
|
|
|
del cur_w
|
|
|
|
del cur_out
|
|
|
|
cur_best_idx = err < min_errs
|
|
|
|
min_errs[cur_best_idx] = err[cur_best_idx]
|
|
|
|
best_max_val[cur_best_idx] = max_val[cur_best_idx]
|
|
|
|
best_max_val_all.append(best_max_val)
|
|
|
|
|
|
|
|
best_max_val = torch.cat(best_max_val_all, dim=0)
|
|
|
|
|
|
|
|
AwqQuantizer.clear_memory(input_feat)
|
|
|
|
AwqQuantizer.clear_memory(org_out)
|
|
|
|
|
|
|
|
return best_max_val.squeeze(1)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@torch.no_grad()
|
|
|
|
def apply_clip(module, clip_list: Tuple[str, torch.Tensor]):
|
|
|
|
for name, max_val in clip_list:
|
|
|
|
layer: torch.nn.Linear = AwqQuantizer.get_op_by_name(module, name)
|
|
|
|
layer.to(AwqQuantizer.get_best_device())
|
|
|
|
max_val = max_val.to(layer.weight.device)
|
|
|
|
org_shape = layer.weight.shape
|
|
|
|
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
|
|
|
|
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
|
|
|
|
layer.weight.data = layer.weight.data.reshape(org_shape)
|
|
|
|
layer.cpu()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@torch.no_grad()
|
|
|
|
def scale_fc_fcs(fc1: torch.nn.Linear, fcs: List[torch.nn.Linear], scales: torch.Tensor):
|
|
|
|
if not isinstance(fcs, list):
|
|
|
|
fcs = [fcs]
|
|
|
|
|
|
|
|
scales = scales.to(fc1.weight.device)
|
|
|
|
|
|
|
|
fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1))
|
|
|
|
if fc1.bias is not None:
|
|
|
|
fc1.bias.div_(scales.view(-1))
|
|
|
|
|
|
|
|
for fc in fcs:
|
|
|
|
fc.weight.mul_(scales.view(1, -1))
|
|
|
|
|
|
|
|
for p in fc1.parameters():
|
|
|
|
assert torch.isnan(p).sum() == 0
|
|
|
|
for fc in fcs:
|
|
|
|
for p in fc.parameters():
|
|
|
|
assert torch.isnan(p).sum() == 0
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def is_allowed_act_fns(op):
|
|
|
|
from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation
|
|
|
|
allowed_act_fns = [
|
|
|
|
torch.nn.GELU,
|
|
|
|
NewGELUActivation,
|
|
|
|
PytorchGELUTanh,
|
|
|
|
GELUActivation,
|
|
|
|
]
|
|
|
|
return (op in allowed_act_fns)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def is_allowed_norms(op):
|
|
|
|
if isinstance(op, torch.nn.LayerNorm):
|
|
|
|
return True
|
|
|
|
if any(t in str(type(op)) for t in ['LlamaRMSNorm', 'GemmaRMSNorm', 'CohereLayerNorm']):
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@torch.no_grad()
|
|
|
|
def scale_fc_fc(fc1: torch.nn.Linear, fc2: torch.nn.Linear, scales: torch.Tensor):
|
|
|
|
assert isinstance(fc1, torch.nn.Linear)
|
|
|
|
assert isinstance(fc2, torch.nn.Linear)
|
|
|
|
|
|
|
|
scales = scales.to(fc1.weight.device)
|
|
|
|
fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1))
|
|
|
|
if fc1.bias is not None:
|
|
|
|
fc1.bias.div_(scales.view(-1))
|
|
|
|
|
|
|
|
fc2.weight.mul_(scales.view(1, -1))
|
|
|
|
|
|
|
|
for p in fc1.parameters():
|
|
|
|
assert torch.isnan(p).sum() == 0
|
|
|
|
for p in fc2.parameters():
|
|
|
|
assert torch.isnan(p).sum() == 0
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@torch.no_grad()
|
|
|
|
def scale_ln_fcs(ln: torch.nn.Linear, fcs: List[torch.nn.Linear], scales: torch.Tensor):
|
|
|
|
if not isinstance(fcs, list):
|
|
|
|
fcs = [fcs]
|
|
|
|
|
|
|
|
scales = scales.to(ln.weight.device)
|
|
|
|
|
|
|
|
# GemmaRMSNorm is different from Llama's in that it multiplies
|
|
|
|
# (1 + weight) to the output, instead of just weight.
|
|
|
|
if 'GemmaRMSNorm' in str(type(ln)):
|
|
|
|
ln.weight += 1
|
|
|
|
ln.weight.div_(scales)
|
|
|
|
ln.weight -= 1
|
|
|
|
else:
|
|
|
|
ln.weight.div_(scales)
|
|
|
|
|
|
|
|
if hasattr(ln, "bias") and ln.bias is not None:
|
|
|
|
ln.bias.div_(scales)
|
|
|
|
|
|
|
|
for fc in fcs:
|
|
|
|
fc.weight.mul_(scales.view(1, -1))
|
|
|
|
|
|
|
|
for p in ln.parameters():
|
|
|
|
assert torch.isnan(p).sum() == 0
|
|
|
|
for fc in fcs:
|
|
|
|
for p in fc.parameters():
|
|
|
|
assert torch.isnan(p).sum() == 0
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@torch.no_grad()
|
|
|
|
def scale_gelu_fc(gelu, fc: torch.nn.Linear, scales: torch.Tensor):
|
|
|
|
assert AwqQuantizer.is_allowed_act_fns(gelu)
|
|
|
|
assert isinstance(fc, torch.nn.Linear)
|
|
|
|
|
|
|
|
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
|
|
|
|
|
|
|
|
for p in fc.parameters():
|
|
|
|
assert torch.isnan(p).sum() == 0
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def apply_scale(module, scales_list, input_feat_dict=None):
|
|
|
|
for prev_op_name, layer_names, scales in scales_list:
|
|
|
|
prev_op = AwqQuantizer.get_op_by_name(module, prev_op_name)
|
|
|
|
layers = [AwqQuantizer.get_op_by_name(module, name) for name in layer_names]
|
|
|
|
|
|
|
|
best_device = AwqQuantizer.get_best_device()
|
|
|
|
prev_op.to(best_device)
|
|
|
|
for layer in layers:
|
|
|
|
layer.to(best_device)
|
|
|
|
scales.to(best_device)
|
|
|
|
if (
|
|
|
|
isinstance(prev_op, torch.nn.Linear)
|
|
|
|
and type(layers) == list
|
|
|
|
and isinstance(layers[0], torch.nn.Linear)
|
|
|
|
):
|
|
|
|
if len(layers) == 1:
|
|
|
|
AwqQuantizer.scale_fc_fc(prev_op, layers[0], scales)
|
|
|
|
else:
|
|
|
|
AwqQuantizer.scale_fc_fcs(prev_op, layers, scales)
|
|
|
|
elif (
|
|
|
|
AwqQuantizer.is_allowed_norms(prev_op)
|
|
|
|
or "rmsnorm" in str(prev_op.__class__).lower()
|
|
|
|
):
|
|
|
|
AwqQuantizer.scale_ln_fcs(prev_op, layers, scales)
|
|
|
|
|
|
|
|
elif AwqQuantizer.is_allowed_act_fns(prev_op):
|
|
|
|
#new_module = ScaledActivation(prev_op, scales)
|
|
|
|
#set_op_by_name(module, prev_op_name, new_module)
|
|
|
|
AwqQuantizer.scale_gelu_fc(prev_op, layers[0], scales)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")
|
|
|
|
|
|
|
|
# apply the scaling to input feat if given; prepare it for clipping
|
|
|
|
if input_feat_dict is not None:
|
|
|
|
for layer_name in layer_names:
|
|
|
|
# Skip the modules that are not quantized
|
|
|
|
if layer_name in input_feat_dict:
|
|
|
|
inp = input_feat_dict[layer_name]
|
|
|
|
inp.div_(scales.view(1, -1).to(inp.device))
|
|
|
|
|
|
|
|
prev_op.cpu()
|
|
|
|
for layer in layers:
|
|
|
|
layer.cpu()
|
|
|
|
scales.cpu()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert):
|
|
|
|
if modules_to_not_convert is None:
|
|
|
|
return linear_layers
|
|
|
|
|
|
|
|
filtered_layers = {}
|
|
|
|
for name, linear_layer in linear_layers.items():
|
|
|
|
if not any(key in name for key in modules_to_not_convert):
|
|
|
|
filtered_layers[name] = linear_layer
|
|
|
|
return filtered_layers
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_named_linears(module):
|
|
|
|
return {name: m for name, m in module.named_modules() if isinstance(m, torch.nn.Linear)}
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_op_by_name(module, op_name):
|
|
|
|
# get the op by its name relative to the module
|
|
|
|
for name, m in module.named_modules():
|
|
|
|
if name == op_name:
|
|
|
|
return m
|
|
|
|
raise ValueError(f"Cannot find op {op_name} in module {module}")
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_calib_dataset(
|
|
|
|
data: Union[str, List[str], List[List[int]]] = "pileval",
|
|
|
|
tokenizer=None,
|
|
|
|
n_samples=128,
|
|
|
|
max_seq_len=512,
|
|
|
|
split="train",
|
|
|
|
text_column="text",
|
|
|
|
):
|
|
|
|
if isinstance(data, str):
|
|
|
|
from datasets import load_dataset
|
|
|
|
if data == "pileval":
|
|
|
|
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
|
|
|
|
else:
|
|
|
|
dataset = load_dataset(data, split=split)
|
|
|
|
# dataset = dataset.shuffle(seed=42)
|
|
|
|
elif isinstance(data, list):
|
|
|
|
if isinstance(data[0], str):
|
|
|
|
dataset = [{text_column: text} for text in data]
|
|
|
|
elif isinstance(data[0][0], int):
|
|
|
|
dataset = data
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Either pass a string to a huggingface dataset or a list"
|
|
|
|
"that is preprocessed with one sample of text per element"
|
|
|
|
" or a list of list of int for tokenized words."
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Either pass a string to a huggingface dataset or a list"
|
|
|
|
"that is preprocessed with one sample of text per element"
|
|
|
|
" or a list of list of int for tokenized words."
|
|
|
|
)
|
|
|
|
|
|
|
|
samples = []
|
|
|
|
n_run = 0
|
|
|
|
for data in dataset:
|
|
|
|
if isinstance(data, list):
|
|
|
|
line_encoded = data
|
|
|
|
else:
|
|
|
|
line = data[text_column]
|
|
|
|
line = line.strip()
|
|
|
|
line_encoded = tokenizer.encode(line)
|
|
|
|
if len(line_encoded) > max_seq_len:
|
|
|
|
continue
|
|
|
|
sample = torch.tensor([line_encoded])
|
|
|
|
if sample.numel() == 0:
|
|
|
|
continue
|
|
|
|
samples.append(sample)
|
|
|
|
n_run += 1
|
|
|
|
if n_run == n_samples:
|
|
|
|
break
|
|
|
|
# now concatenate all samples and split according to max sequence length
|
|
|
|
cat_samples = torch.cat(samples, dim=1)
|
|
|
|
n_split = cat_samples.shape[1] // max_seq_len
|
|
|
|
logging.debug(f" * Split into {n_split} blocks")
|
|
|
|
return [
|
|
|
|
cat_samples[:, i * max_seq_len : (i + 1) * max_seq_len] for i in range(n_split)
|
|
|
|
]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_best_device():
|
|
|
|
if torch.backends.mps.is_available():
|
|
|
|
return "mps"
|
|
|
|
elif torch.cuda.is_available():
|
|
|
|
return "cuda:0"
|
|
|
|
else:
|
|
|
|
return "cpu"
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def clear_memory(weight=None):
|
|
|
|
if weight is not None:
|
|
|
|
del weight
|
|
|
|
gc.collect()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_op_name(module, op):
|
|
|
|
# get the name of the op relative to the module
|
|
|
|
for name, m in module.named_modules():
|
|
|
|
if m is op:
|
|
|
|
return name
|
|
|
|
raise ValueError(f"Cannot find op {op} in module {module}")
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def append_str_prefix(x, prefix):
|
|
|
|
if isinstance(x, str):
|
|
|
|
return prefix + x
|
|
|
|
elif isinstance(x, tuple):
|
|
|
|
return tuple([AwqQuantizer.append_str_prefix(y, prefix) for y in x])
|
|
|
|
elif isinstance(x, list):
|
|
|
|
return [AwqQuantizer.append_str_prefix(y, prefix) for y in x]
|
|
|
|
else:
|
|
|
|
return x
|
|
|
|
|
|
|
|
def init_quant(self, n_samples=128, max_seq_len=512):
|
|
|
|
modules = self.awq_model.blocks
|
|
|
|
samples = AwqQuantizer.get_calib_dataset(
|
|
|
|
data=self.calib_data,
|
|
|
|
tokenizer=self.tokenizer,
|
|
|
|
n_samples=n_samples,
|
|
|
|
max_seq_len=max_seq_len,
|
|
|
|
split=self.split
|
|
|
|
)
|
|
|
|
# samples = torch.cat(samples, dim=0)
|
|
|
|
samples = torch.cat(samples[:1], dim=0) # just using 1 batch
|
|
|
|
inps = []
|
|
|
|
layer_kwargs = {}
|
|
|
|
# build inps
|
|
|
|
self.model.seq_len = samples.numel()
|
|
|
|
self.model.context_len = samples.numel() - 2
|
|
|
|
self.model.token_len = 0
|
|
|
|
best_device = AwqQuantizer.get_best_device()
|
|
|
|
inps = self.model.embedding(samples).to(best_device)
|
|
|
|
position_ids = self.model.get_position_ids()
|
|
|
|
rotary_pos_emb = self.model.rotary(position_ids)
|
|
|
|
attention_mask = self.model.get_attention_mask()
|
|
|
|
layer_kwargs["rotary_pos_emb"] = rotary_pos_emb.to(best_device)
|
|
|
|
layer_kwargs["attention_mask"] = attention_mask.to(best_device)
|
|
|
|
del samples
|
|
|
|
AwqQuantizer.clear_memory()
|
|
|
|
return modules, layer_kwargs, inps
|
|
|
|
|
|
|
|
def _get_input_feat(self, layer, named_linears):
|
|
|
|
# firstly, get input features of all linear layers
|
|
|
|
def cache_input_hook(m, x, y, name, feat_dict):
|
|
|
|
x = x[0]
|
|
|
|
x = x.detach().cpu()
|
|
|
|
feat_dict[name].append(x)
|
|
|
|
input_feat = defaultdict(list)
|
|
|
|
handles = []
|
|
|
|
for name in named_linears:
|
|
|
|
handles.append(
|
|
|
|
named_linears[name].register_forward_hook(
|
|
|
|
functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu
|
|
|
|
# get output as next layer's input
|
|
|
|
|
|
|
|
# Sanitize the kwargs in case we use transformers version that contains
|
|
|
|
# kwargs that are not handled by the module.
|
|
|
|
# Useful for trust_remote_code models.
|
|
|
|
module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)
|
|
|
|
|
|
|
|
self.inps = self._module_forward(self.inps, layer, module_kwargs)
|
|
|
|
for h in handles:
|
|
|
|
h.remove()
|
|
|
|
# now solve for scaling and clipping
|
|
|
|
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
|
|
|
|
|
|
|
|
return input_feat
|
|
|
|
|
|
|
|
def _sanitize_kwargs(self, inputs_kwargs, module):
|
|
|
|
"""
|
|
|
|
Remove the arguments that are not supported in the module's
|
|
|
|
forward pass to avoid breaking behaviour between different versions
|
|
|
|
of transformers.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
inputs_kwargs (`dict`):
|
|
|
|
The input dictionary to pass to the model layer
|
|
|
|
module (`torch.nn.Module`):
|
|
|
|
Target module to quantize.
|
|
|
|
"""
|
|
|
|
module_signature = inspect.signature(module.forward).parameters
|
|
|
|
sanitized_kwargs = {}
|
|
|
|
for k, v in inputs_kwargs.items():
|
|
|
|
if k in module_signature:
|
|
|
|
sanitized_kwargs[k] = v
|
|
|
|
return sanitized_kwargs
|
|
|
|
# awq quantizer end
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
# Export class
|
2024-11-18 14:37:45 +08:00
|
|
|
|
|
|
|
# custom op start
|
|
|
|
class FakeLinearOp(torch.autograd.Function):
|
2024-09-12 12:57:57 +08:00
|
|
|
@staticmethod
|
|
|
|
def symbolic(g, input, in_features, out_features, has_bias, name):
|
|
|
|
# These become the operator attributes.
|
|
|
|
kwargs = {
|
|
|
|
"in_features_i": in_features,
|
|
|
|
"out_features_i": out_features,
|
|
|
|
"has_bias_i": has_bias,
|
|
|
|
"name_s": name
|
|
|
|
}
|
|
|
|
from torch.onnx.symbolic_helper import _get_tensor_sizes
|
|
|
|
out_sizes = _get_tensor_sizes(input)[:-1] + [out_features]
|
|
|
|
output_type = input.type().with_sizes(out_sizes)
|
|
|
|
return g.op("LlmExporter::FakeLinear", input, **kwargs).setType(output_type)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, input, in_features, out_features, has_bias, name):
|
|
|
|
out_shape = list(input.shape)[:-1] + [out_features]
|
|
|
|
return input.new_zeros(out_shape)
|
|
|
|
|
|
|
|
class FakeLinear(torch.nn.Module):
|
|
|
|
def __init__(self, in_features, out_features, has_bias, name):
|
|
|
|
super(FakeLinear, self).__init__()
|
|
|
|
self.in_features = in_features
|
|
|
|
self.out_features = out_features
|
|
|
|
self.has_bias = has_bias
|
|
|
|
self.name = name
|
|
|
|
|
|
|
|
def forward(self, x):
|
2024-11-18 14:37:45 +08:00
|
|
|
return FakeLinearOp.apply(x, self.in_features, self.out_features, self.has_bias, self.name)
|
|
|
|
|
|
|
|
class FusedAttentionOp(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
|
|
def symbolic(g, query, key, value, attention_mask, hidden_size, name):
|
|
|
|
# These become the operator attributes.
|
|
|
|
kwargs = {
|
|
|
|
"hidden_size_i": hidden_size,
|
|
|
|
"name_s": name
|
|
|
|
}
|
|
|
|
from torch.onnx.symbolic_helper import _get_tensor_sizes
|
|
|
|
out_sizes = _get_tensor_sizes(query)
|
|
|
|
output_type = query.type().with_sizes(out_sizes)
|
|
|
|
return g.op("LlmExporter::FusedAttention", query, key, value, attention_mask, **kwargs).setType(output_type)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, query, key, value, attention_mask, hidden_size, name):
|
|
|
|
out_shape = list(query.shape)[:2] + [hidden_size]
|
|
|
|
return query.new_zeros(out_shape)
|
|
|
|
|
|
|
|
class FusedAttention(torch.nn.Module):
|
|
|
|
def __init__(self, hidden_size, name):
|
|
|
|
super(FusedAttention, self).__init__()
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.name = name
|
|
|
|
|
|
|
|
def forward(self, query, key, value, attention_mask):
|
|
|
|
return FusedAttentionOp.apply(query, key, value, attention_mask, self.hidden_size, self.name)
|
|
|
|
|
|
|
|
# custom op end
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
class OnnxRebuilder:
|
|
|
|
def __init__(self, onnx_path, weight_ops):
|
|
|
|
self.weight_ops = weight_ops
|
|
|
|
self.onnx_model = onnx.load(onnx_path)
|
|
|
|
self.dst_path = onnx_path
|
|
|
|
self.onnx_weight_path = f'{onnx_path}.data'
|
|
|
|
self.onnx_weight_offset = 0
|
|
|
|
|
|
|
|
def make_external(self, name, data, shape):
|
|
|
|
# write to external weight
|
|
|
|
length = self.onnx_weight.write(data.tobytes())
|
|
|
|
location = os.path.basename(self.onnx_weight_path)
|
|
|
|
offset = self.onnx_weight_offset
|
|
|
|
self.onnx_weight_offset += length
|
|
|
|
tensor = onnx.TensorProto()
|
|
|
|
tensor.name = name
|
|
|
|
tensor.data_type = onnx.TensorProto.FLOAT
|
|
|
|
tensor.dims.extend(shape)
|
|
|
|
# external info
|
|
|
|
tensor.data_location = onnx.TensorProto.EXTERNAL
|
|
|
|
for k, v in { "location": location, "offset": offset, "length": length }.items():
|
|
|
|
entry = tensor.external_data.add()
|
|
|
|
entry.key = k
|
|
|
|
entry.value = str(v)
|
|
|
|
self.onnx_model.graph.initializer.append(tensor)
|
|
|
|
|
|
|
|
def build_weight(self, name, has_bias, ic, oc):
|
|
|
|
assert(name in self.weight_ops)
|
|
|
|
linear = self.weight_ops[name]
|
|
|
|
assert(linear.in_features == ic and
|
|
|
|
linear.out_features == oc and
|
|
|
|
(linear.bias is not None) == has_bias)
|
|
|
|
weight_name, bias_name = f'{name}_weight', f'{name}_bias'
|
|
|
|
weight = linear.weight.data.transpose(1, 0).flatten().numpy()
|
|
|
|
self.make_external(weight_name, weight, [ic, oc])
|
|
|
|
if has_bias:
|
|
|
|
bias = linear.bias.data.flatten().numpy()
|
|
|
|
self.make_external(bias_name, bias, [oc])
|
|
|
|
return weight_name, bias_name
|
|
|
|
|
|
|
|
def rebuild(self):
|
|
|
|
from onnx import helper
|
|
|
|
new_nodes = []
|
|
|
|
self.onnx_weight = open(self.onnx_weight_path, 'wb')
|
|
|
|
for node in self.onnx_model.graph.node:
|
|
|
|
if node.op_type == 'FakeLinear':
|
|
|
|
attributes = {a.name: a for a in node.attribute}
|
|
|
|
name = attributes.get('name').s.decode('utf-8')
|
|
|
|
has_bias = attributes.get('has_bias').i
|
|
|
|
ic = attributes.get('in_features').i
|
|
|
|
oc = attributes.get('out_features').i
|
|
|
|
weight, bias = self.build_weight(name, has_bias, ic, oc)
|
|
|
|
if has_bias:
|
|
|
|
# fakelinear -> matmul + add
|
|
|
|
middle_tensor = f'{name}_matmul'
|
|
|
|
new_nodes.append(helper.make_node('MatMul', [node.input[0], weight], [middle_tensor], name))
|
2024-11-18 14:37:45 +08:00
|
|
|
new_nodes.append(helper.make_node('Add', [middle_tensor, bias], node.output, f'{name}/Add'))
|
2024-09-12 12:57:57 +08:00
|
|
|
else:
|
|
|
|
# fakelinear -> matmul
|
|
|
|
new_nodes.append(helper.make_node('MatMul', [node.input[0], weight], node.output, name))
|
|
|
|
else:
|
|
|
|
new_nodes.append(node)
|
|
|
|
self.onnx_weight.close()
|
|
|
|
del self.onnx_model.graph.node[:]
|
|
|
|
self.onnx_model.graph.node.extend(new_nodes)
|
|
|
|
onnx.save(self.onnx_model, self.dst_path)
|
|
|
|
return self.onnx_weight_path
|
|
|
|
|
|
|
|
class MNNConveter:
|
|
|
|
def __init__(self, onnx_path, weight_ops, config):
|
|
|
|
self.weight_ops = weight_ops
|
2024-11-18 14:37:45 +08:00
|
|
|
self.config = config
|
2024-09-12 12:57:57 +08:00
|
|
|
self.quant_block = config.quant_block
|
|
|
|
self.quant_bit = config.quant_bit
|
|
|
|
self.lm_quant_bit = config.lm_quant_bit
|
2024-11-18 14:37:45 +08:00
|
|
|
self.symmetric = config.symmetric
|
2024-09-12 12:57:57 +08:00
|
|
|
self.mnn_weight_offset = 0
|
|
|
|
self.onnx_model_path = onnx_path
|
2024-09-12 20:19:02 +08:00
|
|
|
self.mnn_name = os.path.basename(onnx_path).replace('.onnx', '.mnn')
|
|
|
|
self.mnn_model_path = os.path.join(config.dst_path, self.mnn_name)
|
2024-09-12 12:57:57 +08:00
|
|
|
self.mnn_weight_path = f'{self.mnn_model_path}.weight'
|
|
|
|
if os.path.exists(config.mnnconvert):
|
|
|
|
self.mnnconvert = config.mnnconvert
|
|
|
|
else:
|
|
|
|
self.mnnconvert = None
|
|
|
|
|
|
|
|
def convert(self, convert_args):
|
|
|
|
sfd = os.dup(1)
|
|
|
|
log_fp = open(EXPORT_LOG, "a")
|
|
|
|
log_fd = log_fp.fileno()
|
2024-09-12 20:19:02 +08:00
|
|
|
# mnnconvert ... > .export.log
|
2024-09-12 12:57:57 +08:00
|
|
|
os.dup2(log_fd, 1)
|
|
|
|
try:
|
|
|
|
sys.argv = convert_args
|
|
|
|
sys.argc = len(convert_args)
|
|
|
|
if self.mnnconvert is None:
|
|
|
|
from MNN.tools import mnnconvert
|
|
|
|
mnnconvert.main()
|
|
|
|
else:
|
|
|
|
convert_args[0] = self.mnnconvert
|
|
|
|
cmd = ' '.join(convert_args)
|
|
|
|
message = os.popen(cmd).read()
|
|
|
|
print(message)
|
|
|
|
sys.argv = []
|
|
|
|
finally:
|
|
|
|
os.dup2(sfd, 1)
|
|
|
|
os.close(log_fd)
|
|
|
|
|
|
|
|
@spinner_run(f'convert onnx model to ')
|
|
|
|
def onnx2mnn(self, onnx_path, mnn_path, args = []):
|
|
|
|
convert_args = [
|
|
|
|
'',
|
|
|
|
'-f',
|
|
|
|
'ONNX',
|
|
|
|
'--modelFile',
|
|
|
|
str(onnx_path),
|
|
|
|
'--MNNModel',
|
|
|
|
str(mnn_path),
|
|
|
|
'--transformerFuse',
|
|
|
|
'--allowCustomOp'
|
|
|
|
]
|
|
|
|
convert_args += args
|
|
|
|
self.convert(convert_args)
|
|
|
|
return mnn_path
|
|
|
|
|
|
|
|
def mnn2json(self, mnn_path, json_path):
|
|
|
|
convert_args = [
|
|
|
|
'',
|
|
|
|
'-f',
|
|
|
|
'MNN',
|
|
|
|
'--modelFile',
|
|
|
|
str(mnn_path),
|
|
|
|
'--JsonFile',
|
|
|
|
str(json_path)
|
|
|
|
]
|
|
|
|
self.convert(convert_args)
|
|
|
|
return json_path
|
|
|
|
|
|
|
|
def json2mnn(self, json_path, mnn_path):
|
|
|
|
convert_args = [
|
|
|
|
'',
|
|
|
|
'-f',
|
|
|
|
'JSON',
|
|
|
|
'--modelFile',
|
|
|
|
str(json_path),
|
|
|
|
'--MNNModel',
|
|
|
|
str(mnn_path)
|
|
|
|
]
|
|
|
|
self.convert(convert_args)
|
|
|
|
return mnn_path
|
|
|
|
|
2024-12-02 10:12:08 +08:00
|
|
|
def removeDupOps(self, mnn_path):
|
|
|
|
convert_args = [
|
|
|
|
'',
|
|
|
|
'-f',
|
|
|
|
'MNN',
|
|
|
|
'--modelFile',
|
|
|
|
str(mnn_path),
|
|
|
|
'--MNNModel',
|
|
|
|
str(mnn_path),
|
|
|
|
'--optimizeLevel=1'
|
|
|
|
]
|
|
|
|
self.convert(convert_args)
|
|
|
|
return mnn_path
|
|
|
|
|
2024-09-12 20:19:02 +08:00
|
|
|
def export(self, quant_bit = None, quant_block = None):
|
2024-09-12 12:57:57 +08:00
|
|
|
if self.weight_ops is None:
|
2024-09-12 20:19:02 +08:00
|
|
|
if quant_bit is None:
|
|
|
|
quant_bit = self.quant_bit
|
|
|
|
if quant_block is None:
|
|
|
|
quant_block = self.quant_block
|
|
|
|
if quant_bit == 16:
|
|
|
|
quant_args = ['--fp16']
|
|
|
|
else:
|
|
|
|
quant_args = [
|
|
|
|
'--weightQuantBits',
|
|
|
|
str(quant_bit),
|
|
|
|
'--weightQuantBlock',
|
|
|
|
str(quant_block)
|
|
|
|
]
|
2024-09-12 12:57:57 +08:00
|
|
|
self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, quant_args)
|
|
|
|
else:
|
|
|
|
mnn_json = f'{self.mnn_model_path}.json'
|
|
|
|
self.onnx2mnn(self.onnx_model_path, self.mnn_model_path)
|
|
|
|
self.mnn2json(self.mnn_model_path, mnn_json)
|
|
|
|
self.rebuild(mnn_json)
|
|
|
|
self.json2mnn(mnn_json, self.mnn_model_path)
|
2024-12-02 10:12:08 +08:00
|
|
|
self.removeDupOps(self.mnn_model_path)
|
2024-09-12 12:57:57 +08:00
|
|
|
|
2024-12-02 10:12:08 +08:00
|
|
|
@spinner_run(f'quant model weight to ', True)
|
2024-09-12 12:57:57 +08:00
|
|
|
def rebuild(self, json_path):
|
|
|
|
mnn_graph = json.load(open(json_path, 'rt'))
|
|
|
|
new_ops = []
|
2024-12-02 10:12:08 +08:00
|
|
|
with open(self.mnn_weight_path, 'wb') as self.mnn_weight:
|
|
|
|
for op in tqdm(mnn_graph['oplists'], 'Quant weights'):
|
|
|
|
if op['type'] == 'Extra':
|
|
|
|
new_ops += self.rebuild_op(op, mnn_graph)
|
|
|
|
else:
|
|
|
|
new_ops.append(op)
|
2024-09-12 12:57:57 +08:00
|
|
|
mnn_graph['oplists'] = new_ops
|
|
|
|
with open(json_path, 'w', encoding='utf-8') as file:
|
|
|
|
json.dump(mnn_graph, file, ensure_ascii=False, indent=4)
|
|
|
|
return self.mnn_weight_path
|
|
|
|
|
2024-11-18 14:37:45 +08:00
|
|
|
def quant(self, weight, quant_bit, quant_block, symmetric):
|
2024-12-02 10:12:08 +08:00
|
|
|
if torch.cuda.is_available():
|
|
|
|
weight = weight.cuda()
|
|
|
|
if torch.mps.is_available():
|
|
|
|
weight = weight.to('mps')
|
2024-09-12 12:57:57 +08:00
|
|
|
oc, ic = weight.shape
|
|
|
|
if quant_block == 0:
|
|
|
|
block_size = ic
|
|
|
|
else:
|
|
|
|
block_size = quant_block
|
|
|
|
block_num = ic // block_size
|
|
|
|
weight = weight.reshape(oc, block_num, block_size)
|
|
|
|
offset = 1 << (quant_bit - 1)
|
|
|
|
clip_max = offset - 1
|
2024-11-18 14:37:45 +08:00
|
|
|
if symmetric:
|
|
|
|
clip_min = -clip_max
|
2024-12-02 10:12:08 +08:00
|
|
|
abs_max, _ = torch.max(torch.abs(weight), axis=-1, keepdims=True)
|
2024-11-18 14:37:45 +08:00
|
|
|
scale = abs_max / clip_max
|
2024-12-02 10:12:08 +08:00
|
|
|
q_weight = torch.round(weight / scale)
|
|
|
|
q_weight = (torch.clamp(q_weight.flatten(), clip_min, clip_max) + offset).to(torch.uint8)
|
2024-11-18 14:37:45 +08:00
|
|
|
alpha = scale.flatten()
|
|
|
|
else:
|
|
|
|
clip_min = -offset
|
2024-12-02 10:12:08 +08:00
|
|
|
max_val, _ = torch.max(weight, axis=-1, keepdims=True)
|
|
|
|
min_val, _ = torch.min(weight, axis=-1, keepdims=True)
|
2024-11-18 14:37:45 +08:00
|
|
|
scale = (max_val - min_val) / (clip_max - clip_min)
|
|
|
|
|
|
|
|
if False:
|
|
|
|
q_weight = np.round((weight - min_val) / scale) + clip_min
|
|
|
|
zeros = min_val - scale * clip_min
|
|
|
|
else:
|
2024-12-02 10:12:08 +08:00
|
|
|
q_weight = torch.round(weight / scale) - torch.round(min_val / scale) + clip_min
|
|
|
|
zeros = (torch.round(min_val / scale) - clip_min) * scale
|
|
|
|
q_weight = (torch.clamp(q_weight.flatten(), clip_min, clip_max) + offset).to(torch.uint8)
|
|
|
|
alpha = torch.stack([zeros.flatten(), scale.flatten()], axis=-1).flatten()
|
2024-11-18 14:37:45 +08:00
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
q_weight = q_weight.reshape(-1, 2)
|
|
|
|
if quant_bit == 4:
|
|
|
|
q_weight = q_weight[:, 0] * 16 + q_weight[:, 1]
|
2024-11-18 14:37:45 +08:00
|
|
|
|
|
|
|
clip_min = 1
|
2024-09-12 12:57:57 +08:00
|
|
|
|
2024-12-02 10:12:08 +08:00
|
|
|
if q_weight.device is not torch.device('cpu'):
|
|
|
|
return q_weight.cpu(), alpha.float().cpu(), clip_min
|
|
|
|
return q_weight, alpha.float(), clip_min
|
|
|
|
|
|
|
|
def write_weight(self, data):
|
|
|
|
if isinstance(data, torch.Tensor):
|
|
|
|
data = data.numpy()
|
2024-09-12 12:57:57 +08:00
|
|
|
return self.mnn_weight.write(data.tobytes())
|
|
|
|
|
|
|
|
def write_header(self, ic, oc, quant_bit):
|
|
|
|
dim_num = self.mnn_weight.write(b'\x02')
|
|
|
|
shape_dtype = np.int16
|
|
|
|
if oc > 65535 or ic > 65535:
|
|
|
|
shape_dtype = np.int32
|
2024-12-02 10:12:08 +08:00
|
|
|
dim_length = self.write_weight(np.array([oc, ic]).astype(shape_dtype))
|
2024-09-12 12:57:57 +08:00
|
|
|
offset = 1 << (quant_bit - 1)
|
|
|
|
weight_map = [i for i in range(-offset, offset)]
|
|
|
|
if len(weight_map) == 256:
|
|
|
|
weight_map.insert(0, 0)
|
|
|
|
else:
|
|
|
|
weight_map.insert(0, len(weight_map))
|
2024-12-02 10:12:08 +08:00
|
|
|
map_length = self.write_weight(np.array(weight_map, dtype=np.int8))
|
2024-09-12 12:57:57 +08:00
|
|
|
header_length = dim_num + dim_length + map_length
|
|
|
|
return header_length, shape_dtype == np.int32
|
|
|
|
|
2024-11-18 14:37:45 +08:00
|
|
|
def build_weight(self, linear, quant_bit, quant_block, symmetric):
|
2024-09-12 12:57:57 +08:00
|
|
|
ic, oc = linear.in_features, linear.out_features
|
2024-11-18 14:37:45 +08:00
|
|
|
if quant_bit == 16:
|
2024-12-02 10:12:08 +08:00
|
|
|
half_weight = linear.weight.data.flatten().half()
|
|
|
|
weight_len = self.write_weight(half_weight)
|
2024-11-18 14:37:45 +08:00
|
|
|
alpha_len, q_min, shape_int32 = 0, 0, False
|
|
|
|
else:
|
|
|
|
assert(quant_bit in (4, 8))
|
|
|
|
q_weight, alpha, q_min = self.quant(linear.weight.data, quant_bit, quant_block, symmetric)
|
|
|
|
header_len, shape_int32 = self.write_header(ic, oc, quant_bit)
|
2024-12-02 10:12:08 +08:00
|
|
|
weight_len = self.write_weight(q_weight) + header_len
|
|
|
|
alpha_len = self.write_weight(alpha)
|
2024-09-12 12:57:57 +08:00
|
|
|
if linear.bias is not None:
|
2024-12-02 10:12:08 +08:00
|
|
|
bias = linear.bias.data.flatten().float()
|
|
|
|
bias_length = self.write_weight(bias)
|
2024-09-12 12:57:57 +08:00
|
|
|
else:
|
|
|
|
bias_length = 0
|
|
|
|
# bias = np.zeros([oc], dtype=np.float32)
|
2024-12-02 10:12:08 +08:00
|
|
|
# bias_length = self.write_weight(bias)
|
2024-09-12 12:57:57 +08:00
|
|
|
external = [self.mnn_weight_offset, weight_len, alpha_len, bias_length, 0]
|
|
|
|
self.mnn_weight_offset += (weight_len + alpha_len + bias_length)
|
2024-11-18 14:37:45 +08:00
|
|
|
return external, q_min, shape_int32, header_len
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
def build_tensor(self, graph, tensor_name):
|
|
|
|
tensor_idx = [len(graph['tensorName'])]
|
|
|
|
graph['tensorName'].append(tensor_name)
|
|
|
|
return tensor_idx
|
|
|
|
|
|
|
|
def rebuild_op(self, op, graph):
|
2024-11-18 14:37:45 +08:00
|
|
|
op_type = op['main']['type']
|
|
|
|
if op_type == 'FakeLinear':
|
|
|
|
return self.rebuild_linear(op, graph)
|
|
|
|
if op_type == 'FusedAttention':
|
|
|
|
return self.rebuild_attnention(op, graph)
|
|
|
|
|
|
|
|
def rebuild_attnention(self, op, graph):
|
|
|
|
attrs = op['main']['attr']
|
|
|
|
for attr in attrs:
|
|
|
|
if attr['key'] == 'name':
|
|
|
|
name = attr['s']
|
|
|
|
origin_input = op['inputIndexes']
|
|
|
|
origin_output = op['outputIndexes']
|
|
|
|
fused_attention = {
|
|
|
|
"inputIndexes": origin_input,
|
|
|
|
"main_type": "AttentionParam",
|
|
|
|
"main": { "kv_cache": True },
|
|
|
|
"name": name,
|
|
|
|
"outputIndexes": origin_output,
|
|
|
|
"type": "Attention",
|
|
|
|
"defaultDimentionFormat": "NHWC"
|
|
|
|
}
|
|
|
|
return [fused_attention]
|
|
|
|
|
|
|
|
def rebuild_linear(self, op, graph):
|
2024-09-12 12:57:57 +08:00
|
|
|
attrs = op['main']['attr']
|
|
|
|
for attr in attrs:
|
|
|
|
if attr['key'] == 'name':
|
|
|
|
name = attr['s']
|
|
|
|
elif attr['key'] == "in_features":
|
|
|
|
ic = attr["i"]
|
|
|
|
elif attr['key'] == "out_features":
|
|
|
|
oc = attr["i"]
|
|
|
|
elif attr['key'] == "has_bias":
|
|
|
|
has_bias = attr["i"]
|
|
|
|
linear = self.weight_ops[name]
|
|
|
|
assert(linear.in_features == ic and
|
|
|
|
linear.out_features == oc and
|
|
|
|
(linear.bias is not None) == has_bias)
|
|
|
|
|
2024-11-18 14:37:45 +08:00
|
|
|
is_lm = 'lm_head' in name
|
|
|
|
quant_bit = self.lm_quant_bit if is_lm else self.quant_bit
|
2024-10-29 19:32:47 +08:00
|
|
|
block_size = ic if self.quant_block == 0 else self.quant_block
|
2024-11-18 14:37:45 +08:00
|
|
|
external, q_min, shape_int32, header_len = self.build_weight(linear, quant_bit, self.quant_block, self.symmetric)
|
|
|
|
if is_lm and self.config.tie_word_embeddings:
|
|
|
|
weight_offset = external[0] + header_len
|
|
|
|
alpha_offset = external[0] + external[1]
|
|
|
|
alpha_size = external[2]
|
|
|
|
self.config.llm_config['tie_embeddings'] = [weight_offset, alpha_offset, alpha_size, quant_bit, self.quant_block]
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
origin_input = op['inputIndexes']
|
|
|
|
origin_output = op['outputIndexes']
|
|
|
|
# build new tensor
|
|
|
|
pre_reshape_name = f'{name}/pre_reshape'
|
|
|
|
pre_convert_name = f'{name}/pre_convert'
|
|
|
|
conv_name = name
|
|
|
|
post_convert_name = f'{name}/post_convert'
|
|
|
|
post_reshape_name = f'{name}/post_reshape'
|
|
|
|
pre_reshape_output = self.build_tensor(graph, pre_reshape_name)
|
|
|
|
pre_convert_output = self.build_tensor(graph, pre_convert_name)
|
|
|
|
conv_output = self.build_tensor(graph, conv_name)
|
|
|
|
post_convert_output = self.build_tensor(graph, post_convert_name)
|
|
|
|
# [batch, seq, hidden_size_i] -[Linear] -> [batch, seq, hidden_size_o]
|
|
|
|
# [1, seq, hidden_size_i] ->[Reshape]-> [seq, hidden_size_i, 1, 1]
|
|
|
|
# -[Convert]-[Convolution]-[Convert]-> [Reshape] -> [1, seq, hidden_size_o]
|
|
|
|
pre_reshape = {
|
|
|
|
"name": pre_reshape_name,
|
|
|
|
"type": "Reshape",
|
|
|
|
"inputIndexes": origin_input,
|
|
|
|
"outputIndexes": pre_reshape_output,
|
|
|
|
"main_type": "Reshape",
|
|
|
|
"main": {
|
|
|
|
"dims": [-1, ic, 1, 1],
|
|
|
|
"dimType": "NCHW"
|
|
|
|
},
|
|
|
|
"defaultDimentionFormat": "NHWC"
|
|
|
|
}
|
|
|
|
pre_convert = {
|
|
|
|
"name": pre_convert_name,
|
|
|
|
"inputIndexes": pre_reshape_output,
|
|
|
|
"outputIndexes": pre_convert_output,
|
|
|
|
"type": "ConvertTensor",
|
|
|
|
"main_type": "TensorConvertInfo",
|
|
|
|
"main": {
|
|
|
|
"source": "NCHW",
|
|
|
|
"dest": "NC4HW4"
|
|
|
|
},
|
|
|
|
"defaultDimentionFormat": "NHWC"
|
|
|
|
}
|
2024-11-18 14:37:45 +08:00
|
|
|
|
|
|
|
if quant_bit == 16:
|
|
|
|
quanParameter = { "type": 3 }
|
|
|
|
else:
|
|
|
|
if self.symmetric:
|
|
|
|
aMin = 0
|
|
|
|
readType = 0
|
|
|
|
else:
|
|
|
|
aMin = q_min
|
|
|
|
readType = oc * (ic // block_size)
|
|
|
|
|
|
|
|
quanParameter = {
|
|
|
|
"quantScale": 1.0, "scaleIn": 0.0, "scaleOut": 0.0,
|
|
|
|
"useInt32": False, "has_scaleInt": False, "shapeInt32": shape_int32,
|
|
|
|
"type": 1, "aMax": 0, "aMin": aMin, "readType": readType, "weightSize": 0
|
|
|
|
}
|
2024-09-12 12:57:57 +08:00
|
|
|
conv_op = {
|
|
|
|
"name": conv_name,
|
|
|
|
"inputIndexes": pre_convert_output,
|
|
|
|
"outputIndexes": conv_output,
|
|
|
|
"type": "Convolution",
|
|
|
|
"main_type": "Convolution2D",
|
|
|
|
"main": {
|
|
|
|
'common': {
|
|
|
|
'dilateX': 1, 'dilateY': 1, 'strideX': 1, 'strideY': 1,
|
|
|
|
'kernelX': 1, 'kernelY': 1, 'padX': 0, 'padY': 0, 'group': 1,
|
|
|
|
'outputCount': oc, 'relu': False, 'padMode': 'CAFFE',
|
|
|
|
'relu6': False, 'inputCount': ic, 'hasOutputShape': False
|
|
|
|
},
|
2024-11-18 14:37:45 +08:00
|
|
|
"quanParameter": quanParameter,
|
2024-09-12 12:57:57 +08:00
|
|
|
"external": external
|
|
|
|
},
|
|
|
|
"defaultDimentionFormat": "NHWC"
|
|
|
|
}
|
|
|
|
post_convert = {
|
|
|
|
"name": post_convert_name,
|
|
|
|
"inputIndexes": conv_output,
|
|
|
|
"outputIndexes": post_convert_output,
|
|
|
|
"type": "ConvertTensor",
|
|
|
|
"main_type": "TensorConvertInfo",
|
|
|
|
"main": {
|
|
|
|
"source": "NC4HW4",
|
|
|
|
"dest": "NCHW"
|
|
|
|
},
|
|
|
|
"defaultDimentionFormat": "NHWC"
|
|
|
|
}
|
|
|
|
post_reshape = {
|
|
|
|
"name": post_reshape_name,
|
|
|
|
"type": "Reshape",
|
|
|
|
"inputIndexes": post_convert_output,
|
|
|
|
"outputIndexes": origin_output,
|
|
|
|
"main_type": "Reshape",
|
|
|
|
"main": {
|
|
|
|
"dims": [1, -1, oc],
|
|
|
|
"dimType": "NCHW"
|
|
|
|
},
|
|
|
|
"defaultDimentionFormat": "NHWC"
|
|
|
|
}
|
|
|
|
return [pre_reshape, pre_convert, conv_op, post_convert, post_reshape]
|
|
|
|
|
|
|
|
# some wrapper class for export
|
|
|
|
class Embedding(torch.nn.Module):
|
|
|
|
def __init__(self, embed, config):
|
|
|
|
super().__init__()
|
|
|
|
self.hidden_size = config.hidden_size
|
|
|
|
self.embed = embed
|
|
|
|
if config.model_type == 'gemma2':
|
|
|
|
normalizer = torch.tensor(self.hidden_size**0.5)
|
|
|
|
self.embed.weight.data *= normalizer
|
|
|
|
|
|
|
|
def forward(self, input_ids):
|
|
|
|
inputs_embeds = self.embed(input_ids).view(-1, 1, self.hidden_size)
|
|
|
|
return inputs_embeds
|
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
|
|
if n_rep == 1:
|
|
|
|
return hidden_states
|
|
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
|
|
|
|
class Attention(torch.nn.Module):
|
2024-11-18 14:37:45 +08:00
|
|
|
def __init__(self, attn, layer_id, config):
|
2024-09-12 12:57:57 +08:00
|
|
|
super().__init__()
|
2024-11-18 14:37:45 +08:00
|
|
|
self.export_fused_attn = False
|
|
|
|
self.fused_attn = FusedAttention(config.hidden_size, f'/layers.{layer_id}/self_attn/FusedAttention')
|
|
|
|
self.layer_id = layer_id
|
2024-09-12 12:57:57 +08:00
|
|
|
self.hidden_size = config.hidden_size
|
|
|
|
self.head_dim = config.head_dim
|
2024-11-18 14:37:45 +08:00
|
|
|
if isinstance(config.num_attention_heads, list):
|
|
|
|
self.num_heads = config.num_attention_heads[layer_id]
|
|
|
|
self.num_key_value_heads = config.num_key_value_heads[layer_id]
|
|
|
|
else:
|
|
|
|
self.head_dim = config.head_dim
|
|
|
|
self.num_heads = config.num_attention_heads
|
|
|
|
self.num_key_value_heads = config.num_key_value_heads
|
2024-09-12 12:57:57 +08:00
|
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
|
|
self.rotary = config.rotary
|
2024-11-18 14:37:45 +08:00
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
ModelMapper.do_map(self, attn, config.model_map['attention'])
|
2024-11-18 14:37:45 +08:00
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
|
|
|
|
# split qkv linear to q, k, v
|
|
|
|
split_sizes = [self.hidden_size] * 3
|
|
|
|
if self.qkv_proj.weight.shape[0] != self.hidden_size * 3:
|
|
|
|
# M/GQA
|
2024-11-18 14:37:45 +08:00
|
|
|
split_sizes = [
|
|
|
|
self.num_heads * self.head_dim, # q_size
|
|
|
|
self.num_key_value_heads * self.head_dim, # k_size
|
|
|
|
self.num_key_value_heads * self.head_dim # v_size
|
|
|
|
]
|
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
self.q_proj = torch.nn.Linear(self.hidden_size, split_sizes[0])
|
|
|
|
self.k_proj = torch.nn.Linear(self.hidden_size, split_sizes[1])
|
|
|
|
self.v_proj = torch.nn.Linear(self.hidden_size, split_sizes[2])
|
|
|
|
if config.model_type == 'chatglm':
|
|
|
|
# chatglm-6b
|
|
|
|
qkv_weight = self.qkv_proj.weight.data.view(self.num_heads, 3, self.head_dim, self.hidden_size)
|
|
|
|
self.q_proj.weight.data = qkv_weight[:, 0, :, :].reshape(self.hidden_size, self.hidden_size)
|
|
|
|
self.k_proj.weight.data = qkv_weight[:, 1, :, :].reshape(self.hidden_size, self.hidden_size)
|
|
|
|
self.v_proj.weight.data = qkv_weight[:, 2, :, :].reshape(self.hidden_size, self.hidden_size)
|
|
|
|
qkv_bias = self.qkv_proj.bias.data.view(self.num_heads, 3, self.head_dim)
|
|
|
|
self.q_proj.bias.data = qkv_bias[:, 0, :].reshape(self.hidden_size)
|
|
|
|
self.k_proj.bias.data = qkv_bias[:, 1, :].reshape(self.hidden_size)
|
|
|
|
self.v_proj.bias.data = qkv_bias[:, 2, :].reshape(self.hidden_size)
|
|
|
|
else:
|
|
|
|
# other
|
|
|
|
qw, kw, vw = torch.split(self.qkv_proj.weight, split_sizes)
|
|
|
|
self.q_proj.weight.data = qw
|
|
|
|
self.k_proj.weight.data = kw
|
|
|
|
self.v_proj.weight.data = vw
|
|
|
|
if self.qkv_proj.bias is not None:
|
|
|
|
qb, kb, vb = torch.split(self.qkv_proj.bias, split_sizes)
|
|
|
|
self.q_proj.bias.data = qb
|
|
|
|
self.k_proj.bias.data = kb
|
|
|
|
self.v_proj.bias.data = vb
|
2024-11-18 14:37:45 +08:00
|
|
|
else:
|
|
|
|
self.q_proj.bias.data = torch.zeros(split_sizes[0])
|
|
|
|
self.k_proj.bias.data = torch.zeros(split_sizes[1])
|
|
|
|
self.v_proj.bias.data = torch.zeros(split_sizes[2])
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
|
|
rotary_pos_emb: Optional[torch.Tensor] = None,
|
2024-11-18 14:37:45 +08:00
|
|
|
cross_attention_states: Optional[torch.Tensor] = None,
|
2024-09-12 12:57:57 +08:00
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
2024-11-18 14:37:45 +08:00
|
|
|
if cross_attention_states is not None:
|
|
|
|
hidden_states = cross_attention_states
|
2024-09-12 12:57:57 +08:00
|
|
|
key_states = self.k_proj(hidden_states)
|
|
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
|
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
2024-11-18 14:37:45 +08:00
|
|
|
# openelm model has qk_norm
|
|
|
|
if hasattr(self, 'q_norm') and self.q_norm is not None and \
|
|
|
|
hasattr(self, 'k_norm') and self.k_norm is not None :
|
|
|
|
query_states = self.q_norm(query_states)
|
|
|
|
key_states = self.k_norm(key_states)
|
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
kv_seq_len = key_states.shape[1]
|
|
|
|
if past_key_value is not None:
|
|
|
|
kv_seq_len += past_key_value[0].shape[1]
|
|
|
|
|
|
|
|
# rope
|
|
|
|
cos, sin = rotary_pos_emb[0], rotary_pos_emb[1]
|
|
|
|
query_states = self.rotary.apply_rotary_pos(query_states, cos, sin)
|
|
|
|
key_states = self.rotary.apply_rotary_pos(key_states, cos, sin)
|
2024-11-18 14:37:45 +08:00
|
|
|
|
|
|
|
if self.export_fused_attn:
|
|
|
|
attn_output = self.fused_attn(query_states, key_states, value_states, attention_mask)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output, past_key_value
|
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
# kv cache
|
|
|
|
if past_key_value is not None:
|
|
|
|
past_key, past_value = past_key_value[0], past_key_value[1]
|
|
|
|
key_states = torch.cat((past_key, key_states), dim=1)
|
|
|
|
value_states = torch.cat((past_value, value_states), dim=1)
|
|
|
|
|
|
|
|
past_key_value = torch.stack((key_states, value_states))
|
|
|
|
query_states = query_states.transpose(1, 2)
|
|
|
|
key_states = key_states.permute([0, 2, 3, 1])
|
|
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
# repeat k/v heads if n_kv_heads < n_heads
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
#------- attention ----------
|
|
|
|
# query_states @ key_states
|
|
|
|
attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim)
|
|
|
|
# attention_mask
|
|
|
|
if attention_mask.dtype in (torch.bool, torch.int32):
|
|
|
|
# chatglm
|
|
|
|
attn_weights.masked_fill_(attention_mask, -10000.0)
|
|
|
|
else:
|
|
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
# upcast softmax to fp32
|
|
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
|
# attn_weights @ value_states
|
|
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output, past_key_value
|
|
|
|
|
|
|
|
def rotate_half(x):
|
|
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
|
|
class Rotary(torch.nn.Module):
|
|
|
|
def __init__(self, config):
|
|
|
|
super().__init__()
|
|
|
|
self.rope_theta = config.rope_theta
|
|
|
|
self.rotary_dim = config.head_dim
|
|
|
|
self.model_type = config.model_type
|
|
|
|
if hasattr(config, 'rotary_dim'):
|
|
|
|
self.rotary_dim = config.rotary_dim
|
|
|
|
if self.model_type == 'chatglm':
|
|
|
|
self.rotary_dim = config.head_dim // 2
|
|
|
|
|
|
|
|
def forward(self, position_ids):
|
|
|
|
theta = 1.0 / (self.rope_theta ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim))
|
|
|
|
position_ids = position_ids.float().reshape(-1, 1)
|
|
|
|
idx_theta = position_ids * theta
|
|
|
|
rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)])
|
|
|
|
if self.model_type != 'chatglm2':
|
|
|
|
rotary_pos_emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
|
|
rotary_pos_emb = rotary_pos_emb.unsqueeze(2).unsqueeze(1)
|
|
|
|
return rotary_pos_emb
|
|
|
|
|
|
|
|
def apply_rotary_pos(self, x, cos, sin):
|
|
|
|
if self.model_type == 'chatglm':
|
|
|
|
return self.chatglm_rotary_pos(x, cos, sin)
|
|
|
|
if self.model_type == 'chatglm2':
|
|
|
|
return self.chatglm2_rotary_pos(x, cos, sin)
|
|
|
|
if self.model_type == 'phi-msft':
|
|
|
|
return self.phi_rotary_pos(x, cos, sin)
|
|
|
|
return self.llama_rotary_pos(x, cos, sin)
|
|
|
|
|
|
|
|
def llama_rotary_pos(self, x, cos, sin):
|
|
|
|
x = (x * cos) + (rotate_half(x) * sin)
|
|
|
|
return x
|
|
|
|
|
|
|
|
def phi_rotary_pos(self, x, cos, sin):
|
|
|
|
x, x_pass = x[..., :self.rotary_dim], x[..., self.rotary_dim:]
|
|
|
|
x = (x * cos) + (rotate_half(x) * sin)
|
|
|
|
return torch.cat((x, x_pass), dim=-1)
|
|
|
|
|
|
|
|
def chatglm2_rotary_pos(self, x, cos, sin):
|
|
|
|
x, x_pass = x[..., :self.rotary_dim], x[..., self.rotary_dim:]
|
|
|
|
b, s, n, h = x.shape
|
|
|
|
xshaped = x.view(b, s, n, h//2, 2)
|
|
|
|
x = torch.concat(
|
|
|
|
[
|
|
|
|
xshaped[..., 0] * cos - xshaped[..., 1] * sin,
|
|
|
|
xshaped[..., 1] * cos + xshaped[..., 0] * sin,
|
|
|
|
],
|
|
|
|
-1,
|
|
|
|
)
|
|
|
|
return torch.cat((x, x_pass), dim=-1)
|
|
|
|
|
|
|
|
def chatglm_rotary_pos(self, x, cos, sin):
|
|
|
|
seq = x.shape[1]
|
|
|
|
x1, x2 = x[..., :self.rotary_dim], x[..., self.rotary_dim:]
|
|
|
|
cos1, sin1 = cos[:, :seq, ...], sin[:, :seq, ...]
|
|
|
|
cos2, sin2 = cos[:, seq:, ...], sin[:, seq:, ...]
|
|
|
|
x1 = (x1 * cos1) + (rotate_half(x1) * sin1)
|
|
|
|
x2 = (x2 * cos2) + (rotate_half(x2) * sin2)
|
|
|
|
return torch.cat((x1, x2), dim=-1)
|
|
|
|
|
|
|
|
class Decoder(torch.nn.Module):
|
2024-11-18 14:37:45 +08:00
|
|
|
def __init__(self, decoder, layer_id, config):
|
2024-09-12 12:57:57 +08:00
|
|
|
super().__init__()
|
2024-11-18 14:37:45 +08:00
|
|
|
self.cross_decoder = False
|
2024-09-12 12:57:57 +08:00
|
|
|
ModelMapper.do_map(self, decoder, config.model_map['decoder'])
|
2024-11-18 14:37:45 +08:00
|
|
|
# mllama has cross_attn
|
|
|
|
if hasattr(self, 'cross_attn') and self.cross_attn is not None:
|
|
|
|
self.cross_decoder = True
|
|
|
|
self.self_attn = Attention(self.cross_attn, layer_id, config)
|
|
|
|
else:
|
|
|
|
self.self_attn = Attention(self.self_attn, layer_id, config)
|
2024-09-12 12:57:57 +08:00
|
|
|
self.hidden_size = config.hidden_size
|
|
|
|
# chatglm
|
|
|
|
self.alpha = (2 * config.num_hidden_layers) ** 0.5 if config.model_type == 'chatglm' else 1.0
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
rotary_pos_emb: Optional[torch.Tensor] = None,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
2024-11-18 14:37:45 +08:00
|
|
|
cross_attention_states: Optional[torch.Tensor] = None,
|
|
|
|
cross_attention_mask: Optional[torch.Tensor] = None,
|
2024-09-12 12:57:57 +08:00
|
|
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
|
|
hidden_states = hidden_states.view(1, -1, self.hidden_size)
|
|
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
norm_hidden_states = hidden_states
|
|
|
|
# Self Attention
|
|
|
|
hidden_states, present_key_value = self.self_attn(
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
rotary_pos_emb=rotary_pos_emb,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
past_key_value=past_key_value,
|
2024-11-18 14:37:45 +08:00
|
|
|
cross_attention_states=cross_attention_states,
|
2024-09-12 12:57:57 +08:00
|
|
|
)
|
|
|
|
# Fully Connected
|
|
|
|
if not hasattr(self, 'post_attention_layernorm'):
|
|
|
|
# phi
|
|
|
|
feed_forward_hidden_states = self.mlp(norm_hidden_states)
|
|
|
|
hidden_states = hidden_states + feed_forward_hidden_states + residual
|
|
|
|
elif self.alpha != 1.0:
|
|
|
|
# chatglm-6b
|
|
|
|
hidden_states = norm_hidden_states * self.alpha + hidden_states
|
|
|
|
mlp_input = self.post_attention_layernorm(hidden_states)
|
|
|
|
mlp_output = self.mlp(mlp_input)
|
|
|
|
hidden_states = mlp_input * self.alpha + mlp_output
|
|
|
|
elif hasattr(self, 'pre_feedforward_layernorm'):
|
|
|
|
# gemma2
|
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
|
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
|
|
|
hidden_states = residual + hidden_states
|
2024-11-18 14:37:45 +08:00
|
|
|
elif cross_attention_mask is not None:
|
|
|
|
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
|
|
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
hidden_states = cross_attention_mask * hidden_states
|
|
|
|
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
2024-09-12 12:57:57 +08:00
|
|
|
else:
|
|
|
|
# general
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
|
|
|
|
return hidden_states, present_key_value
|
|
|
|
|
|
|
|
class Lm(torch.nn.Module):
|
|
|
|
def __init__(self, lm_, final_layernorm_, config):
|
|
|
|
super().__init__()
|
|
|
|
self.final_layernorm = final_layernorm_
|
|
|
|
self.lm = lm_
|
|
|
|
self.hidden_size = config.hidden_size
|
2024-11-18 14:37:45 +08:00
|
|
|
self.ppl = config.ppl
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
def forward(self, hidden_states):
|
2024-11-18 14:37:45 +08:00
|
|
|
if not self.ppl:
|
|
|
|
# just need last logit for predict next token
|
|
|
|
hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
|
2024-09-12 12:57:57 +08:00
|
|
|
hidden_states = self.final_layernorm(hidden_states)
|
|
|
|
m_logits = self.lm(hidden_states)
|
|
|
|
return m_logits
|
|
|
|
|
2024-09-12 20:19:02 +08:00
|
|
|
class Visual(torch.nn.Module):
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
super().__init__()
|
|
|
|
self.visual = visual.eval()
|
|
|
|
self.embed_ = base.embed
|
|
|
|
self.tokenizer = base.tokenizer
|
|
|
|
self.config = base.config
|
|
|
|
self.hidden_size = base.hidden_size
|
|
|
|
self.llm_config = base.llm_config
|
2024-11-18 14:37:45 +08:00
|
|
|
# mllama
|
|
|
|
self.cross_attention_states = None
|
|
|
|
self.cross_attention_mask = None
|
2024-09-12 20:19:02 +08:00
|
|
|
self.init_config()
|
|
|
|
self.load()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_visual(model_type):
|
|
|
|
visual_models = {
|
|
|
|
'qwen': QwenVisual,
|
2024-11-18 14:37:45 +08:00
|
|
|
'qwen2_vl': Qwen2Visual,
|
|
|
|
'mllama': MllamaVision
|
2024-09-12 20:19:02 +08:00
|
|
|
}
|
|
|
|
if model_type in visual_models:
|
|
|
|
return visual_models[model_type]
|
|
|
|
return None
|
|
|
|
|
|
|
|
def init_config(self):
|
|
|
|
from transformers.image_utils import (OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)
|
|
|
|
self.llm_config['is_visual'] = True
|
|
|
|
image_mean = np.array(OPENAI_CLIP_MEAN) * 255.0
|
|
|
|
image_norm = 1 / (np.array(OPENAI_CLIP_STD) * 255.0)
|
|
|
|
self.llm_config['image_mean'] = image_mean.tolist()
|
|
|
|
self.llm_config['image_norm'] = image_norm.tolist()
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def str_to_ids(self, prompt):
|
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids']
|
|
|
|
return input_ids
|
|
|
|
|
|
|
|
def forward(self, images):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def embed(self, input_ids, images = None, videos = None):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
class QwenVisual(Visual):
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
self.quant_bit = 16
|
|
|
|
super().__init__(visual, base)
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
self.image_start_id = self.config.visual['image_start_id']
|
|
|
|
self.image_size = self.config.visual['image_size']
|
|
|
|
self.llm_config['is_visual'] = True
|
|
|
|
self.llm_config['image_size'] = self.image_size
|
|
|
|
self.llm_config['vision_start'] = self.tokenizer.img_start_id
|
|
|
|
self.llm_config['vision_end'] = self.tokenizer.img_end_id
|
|
|
|
self.llm_config['image_pad'] = self.tokenizer.img_pad_id
|
|
|
|
|
|
|
|
def forward(self, images):
|
|
|
|
return self.visual(images).transpose(1, 0)
|
|
|
|
|
|
|
|
def embed(self, input_ids, images = None, videos = None):
|
|
|
|
if not torch.any(input_ids == self.image_start_id):
|
|
|
|
return self.embed_(input_ids)
|
|
|
|
bos_pos = torch.where(input_ids == self.image_start_id)
|
|
|
|
eos_pos = torch.where(input_ids == self.image_start_id + 1)
|
|
|
|
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
|
|
|
|
images = []
|
|
|
|
for i, a, b in img_pos:
|
|
|
|
image = input_ids[i][a + 1 : b - 1].tolist()
|
|
|
|
image = image[ : image.index(self.image_start_id + 2)]
|
|
|
|
images.append(bytes(image).decode('utf-8'))
|
|
|
|
images = self.visual.encode(images).transpose(1, 0)
|
|
|
|
hidden_states = self.embed_(input_ids)
|
|
|
|
for idx, (i, a, b) in enumerate(img_pos):
|
|
|
|
hidden_states[a + 1 : b, i] = images[:, idx]
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
class Qwen2Visual(Visual):
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
self.quant_bit = 4
|
|
|
|
self.temporal_patch_size = 2
|
|
|
|
self.patch_size = 14
|
|
|
|
self.merge_size = 2
|
|
|
|
self.image_size = 420
|
|
|
|
self.image_embeds = None
|
|
|
|
super().__init__(visual, base)
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
self.vision_start_id = self.config.vision_start_token_id
|
|
|
|
self.vision_end_id = self.config.vision_end_token_id
|
|
|
|
self.image_pad_id = self.config.image_token_id
|
|
|
|
self.llm_config['image_size'] = self.image_size
|
|
|
|
self.llm_config['vision_start'] = self.vision_start_id
|
|
|
|
self.llm_config['vision_end'] = self.vision_end_id
|
|
|
|
self.llm_config['image_pad'] = self.image_pad_id
|
|
|
|
|
|
|
|
def str_to_ids(self, prompt):
|
|
|
|
if '<img>' in prompt and '</img>' in prompt:
|
|
|
|
import re
|
|
|
|
import requests
|
|
|
|
from PIL import Image
|
|
|
|
pattern = r'(<img>.*?</img>)'
|
|
|
|
parts = re.split(pattern, prompt)
|
|
|
|
txt_prompt = ''
|
|
|
|
for part in parts:
|
|
|
|
if re.match(pattern, part):
|
|
|
|
img_content = re.search(r'<img>(.*?)</img>', part).group(1)
|
|
|
|
if img_content.startswith('http://') or img_content.startswith('https://'):
|
|
|
|
image_obj = Image.open(requests.get(img_content, stream=True).raw)
|
|
|
|
img_pad_len = self.img_process(image_obj)
|
|
|
|
img_pad_str = '<|image_pad|>' * img_pad_len
|
|
|
|
img_str = f'<|vision_start|>{img_pad_str}<|vision_end|>'
|
|
|
|
txt_prompt += img_str
|
|
|
|
else:
|
|
|
|
txt_prompt += part
|
|
|
|
else:
|
|
|
|
txt_prompt = prompt
|
|
|
|
input_ids = self.tokenizer(txt_prompt, return_tensors="pt")['input_ids']
|
|
|
|
return input_ids
|
|
|
|
|
|
|
|
def forward(self, images):
|
|
|
|
images = [images] * self.temporal_patch_size
|
|
|
|
patches = torch.concat(images, axis=0)
|
|
|
|
_, channel, height, width = patches.shape
|
|
|
|
grid_t = patches.shape[0] // self.temporal_patch_size
|
|
|
|
grid_h, grid_w = height // self.patch_size, width // self.patch_size
|
|
|
|
patches = patches.reshape(
|
|
|
|
grid_t,
|
|
|
|
self.temporal_patch_size,
|
|
|
|
channel,
|
|
|
|
grid_h // self.merge_size,
|
|
|
|
self.merge_size,
|
|
|
|
self.patch_size,
|
|
|
|
grid_w // self.merge_size,
|
|
|
|
self.merge_size,
|
|
|
|
self.patch_size,
|
|
|
|
)
|
|
|
|
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
|
|
|
flatten_patches = patches.reshape(
|
|
|
|
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
|
|
|
)
|
|
|
|
image_grid_thw = torch.tensor([[grid_t, grid_h, grid_w]])
|
|
|
|
image_embeds = self.visual(flatten_patches, image_grid_thw)
|
|
|
|
image_embeds = image_embeds.unsqueeze(1)
|
|
|
|
return image_embeds
|
|
|
|
|
|
|
|
def img_process(self, image):
|
|
|
|
resized_height = self.image_size
|
|
|
|
resized_width = self.image_size
|
|
|
|
from transformers.image_transforms import (
|
|
|
|
convert_to_rgb,
|
|
|
|
resize,
|
|
|
|
rescale,
|
|
|
|
normalize
|
|
|
|
)
|
|
|
|
from transformers.image_utils import (
|
|
|
|
OPENAI_CLIP_MEAN,
|
|
|
|
OPENAI_CLIP_STD,
|
|
|
|
PILImageResampling,
|
|
|
|
infer_channel_dimension_format,
|
|
|
|
to_numpy_array
|
|
|
|
)
|
|
|
|
image = convert_to_rgb(image)
|
|
|
|
image = to_numpy_array(image)
|
|
|
|
format = infer_channel_dimension_format(image)
|
|
|
|
resample = PILImageResampling.BICUBIC
|
|
|
|
image = resize(image, size=(resized_height, resized_width), resample=resample, input_data_format=format)
|
|
|
|
image = rescale(image, scale=1 / 255.0, input_data_format=format)
|
|
|
|
image = normalize(image=image, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_data_format=format)
|
|
|
|
image = np.expand_dims(image, [0])
|
|
|
|
image = image.transpose(0, 3, 1, 2)
|
|
|
|
image = torch.from_numpy(image)
|
|
|
|
self.image_embeds = self.forward(image)
|
|
|
|
return self.image_embeds.shape[0]
|
|
|
|
|
|
|
|
def embed(self, input_ids, images = None, videos = None):
|
|
|
|
input_embeds = self.embed_(input_ids)
|
|
|
|
if self.image_embeds is not None:
|
|
|
|
image_mask = (input_ids == self.image_pad_id).squeeze()
|
|
|
|
input_embeds[image_mask] = self.image_embeds
|
|
|
|
return input_embeds
|
|
|
|
|
2024-11-18 14:37:45 +08:00
|
|
|
class MllamaVision(Visual):
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
super().__init__(visual, base)
|
|
|
|
self.image_objs = []
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
self.llm_config['is_visual'] = True
|
|
|
|
self.llm_config['image_size'] = self.config.vision_config.image_size
|
|
|
|
self.image_size = self.config.vision_config.image_size
|
|
|
|
|
|
|
|
def str_to_ids(self, prompt):
|
|
|
|
if '<img>' in prompt and '</img>' in prompt:
|
|
|
|
import re
|
|
|
|
import requests
|
|
|
|
from PIL import Image
|
|
|
|
pattern = r'(<img>.*?</img>)'
|
|
|
|
parts = re.split(pattern, prompt)
|
|
|
|
txt_prompt = ''
|
|
|
|
for part in parts:
|
|
|
|
if re.match(pattern, part):
|
|
|
|
img_content = re.search(r'<img>(.*?)</img>', part).group(1)
|
|
|
|
if img_content.startswith('http://') or img_content.startswith('https://'):
|
|
|
|
self.image_objs.append(Image.open(requests.get(img_content, stream=True).raw))
|
|
|
|
txt_prompt += '<|image|>'
|
|
|
|
else:
|
|
|
|
txt_prompt += part
|
|
|
|
else:
|
|
|
|
txt_prompt = prompt
|
|
|
|
input_ids = self.tokenizer(txt_prompt, return_tensors="pt")['input_ids']
|
|
|
|
# image process
|
|
|
|
for img in self.image_objs:
|
|
|
|
image_embeds = self.img_process(img)
|
|
|
|
print(image_embeds.shape)
|
|
|
|
pass
|
|
|
|
return input_ids
|
|
|
|
|
|
|
|
def img_process(self, image):
|
|
|
|
resized_height = self.image_size
|
|
|
|
resized_width = self.image_size
|
|
|
|
from transformers.image_transforms import (
|
|
|
|
convert_to_rgb,
|
|
|
|
resize,
|
|
|
|
rescale,
|
|
|
|
normalize
|
|
|
|
)
|
|
|
|
from transformers.image_utils import (
|
|
|
|
OPENAI_CLIP_MEAN,
|
|
|
|
OPENAI_CLIP_STD,
|
|
|
|
PILImageResampling,
|
|
|
|
infer_channel_dimension_format,
|
|
|
|
to_numpy_array
|
|
|
|
)
|
|
|
|
image = convert_to_rgb(image)
|
|
|
|
image = to_numpy_array(image)
|
|
|
|
format = infer_channel_dimension_format(image)
|
|
|
|
resample = PILImageResampling.BICUBIC
|
|
|
|
image = resize(image, size=(resized_height, resized_width), resample=resample, input_data_format=format)
|
|
|
|
image = rescale(image, scale=1 / 255.0, input_data_format=format)
|
|
|
|
image = normalize(image=image, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_data_format=format)
|
|
|
|
image = image.transpose(2, 0, 1)
|
|
|
|
image = np.expand_dims(image, [0, 1, 2])
|
|
|
|
pad_val = np.zeros_like(image)
|
|
|
|
image = np.concatenate([image, pad_val, pad_val, pad_val], axis=2)
|
|
|
|
print(image.shape)
|
|
|
|
image = torch.from_numpy(image)
|
|
|
|
image_embeds = self.forward(image)
|
|
|
|
print(image_embeds.shape)
|
|
|
|
return image_embeds
|
|
|
|
|
|
|
|
def forward(self, images):
|
|
|
|
aspect_ratio_ids = torch.tensor([[1]])
|
|
|
|
aspect_ratio_mask = torch.tensor([[[1, 0, 0, 0]]])
|
|
|
|
return self.visual(images, aspect_ratio_ids, aspect_ratio_mask)
|
|
|
|
|
|
|
|
def embed(self, input_ids, images = None, videos = None):
|
|
|
|
return self.embed_(input_ids)
|
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
class LlmExporter(torch.nn.Module):
|
|
|
|
'''
|
|
|
|
Base class for all llm model export. Inherits from [`torch.nn.Module`].
|
|
|
|
'''
|
|
|
|
|
|
|
|
def __init__(self, args):
|
|
|
|
super().__init__()
|
|
|
|
self.init_from_args(args)
|
|
|
|
self.load_model(args.path)
|
|
|
|
|
|
|
|
def init_from_args(self, args):
|
2024-11-18 14:37:45 +08:00
|
|
|
self.max_length = 128
|
2024-09-12 12:57:57 +08:00
|
|
|
self.stop_ids = []
|
|
|
|
self.visual = None
|
|
|
|
self.dst_name = 'llm'
|
|
|
|
# load config from args
|
|
|
|
self.path = args.path
|
|
|
|
self.dst_path = args.dst_path
|
2024-09-12 20:19:02 +08:00
|
|
|
self.onnx_path = os.path.join(self.dst_path, 'onnx')
|
2024-11-18 14:37:45 +08:00
|
|
|
self.tokenizer_path = args.tokenizer_path
|
2024-09-12 12:57:57 +08:00
|
|
|
self.lora_path = args.lora_path
|
2024-11-18 14:37:45 +08:00
|
|
|
self.onnx_slim = args.onnx_slim
|
|
|
|
self.ppl = args.ppl
|
|
|
|
self.awq = args.awq
|
2024-09-12 12:57:57 +08:00
|
|
|
self.quant_bit = args.quant_bit
|
|
|
|
self.quant_block = args.quant_block
|
2024-11-18 14:37:45 +08:00
|
|
|
self.symmetric = args.sym
|
2024-09-12 12:57:57 +08:00
|
|
|
self.mnnconvert = args.mnnconvert
|
2024-11-18 14:37:45 +08:00
|
|
|
if self.tokenizer_path is None:
|
|
|
|
self.tokenizer_path = self.path
|
2024-09-12 12:57:57 +08:00
|
|
|
if args.lm_quant_bit is not None:
|
|
|
|
self.lm_quant_bit = args.lm_quant_bit
|
|
|
|
else:
|
|
|
|
self.lm_quant_bit = self.quant_bit
|
|
|
|
# init export dst dir
|
|
|
|
if not os.path.exists(self.dst_path):
|
|
|
|
os.makedirs(self.dst_path)
|
2024-09-12 20:19:02 +08:00
|
|
|
if not os.path.exists(self.onnx_path):
|
|
|
|
os.makedirs(self.onnx_path)
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
def load_pretrained(self, model_path: str):
|
2024-11-18 14:37:45 +08:00
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True, use_fast=False)
|
2024-09-12 20:19:02 +08:00
|
|
|
if 'Qwen2-VL' in model_path:
|
|
|
|
from transformers import Qwen2VLForConditionalGeneration
|
2024-12-02 10:12:08 +08:00
|
|
|
self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype='auto').eval()
|
2024-11-18 14:37:45 +08:00
|
|
|
elif 'Llama-3.2' in model_path and 'Vision' in model_path:
|
|
|
|
from transformers import MllamaForConditionalGeneration
|
2024-12-02 10:12:08 +08:00
|
|
|
self.model = MllamaForConditionalGeneration.from_pretrained(model_path, torch_dtype='auto').eval()
|
2024-09-12 20:19:02 +08:00
|
|
|
else:
|
|
|
|
try:
|
2024-12-02 10:12:08 +08:00
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype='auto', trust_remote_code=True).eval()
|
2024-09-12 20:19:02 +08:00
|
|
|
except:
|
2024-12-02 10:12:08 +08:00
|
|
|
self.model = AutoModel.from_pretrained(model_path, torch_dtype='auto', trust_remote_code=True).eval()
|
2024-09-12 12:57:57 +08:00
|
|
|
self.config = self.model.config
|
|
|
|
if self.lora_path is not None:
|
|
|
|
from peft import PeftModel
|
|
|
|
adapter = PeftModel.from_pretrained(self.model, model_id=self.lora_path)
|
|
|
|
self.model = adapter.merge_and_unload(progressbar=True)
|
|
|
|
|
2024-09-12 20:19:02 +08:00
|
|
|
@staticmethod
|
|
|
|
def has_attr(obj, attr):
|
|
|
|
return hasattr(obj, attr) and getattr(obj, attr) is not None
|
|
|
|
|
2024-12-02 10:12:08 +08:00
|
|
|
@spinner_run(f'load pretrained model ', True)
|
2024-09-12 12:57:57 +08:00
|
|
|
def load_model(self, model_path):
|
|
|
|
self.load_pretrained(model_path)
|
|
|
|
self.attention_mask_type = 'float'
|
|
|
|
# load tokenizer info
|
|
|
|
self.stop_ids.append(self.tokenizer.eos_token_id)
|
|
|
|
if hasattr(self.tokenizer, 'im_end_id'):
|
|
|
|
self.stop_ids.append(self.tokenizer.im_end_id)
|
|
|
|
eot_id = self.tokenizer.encode('<|eot_id|>')
|
|
|
|
if len(eot_id) == 1:
|
|
|
|
self.stop_ids.append(eot_id[0])
|
2024-09-12 20:19:02 +08:00
|
|
|
if hasattr(self.model, 'generation_config') and self.model.generation_config is not None:
|
2024-09-12 12:57:57 +08:00
|
|
|
eos_token_id = self.model.generation_config.eos_token_id
|
|
|
|
from collections.abc import Iterable
|
|
|
|
if isinstance(eos_token_id, int):
|
|
|
|
self.stop_ids.append(eos_token_id)
|
|
|
|
elif isinstance(eos_token_id, Iterable):
|
|
|
|
for id in eos_token_id:
|
|
|
|
self.stop_ids.append(id)
|
|
|
|
self.stop_ids = [stop_id for stop_id in self.stop_ids if stop_id is not None]
|
|
|
|
self.stop_ids = list(set(self.stop_ids))
|
|
|
|
model_mapper = ModelMapper()
|
|
|
|
|
2024-11-18 14:37:45 +08:00
|
|
|
self.tie_word_embeddings = (hasattr(self.config, 'tie_word_embeddings') and self.config.tie_word_embeddings)
|
2024-09-12 12:57:57 +08:00
|
|
|
self.model_type, self.model_map = model_mapper.get_map(self.config)
|
2024-12-02 10:12:08 +08:00
|
|
|
if self.awq:
|
|
|
|
self.model.float()
|
|
|
|
else:
|
|
|
|
# set norm's weight as float for export
|
|
|
|
def visit_module(module):
|
|
|
|
if not isinstance(module, torch.nn.Linear) and hasattr(module, 'weight'):
|
|
|
|
module.float()
|
|
|
|
for name, child in module.named_children():
|
|
|
|
visit_module(child)
|
|
|
|
visit_module(self.model)
|
2024-09-12 20:19:02 +08:00
|
|
|
# print(self.config, self.model_type, self.model_map, self.model)
|
2024-09-12 12:57:57 +08:00
|
|
|
# load config info
|
|
|
|
ModelMapper.do_map(self, self.config, self.model_map['config'])
|
|
|
|
if not hasattr(self, 'num_key_value_heads') or self.num_key_value_heads is None:
|
|
|
|
self.num_key_value_heads = self.num_attention_heads
|
|
|
|
if not hasattr(self, 'rope_theta') or self.rope_theta is None:
|
|
|
|
self.rope_theta = 10000.0
|
|
|
|
if not hasattr(self, 'head_dim') or self.head_dim is None:
|
2024-11-18 14:37:45 +08:00
|
|
|
if isinstance(self.num_attention_heads, list):
|
|
|
|
self.head_dim = [self.hidden_size // atten_head for atten_head in self.num_attention_heads]
|
|
|
|
else:
|
|
|
|
self.head_dim = self.hidden_size // self.num_attention_heads
|
2024-09-12 12:57:57 +08:00
|
|
|
# some export info
|
2024-11-18 14:37:45 +08:00
|
|
|
if isinstance(self.num_attention_heads, list):
|
|
|
|
self.past_kv_shape = [self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads[0], self.head_dim]
|
|
|
|
else:
|
|
|
|
self.past_kv_shape = [self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads, self.head_dim]
|
2024-09-12 12:57:57 +08:00
|
|
|
self.block_dynamic_axes = {
|
|
|
|
"inputs_embeds" : { 0: "seq_len" },
|
|
|
|
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
|
|
|
|
"position_ids" : { 0: "seq_len" },
|
|
|
|
"past_key_values" : { 1: "history_len" }
|
|
|
|
}
|
|
|
|
self.model_dynamic_axes = {
|
|
|
|
"input_ids" : { 0: "seq_len" },
|
|
|
|
"attention_mask" : { 2: "seq_len", 3: "seq_len" },
|
2024-11-18 14:37:45 +08:00
|
|
|
"position_ids" : { 1: "seq_len" },
|
|
|
|
"past_key_values" : { 3: "history_len" }
|
2024-09-12 12:57:57 +08:00
|
|
|
}
|
2024-10-29 19:32:47 +08:00
|
|
|
prompt_template = self.build_prompt_template()
|
2024-09-12 12:57:57 +08:00
|
|
|
self.llm_config = {
|
|
|
|
'hidden_size' : self.hidden_size,
|
|
|
|
'layer_nums' : self.num_hidden_layers,
|
|
|
|
'attention_mask': self.attention_mask_type,
|
|
|
|
'key_value_shape': self.past_kv_shape[1:],
|
2024-10-29 19:32:47 +08:00
|
|
|
"system_prompt_template": prompt_template['system'].format(query='%s'),
|
|
|
|
'user_prompt_template': prompt_template['user'].format(query='%s'),
|
|
|
|
'assistant_prefix': prompt_template['assistant_prefix'],
|
|
|
|
'assistant_suffix': prompt_template['assistant_suffix'],
|
2024-09-12 12:57:57 +08:00
|
|
|
'is_visual': False
|
|
|
|
}
|
|
|
|
# load modules
|
|
|
|
ModelMapper.do_map(self, self.model, self.model_map['model'])
|
|
|
|
# rebuild modules
|
2024-11-18 14:37:45 +08:00
|
|
|
if self.lm_ is None:
|
|
|
|
out_features, in_features = self.embed_.weight.shape
|
|
|
|
self.lm_ = torch.nn.Linear(in_features, out_features)
|
|
|
|
self.lm_.weight = self.embed_.weight
|
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
if self.embed_.weight is self.lm_.weight:
|
|
|
|
import copy
|
|
|
|
embed_copy = copy.deepcopy(self.embed_)
|
|
|
|
self.embed = Embedding(embed_copy, self)
|
|
|
|
else:
|
|
|
|
self.embed = Embedding(self.embed_, self)
|
|
|
|
# Rotary
|
|
|
|
self.rotary = Rotary(self)
|
|
|
|
self.blocks = []
|
|
|
|
for block in self.blocks_.children():
|
2024-11-18 14:37:45 +08:00
|
|
|
layer_id = len(self.blocks)
|
|
|
|
self.blocks.append(Decoder(block, layer_id, self))
|
2024-09-12 12:57:57 +08:00
|
|
|
self.lm = Lm(self.lm_, self.final_layernorm_, self)
|
|
|
|
# visual model
|
|
|
|
if self.visual is not None:
|
2024-12-02 10:12:08 +08:00
|
|
|
self.visual.float()
|
2024-09-12 20:19:02 +08:00
|
|
|
self.visual = Visual.get_visual(self.model_type)(self.visual, self)
|
2024-09-12 12:57:57 +08:00
|
|
|
return model_path
|
|
|
|
|
|
|
|
def get_attention_mask(self) -> torch.Tensor:
|
|
|
|
if self.model_type == 'chatglm':
|
|
|
|
return self.chatglm_attention_mask()
|
|
|
|
if self.token_len:
|
|
|
|
return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32)
|
|
|
|
return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min
|
|
|
|
|
|
|
|
def get_position_ids(self) -> torch.Tensor:
|
|
|
|
if self.model_type == 'chatglm':
|
|
|
|
return self.chatglm_position_ids()
|
|
|
|
if self.token_len:
|
2024-11-18 14:37:45 +08:00
|
|
|
return torch.tensor([[self.seq_len - 1]], dtype=torch.int)
|
|
|
|
return torch.arange(self.seq_len, dtype=torch.int).unsqueeze(0)
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
def chatglm_attention_mask(self):
|
|
|
|
if self.token_len:
|
|
|
|
return torch.zeros([1]).bool().reshape([1, 1, 1, 1])
|
|
|
|
attention_mask = torch.zeros([self.seq_len, self.seq_len], dtype=torch.bool)
|
|
|
|
for i in range(self.seq_len - 1):
|
|
|
|
attention_mask[i][-1] = True
|
|
|
|
attention_mask = attention_mask.reshape([1, 1, self.seq_len, self.seq_len])
|
|
|
|
return attention_mask
|
|
|
|
|
|
|
|
def chatglm_position_ids(self):
|
|
|
|
if self.token_len:
|
|
|
|
return torch.tensor([self.context_len, self.token_len + 1]).reshape([1, 2, 1])
|
2024-11-18 14:37:45 +08:00
|
|
|
position_ids_0 = torch.arange(self.seq_len, dtype=torch.int)
|
|
|
|
position_ids_1 = torch.zeros(self.seq_len, dtype=torch.int)
|
2024-09-12 12:57:57 +08:00
|
|
|
position_ids_0[-1] = position_ids_0[-2]
|
|
|
|
position_ids_1[-1] = 1
|
|
|
|
position_ids = torch.stack([position_ids_0, position_ids_1]).view(1, 2, -1)
|
|
|
|
return position_ids
|
|
|
|
|
|
|
|
def visual_embed(self, input_ids):
|
2024-09-12 20:19:02 +08:00
|
|
|
return self.visual.embed(input_ids)
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
def embedding(self, input_ids):
|
|
|
|
if self.visual is not None and self.token_len == 0:
|
|
|
|
input_embeds = self.visual_embed(input_ids)
|
|
|
|
else:
|
|
|
|
input_embeds = self.embed(input_ids)
|
|
|
|
return input_embeds
|
|
|
|
|
2024-11-18 14:37:45 +08:00
|
|
|
def forward(self,
|
|
|
|
input_ids: torch.Tensor,
|
|
|
|
attention_mask: torch.Tensor,
|
|
|
|
position_ids: torch.Tensor,
|
|
|
|
past_key_values: Optional[list[torch.Tensor]] = None,
|
|
|
|
cross_attention_states: Optional[torch.Tensor] = None,
|
|
|
|
cross_attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
):
|
2024-09-12 12:57:57 +08:00
|
|
|
hidden_states = input_ids # llm forward without embedding
|
2024-11-18 14:37:45 +08:00
|
|
|
presents = [None for i in range(self.num_hidden_layers)]
|
2024-09-12 12:57:57 +08:00
|
|
|
rotary_pos_emb = self.rotary(position_ids)
|
|
|
|
for i in range(self.num_hidden_layers):
|
2024-11-18 14:37:45 +08:00
|
|
|
if self.blocks[i].cross_decoder and cross_attention_states is None:
|
|
|
|
continue
|
2024-09-12 12:57:57 +08:00
|
|
|
hidden_states, kv = self.blocks[i](hidden_states, rotary_pos_emb, attention_mask, past_key_values[i])
|
2024-11-18 14:37:45 +08:00
|
|
|
presents[i] = kv
|
|
|
|
logits = self.lm(hidden_states)
|
|
|
|
if not self.ppl:
|
|
|
|
logits = logits.reshape(-1)
|
|
|
|
if presents[0].shape == presents[-1].shape and None not in presents:
|
|
|
|
presents = torch.stack(presents)
|
2024-09-12 12:57:57 +08:00
|
|
|
self.seq_len += 1
|
|
|
|
self.token_len += 1
|
|
|
|
return logits, presents
|
|
|
|
|
|
|
|
# some test functions
|
2024-10-29 19:32:47 +08:00
|
|
|
def build_prompt_template(self) -> Dict[str, str]:
|
|
|
|
template = {
|
|
|
|
'system': '',
|
|
|
|
'user': '',
|
|
|
|
'assistant_prefix': '',
|
|
|
|
'assistant_suffix': '',
|
|
|
|
}
|
2024-09-12 12:57:57 +08:00
|
|
|
# just for test
|
2024-12-02 22:35:36 +08:00
|
|
|
if 'Qwen' in self.path or 'Qwen2' in self.path or 'QwQ' in self.path or 'reader' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['system'] = '<|im_start|>system\n{query}<|im_end|>\n'
|
|
|
|
template['user'] = '<|im_start|>user\n{query}<|im_end|>\n'
|
|
|
|
template['assistant_prefix'] = '<|im_start|>assistant\n'
|
|
|
|
template['assistant_suffix'] = '<|im_end|>\n'
|
|
|
|
return template
|
2024-09-12 12:57:57 +08:00
|
|
|
if 'Baichuan2' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['user'] = '<reserved_106>{query}<reserved_107>'
|
|
|
|
return template
|
2024-09-12 12:57:57 +08:00
|
|
|
if 'internlm' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['user'] = '<|User|>:{query}<eoh>\n'
|
|
|
|
template['assistant_prefix'] = '<|Bot|>:'
|
|
|
|
template['assistant_suffix'] = '<eoh>\n'
|
|
|
|
return template
|
2024-09-12 12:57:57 +08:00
|
|
|
if 'TinyLlama' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['system'] = '<s><|system|>\n{query}</s>\n'
|
|
|
|
template['user'] = '<|user|>\n{query}</s>\n'
|
|
|
|
template['assistant_prefix'] = '<|assistant|>\n'
|
|
|
|
template['assistant_suffix'] = '</s>\n'
|
|
|
|
return template
|
2024-09-12 12:57:57 +08:00
|
|
|
if 'Yi' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['user'] = '<|im_start|> user\n{query}<|im_end|>\n'
|
|
|
|
template['assistant_prefix'] = '<|im_start|> assistant\n'
|
|
|
|
template['assistant_suffix'] = '<|im_end|>\n'
|
|
|
|
return template
|
2024-09-12 12:57:57 +08:00
|
|
|
if 'deepseek' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['user'] = '<|begin_of_sentence|>User: {query}\n'
|
|
|
|
template['assistant_prefix'] = '\nAssistant: '
|
|
|
|
template['assistant_suffix'] = '\n<|end_of_sentence|>'
|
|
|
|
return template
|
2024-09-12 12:57:57 +08:00
|
|
|
if 'Llama-3.1' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['system'] = '<|start_header_id|>system<|end_header_id|>\n\n{query}<|eot_id|>'
|
|
|
|
template['user'] = '<|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|>'
|
|
|
|
template['assistant_prefix'] = '<|start_header_id|>assistant<|end_header_id|>\n\n'
|
|
|
|
template['assistant_suffix'] = '<|eot_id|>'
|
|
|
|
return template
|
2024-09-12 12:57:57 +08:00
|
|
|
if 'Llama-3' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['system'] = '<|start_header_id|>system<|end_header_id|>\n\n{query}<|eot_id|>'
|
|
|
|
template['user'] = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|>'
|
|
|
|
template['assistant_prefix'] = '<|start_header_id|>assistant<|end_header_id|>\n\n'
|
|
|
|
template['assistant_suffix'] = '<|eot_id|>'
|
|
|
|
return template
|
2024-09-12 12:57:57 +08:00
|
|
|
if 'Llama-2' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['user'] = '[INST]{query}[/INST]'
|
|
|
|
return template
|
2024-09-12 12:57:57 +08:00
|
|
|
if 'chatglm2' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['user'] = '[Round 1]\n\n问:{query}\n\n'
|
|
|
|
template['assistant_prefix'] = '答:'
|
|
|
|
template['assistant_suffix'] = '\n\n'
|
|
|
|
return template
|
2024-09-12 12:57:57 +08:00
|
|
|
if 'chatglm3' in self.path or 'glm-4' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['user'] = '<|user|>\n{query}\n'
|
|
|
|
template['assistant_prefix'] = '<|assistant|>\n'
|
|
|
|
template['assistant_suffix'] = '\n'
|
|
|
|
return template
|
2024-09-12 12:57:57 +08:00
|
|
|
if 'chatglm' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['user'] = '{query}[gMASK]<sop>'
|
|
|
|
return template
|
2024-09-12 12:57:57 +08:00
|
|
|
if 'phi-2' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['user'] = 'Instruct: {query}\n'
|
|
|
|
template['assistant_prefix'] = 'Output:'
|
|
|
|
template['assistant_suffix'] = '\n'
|
|
|
|
return template
|
2024-09-12 12:57:57 +08:00
|
|
|
if 'gemma-2' in self.path:
|
2024-10-29 19:32:47 +08:00
|
|
|
template['system'] = '<start_of_turn>system\n{query}<end_of_turn>\n'
|
|
|
|
template['user'] = '<bos><start_of_turn>user\n{query}<end_of_turn>\n'
|
|
|
|
template['assistant_prefix'] = '<start_of_turn>model\n'
|
|
|
|
template['assistant_suffix'] = '<end_of_turn>\n'
|
|
|
|
return template
|
2024-11-18 14:37:45 +08:00
|
|
|
if 'OpenELM' in self.path:
|
2024-11-20 11:43:26 +08:00
|
|
|
template['user'] = '<s>{query}'
|
|
|
|
return template
|
2024-11-18 14:37:45 +08:00
|
|
|
if 'SmolLM2' in self.path:
|
2024-11-20 11:43:26 +08:00
|
|
|
template['system'] = '<|im_start|>system\n{query}<|im_end|>\n'
|
|
|
|
template['user'] = '<|im_start|>user\n{query}<|im_end|>\n'
|
|
|
|
template['assistant_prefix'] = '<|im_start|>assistant\n'
|
|
|
|
template['assistant_suffix'] = '<|im_end|>\n'
|
|
|
|
return template
|
2024-10-29 19:32:47 +08:00
|
|
|
# not matched
|
|
|
|
return template
|
|
|
|
|
|
|
|
def build_prompt(self, queries, roles):
|
|
|
|
template = self.build_prompt_template(self)
|
|
|
|
prompt = ""
|
|
|
|
for item in zip(queries, roles):
|
|
|
|
query, role = item
|
|
|
|
if '{query}' in template[role]:
|
|
|
|
prompt += template[role].format(query=query)
|
|
|
|
else:
|
|
|
|
prompt += role + '\n' + query +'\n'
|
|
|
|
return prompt + template['assistant_prefix']
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
def str_to_ids(self, prompt):
|
2024-09-12 20:19:02 +08:00
|
|
|
if self.visual is not None:
|
|
|
|
return self.visual.str_to_ids(prompt)
|
2024-09-12 12:57:57 +08:00
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids']
|
|
|
|
return input_ids
|
|
|
|
|
|
|
|
def id_to_str(self, token_id):
|
2024-11-18 14:37:45 +08:00
|
|
|
def contains_replacement(text): return '\uFFFD' in text
|
|
|
|
def decode_id(token_id):
|
|
|
|
return self.tokenizer.convert_tokens_to_string(
|
|
|
|
self.tokenizer._convert_id_to_token(int(token_id)))
|
|
|
|
def decode_ids(token_ids):
|
|
|
|
return self.tokenizer.convert_tokens_to_string(
|
|
|
|
self.tokenizer.convert_ids_to_tokens(token_ids))
|
|
|
|
word = decode_id(int(token_id))
|
|
|
|
# Smollm tokenizer will produce half chinese character, using buffer to decode
|
|
|
|
if contains_replacement(word):
|
|
|
|
self.decode_buffer.append(token_id)
|
|
|
|
buffer_txt = decode_ids(self.decode_buffer)
|
|
|
|
if not contains_replacement(buffer_txt):
|
|
|
|
word = buffer_txt
|
|
|
|
self.decode_buffer.clear()
|
|
|
|
else:
|
|
|
|
word = ''
|
2024-09-12 12:57:57 +08:00
|
|
|
return word
|
|
|
|
|
|
|
|
def response(self, query):
|
2024-11-18 14:37:45 +08:00
|
|
|
# self.imitate_quant()
|
|
|
|
self.decode_buffer = []
|
2024-11-20 11:43:26 +08:00
|
|
|
prompt = self.build_prompt(['You are a helpful assistant!', query], roles=['system', 'user'])
|
2024-09-12 12:57:57 +08:00
|
|
|
input_ids = self.str_to_ids(prompt)
|
2024-11-18 14:37:45 +08:00
|
|
|
if self.visual is not None:
|
|
|
|
cross_attention_states = self.visual.cross_attention_states
|
|
|
|
cross_attention_mask = self.visual.cross_attention_mask
|
|
|
|
else:
|
|
|
|
cross_attention_states = None
|
|
|
|
cross_attention_mask = None
|
2024-09-12 12:57:57 +08:00
|
|
|
self.seq_len = input_ids.numel()
|
|
|
|
self.context_len = self.seq_len - 2
|
|
|
|
self.token_len = 0
|
|
|
|
past_key_values = [None for i in range(self.num_hidden_layers)]
|
|
|
|
token_id = input_ids
|
|
|
|
while self.token_len < self.max_length:
|
|
|
|
attention_mask = self.get_attention_mask()
|
|
|
|
position_ids = self.get_position_ids()
|
2024-09-12 20:19:02 +08:00
|
|
|
input_ids = self.embedding(token_id)
|
2024-11-18 14:37:45 +08:00
|
|
|
logits, past_key_values = self.forward(input_ids,
|
|
|
|
attention_mask,
|
|
|
|
position_ids,
|
|
|
|
past_key_values,
|
|
|
|
cross_attention_states,
|
|
|
|
cross_attention_mask)
|
2024-09-12 12:57:57 +08:00
|
|
|
token_id = torch.argmax(logits)
|
|
|
|
if token_id in self.stop_ids:
|
|
|
|
print("", end='\n')
|
|
|
|
break
|
|
|
|
word = self.id_to_str(token_id)
|
|
|
|
print(word, end="", flush=True)
|
|
|
|
|
2024-09-12 20:19:02 +08:00
|
|
|
@spinner_run(f'export visual to ')
|
2024-09-12 12:57:57 +08:00
|
|
|
def export_visual(self):
|
|
|
|
if self.visual is None:
|
|
|
|
return
|
2024-09-12 20:19:02 +08:00
|
|
|
input_images = torch.randn((1, 3, self.visual.image_size, self.visual.image_size))
|
2024-09-12 12:57:57 +08:00
|
|
|
model = self.visual
|
2024-09-12 20:19:02 +08:00
|
|
|
onnx_model = f'{self.onnx_path}/visual.onnx'
|
2024-09-12 12:57:57 +08:00
|
|
|
torch.onnx.export(model, (input_images),
|
|
|
|
onnx_model,
|
|
|
|
input_names=['input_images'],
|
|
|
|
output_names=['image_embeds'],
|
|
|
|
dynamic_axes={"input_images": {
|
|
|
|
0: "size"
|
|
|
|
}},
|
|
|
|
do_constant_folding=True,
|
2024-09-12 20:19:02 +08:00
|
|
|
verbose=False,
|
2024-09-12 12:57:57 +08:00
|
|
|
opset_version=15)
|
|
|
|
return onnx_model
|
|
|
|
|
|
|
|
@spinner_run(f'export embedding to ')
|
|
|
|
def export_embed(self):
|
|
|
|
import ctypes
|
|
|
|
if hasattr(self, 'word_embeddings'):
|
|
|
|
# embedding model's embed
|
|
|
|
tensor_data = self.word_embeddings.weight.data.bfloat16()
|
|
|
|
else:
|
|
|
|
tensor_data = self.embed.embed.weight.data.bfloat16()
|
|
|
|
data_ptr = tensor_data.untyped_storage().data_ptr()
|
|
|
|
buffer = (ctypes.c_byte * (tensor_data.numel() * 2)).from_address(data_ptr)
|
|
|
|
embedding_file = f'{self.dst_path}/embeddings_bf16.bin'
|
|
|
|
with open(embedding_file, 'wb') as f:
|
|
|
|
f.write(buffer)
|
|
|
|
return embedding_file
|
|
|
|
|
|
|
|
@spinner_run(f'export config to ')
|
|
|
|
def export_config(self, mnn_config = False):
|
|
|
|
config_json = f'{self.dst_path}/llm_config.json'
|
|
|
|
with open(config_json, 'w', encoding='utf-8') as f:
|
|
|
|
json.dump(self.llm_config, f, ensure_ascii=False, indent=4)
|
|
|
|
if not mnn_config:
|
|
|
|
return config_json
|
|
|
|
with open(f'{self.dst_path}/config.json', 'w', encoding='utf-8') as f:
|
|
|
|
config = {
|
|
|
|
"llm_model": f"{self.dst_name}.mnn",
|
|
|
|
"llm_weight": f"{self.dst_name}.mnn.weight",
|
|
|
|
"backend_type": "cpu",
|
|
|
|
"thread_num": 4,
|
|
|
|
"precision": "low",
|
|
|
|
"memory": "low"
|
|
|
|
}
|
|
|
|
json.dump(config, f, ensure_ascii=False, indent=4)
|
|
|
|
return config_json
|
|
|
|
|
|
|
|
def imitate_quant(self):
|
|
|
|
def quant_dequant(linear, quant_bit = self.quant_bit, quant_block = self.quant_block):
|
|
|
|
weight = linear.weight.data
|
|
|
|
oc, ic = weight.shape
|
|
|
|
if quant_block == 0:
|
|
|
|
block_size = ic
|
|
|
|
else:
|
|
|
|
block_size = quant_block
|
|
|
|
block_num = ic // block_size
|
|
|
|
weight = weight.reshape(oc, block_num, block_size)
|
|
|
|
max_val = torch.max(weight, axis=-1, keepdims=True).values
|
|
|
|
min_val = torch.min(weight, axis=-1, keepdims=True).values
|
|
|
|
offset = 1 << (quant_bit - 1)
|
|
|
|
clip_max = offset - 1
|
|
|
|
clip_min = -offset
|
|
|
|
scale = (max_val - min_val) / (clip_max - clip_min)
|
|
|
|
q_weight = torch.round((weight - min_val) / scale) + clip_min
|
|
|
|
q_weight = torch.clip(q_weight, clip_min, clip_max)
|
|
|
|
dq_weight = (q_weight - clip_min) * scale + min_val
|
|
|
|
dq_weight = dq_weight.reshape(oc, ic).float()
|
|
|
|
linear.weight.data = dq_weight
|
|
|
|
return linear
|
|
|
|
with torch.no_grad():
|
|
|
|
for i in range(self.num_hidden_layers):
|
|
|
|
for name, child in self.blocks[i].self_attn.named_children():
|
|
|
|
if isinstance(child, torch.nn.Linear):
|
|
|
|
setattr(self.blocks[i].self_attn, name, quant_dequant(child))
|
|
|
|
for name, child in self.blocks[i].mlp.named_children():
|
|
|
|
if isinstance(child, torch.nn.Linear):
|
|
|
|
setattr(self.blocks[i].mlp, name, quant_dequant(child))
|
|
|
|
self.lm.lm = quant_dequant(self.lm.lm)
|
|
|
|
|
|
|
|
def unload_param(self):
|
|
|
|
self.unloaded_ops = {}
|
|
|
|
def build_faker(real, name):
|
|
|
|
faker = FakeLinear(real.in_features, real.out_features, real.bias is not None, name)
|
|
|
|
self.unloaded_ops[name] = real
|
|
|
|
return faker
|
|
|
|
# replace linear with fakelinear to save export memory and time
|
|
|
|
with torch.no_grad():
|
|
|
|
for i in range(self.num_hidden_layers):
|
2024-11-18 14:37:45 +08:00
|
|
|
# different kv cache shape in different layers
|
|
|
|
if isinstance(self.num_attention_heads, list):
|
|
|
|
self.blocks[i].self_attn.export_fused_attn = True
|
2024-09-12 12:57:57 +08:00
|
|
|
for name, child in self.blocks[i].self_attn.named_children():
|
|
|
|
if isinstance(child, torch.nn.Linear):
|
|
|
|
setattr(self.blocks[i].self_attn, name, build_faker(child, f'/layers.{i}/self_attn/{name}/Linear'))
|
|
|
|
for name, child in self.blocks[i].mlp.named_children():
|
|
|
|
if isinstance(child, torch.nn.Linear):
|
|
|
|
setattr(self.blocks[i].mlp, name, build_faker(child, f'/layers.{i}/mlp/{name}/Linear'))
|
|
|
|
self.lm.lm = build_faker(self.lm.lm, f'/lm/lm_head/Linear')
|
|
|
|
|
|
|
|
@spinner_run(f'export model weight to ')
|
|
|
|
def onnx_load_param(self, onnx_path):
|
|
|
|
return OnnxRebuilder(onnx_path, self.unloaded_ops).rebuild()
|
|
|
|
|
|
|
|
@spinner_run(f'slim the graph of ')
|
|
|
|
def onnx_slim(self, onnx_model):
|
|
|
|
import onnxslim
|
|
|
|
model = onnxslim.slim(onnx_model)
|
|
|
|
onnx.save(model, onnx_model)
|
|
|
|
return onnx_model
|
|
|
|
|
|
|
|
@spinner_run(f'export onnx model to ')
|
|
|
|
def export_onnx(self):
|
|
|
|
# unload linear weight to save export memory
|
|
|
|
self.unload_param()
|
|
|
|
model = self
|
|
|
|
self.seq_len = 3
|
|
|
|
self.token_len = 0
|
|
|
|
input_ids = torch.arange(3, dtype=torch.long)
|
|
|
|
attention_mask = self.get_attention_mask()
|
|
|
|
position_ids = self.get_position_ids()
|
2024-09-12 20:19:02 +08:00
|
|
|
onnx_model = f'{self.onnx_path}/{self.dst_name}.onnx'
|
2024-09-12 12:57:57 +08:00
|
|
|
input_ids = self.embedding(input_ids)
|
2024-11-18 14:37:45 +08:00
|
|
|
past_key_values = torch.zeros(self.past_kv_shape)
|
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
# export to onnx
|
|
|
|
torch.onnx.export(
|
|
|
|
model, (input_ids, attention_mask, position_ids, past_key_values),
|
|
|
|
onnx_model,
|
|
|
|
input_names=[
|
|
|
|
'input_ids', 'attention_mask', 'position_ids', 'past_key_values'
|
|
|
|
],
|
|
|
|
output_names=['logits', 'presents'],
|
|
|
|
dynamic_axes=self.model_dynamic_axes,
|
|
|
|
do_constant_folding=True,
|
2024-11-18 14:37:45 +08:00
|
|
|
verbose=False,
|
2024-09-12 12:57:57 +08:00
|
|
|
opset_version=15)
|
|
|
|
return onnx_model
|
|
|
|
|
2024-11-18 14:37:45 +08:00
|
|
|
def awq_quant(self):
|
|
|
|
self.awq_quantizer = AwqQuantizer(self)
|
|
|
|
self.awq_quantizer.quantize()
|
|
|
|
self.is_awq_quantized = True
|
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
def export(self, export_type):
|
2024-11-18 14:37:45 +08:00
|
|
|
if self.awq:
|
|
|
|
self.awq_quant()
|
2024-09-12 12:57:57 +08:00
|
|
|
export_mnn = export_type == 'mnn'
|
|
|
|
# export tokenizer
|
|
|
|
self.export_tokenizer()
|
2024-11-18 14:37:45 +08:00
|
|
|
if export_mnn and self.tie_word_embeddings:
|
|
|
|
pass # mnn tie_word_embeddings need't export embedding
|
|
|
|
else:
|
|
|
|
self.export_embed()
|
2024-09-12 12:57:57 +08:00
|
|
|
if self.visual:
|
2024-09-12 20:19:02 +08:00
|
|
|
visual_onnx = self.export_visual()
|
2024-11-18 14:37:45 +08:00
|
|
|
#if self.onnx_slim:
|
2024-09-12 20:19:02 +08:00
|
|
|
#visual_onnx = self.onnx_slim(visual_onnx)
|
|
|
|
if export_mnn:
|
|
|
|
MNNConveter(visual_onnx, None, self).export(quant_bit=self.visual.quant_bit)
|
2024-09-12 12:57:57 +08:00
|
|
|
# export graph to llm.onnx
|
|
|
|
onnx_model = self.export_onnx()
|
2024-11-18 14:37:45 +08:00
|
|
|
if self.onnx_slim:
|
2024-09-12 12:57:57 +08:00
|
|
|
self.onnx_slim(onnx_model)
|
|
|
|
if export_mnn:
|
|
|
|
# convert onnx to mnn and quant weight
|
|
|
|
MNNConveter(onnx_model, self.unloaded_ops, self).export()
|
2024-11-18 14:37:45 +08:00
|
|
|
# delete onnx file
|
|
|
|
if os.path.exists(onnx_model):
|
2024-12-02 10:12:08 +08:00
|
|
|
import glob
|
2024-11-18 14:37:45 +08:00
|
|
|
try:
|
2024-12-02 10:12:08 +08:00
|
|
|
for file in glob.glob(f'{self.onnx_path}/*'):
|
|
|
|
os.remove(file)
|
2024-11-18 14:37:45 +08:00
|
|
|
os.rmdir(self.onnx_path)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"remove onnx error: {e}")
|
2024-09-12 12:57:57 +08:00
|
|
|
else:
|
|
|
|
# export weight to llm.onnx.data
|
|
|
|
self.onnx_load_param(onnx_model)
|
2024-11-18 14:37:45 +08:00
|
|
|
# export llm_config.json and config.json
|
|
|
|
self.export_config(export_mnn)
|
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
@spinner_run(f'export tokenizer to ')
|
|
|
|
def export_tokenizer(self):
|
|
|
|
# load tokenizer file
|
2024-11-18 14:37:45 +08:00
|
|
|
tokenizer_model = os.path.join(self.tokenizer_path, 'tokenizer.model')
|
|
|
|
ice_text_model = os.path.join(self.tokenizer_path, 'ice_text.model')
|
2024-09-12 12:57:57 +08:00
|
|
|
try:
|
|
|
|
import sentencepiece as spm
|
|
|
|
if os.path.exists(tokenizer_model):
|
|
|
|
self.sp_model = spm.SentencePieceProcessor(tokenizer_model)
|
|
|
|
elif os.path.exists(ice_text_model):
|
|
|
|
self.sp_model = spm.SentencePieceProcessor(ice_text_model)
|
|
|
|
else:
|
|
|
|
self.sp_model = None
|
|
|
|
except:
|
|
|
|
self.sp_model = None
|
|
|
|
merge_file = os.path.join(self.path, 'merges.txt')
|
|
|
|
if os.path.exists(merge_file):
|
|
|
|
self.merge_txt = merge_file
|
|
|
|
else:
|
|
|
|
self.merge_txt = None
|
|
|
|
# TOKENIZER MAGIC NUMBER
|
|
|
|
MAGIC_NUMBER = 430
|
|
|
|
# TOKENIZER TYPE
|
|
|
|
SENTENCEPIECE = 0; TIKTOIKEN = 1; BERT = 2; HUGGINGFACE = 3
|
|
|
|
def write_line(fp, *args):
|
|
|
|
for arg in args:
|
|
|
|
for token in arg:
|
|
|
|
fp.write(str(token) + ' ')
|
|
|
|
fp.write('\n')
|
|
|
|
def write_header(fp, type, speicals, prefix = []):
|
|
|
|
fp.write(f'{MAGIC_NUMBER} {type}\n')
|
|
|
|
fp.write(f'{len(speicals)} {len(self.stop_ids)} {len(prefix)}\n')
|
|
|
|
write_line(fp, speicals, self.stop_ids, prefix)
|
|
|
|
|
|
|
|
file_path = os.path.join(self.dst_path, "tokenizer.txt")
|
|
|
|
special_list = list(self.tokenizer.added_tokens_decoder.keys())
|
|
|
|
if hasattr(self.tokenizer, 'special_tokens'):
|
|
|
|
for k, v in self.tokenizer.special_tokens.items():
|
|
|
|
special_list.append(v)
|
|
|
|
if hasattr(self.tokenizer, 'gmask_token_id'):
|
|
|
|
special_list.append(self.tokenizer.gmask_token_id)
|
|
|
|
vocab_list = []
|
|
|
|
prefix_list = []
|
|
|
|
if hasattr(self.tokenizer, 'get_prefix_tokens'):
|
|
|
|
prefix_list = self.tokenizer.get_prefix_tokens()
|
2024-11-18 14:37:45 +08:00
|
|
|
if len(prefix_list) == 0:
|
|
|
|
test_txt = 'A'
|
|
|
|
ids = self.tokenizer.encode(test_txt)
|
|
|
|
get_txt = self.tokenizer.decode(ids[-1])
|
|
|
|
if len(ids) > 1 and get_txt == test_txt:
|
|
|
|
prefix_list += ids[:-1]
|
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
if self.sp_model is not None:
|
|
|
|
# senetencepiece
|
|
|
|
NORMAL = 1; UNKNOWN = 2; CONTROL = 3
|
|
|
|
USER_DEFINED = 4; UNUSED = 5; BYTE = 6
|
|
|
|
for i in range(self.sp_model.GetPieceSize()):
|
|
|
|
token = self.sp_model.IdToPiece(i)
|
|
|
|
score = self.sp_model.GetScore(i)
|
|
|
|
token_type = NORMAL
|
|
|
|
if self.sp_model.IsUnknown(i):
|
|
|
|
token_type = UNKNOWN
|
|
|
|
elif self.sp_model.IsControl(i):
|
|
|
|
token_type = CONTROL
|
|
|
|
elif self.sp_model.IsUnused(i):
|
|
|
|
token_type = UNUSED
|
|
|
|
elif self.sp_model.IsByte(i):
|
|
|
|
token_type = BYTE
|
|
|
|
if self.path == 'Chatglm_6b':
|
|
|
|
if '<n>' in token: token = '\n'
|
|
|
|
if '<|tab|>' in token: token = '\t'
|
|
|
|
if '<|blank_' in token: token = ' ' * int(token[8:token.find('|>')])
|
|
|
|
if '▁' in token: token = token.replace('▁', ' ')
|
|
|
|
token_encode = base64.b64encode(token.encode("utf-8")).decode("utf8")
|
|
|
|
vocab_list.append(f'{token_encode} {score} {token_type}\n')
|
|
|
|
with open(file_path, "w", encoding="utf8") as fp:
|
|
|
|
write_header(fp, SENTENCEPIECE, special_list, prefix_list)
|
|
|
|
fp.write(f'{len(vocab_list)}\n')
|
|
|
|
for vocab in vocab_list:
|
|
|
|
fp.write(vocab)
|
|
|
|
elif hasattr(self.tokenizer, 'mergeable_ranks'):
|
|
|
|
# tikton
|
|
|
|
vocab_list = []
|
|
|
|
for k, v in self.tokenizer.mergeable_ranks.items():
|
|
|
|
line = base64.b64encode(k).decode("utf8") + "\n"
|
|
|
|
vocab_list.append(line)
|
|
|
|
if hasattr(self.tokenizer, 'special_tokens'):
|
|
|
|
for k, v in self.tokenizer.special_tokens.items():
|
|
|
|
line = base64.b64encode(k.encode("utf-8")).decode("utf8") + "\n"
|
|
|
|
vocab_list.append(line)
|
|
|
|
if hasattr(self.tokenizer, 'added_tokens_decoder'):
|
|
|
|
for k, v in self.tokenizer.added_tokens_decoder.items():
|
|
|
|
line = base64.b64encode(v.__str__().encode("utf-8")).decode("utf8") + "\n"
|
|
|
|
vocab_list.append(line)
|
|
|
|
with open(file_path, "w", encoding="utf8") as fp:
|
|
|
|
write_header(fp, TIKTOIKEN, special_list, prefix_list)
|
|
|
|
fp.write(f'{len(vocab_list)}\n')
|
|
|
|
for vocab in vocab_list:
|
|
|
|
fp.write(vocab)
|
|
|
|
elif self.merge_txt is not None:
|
|
|
|
# huggingface tokenizer
|
|
|
|
merge_list = []
|
|
|
|
vocab = self.tokenizer.get_vocab()
|
|
|
|
special_list = list(self.tokenizer.added_tokens_decoder.keys())
|
|
|
|
vocab_list = ['<unk>' for i in range(len(vocab))]
|
|
|
|
# load vocab
|
|
|
|
for k, v in vocab.items():
|
|
|
|
vocab_list[int(v)] = k
|
|
|
|
# load merge
|
|
|
|
with open(self.merge_txt, 'rt') as merge:
|
|
|
|
for line in merge.readlines():
|
|
|
|
merge_list.append(line)
|
|
|
|
# write to tokenizer.txt
|
|
|
|
with open(file_path, "w", encoding="utf8") as fp:
|
|
|
|
write_header(fp, HUGGINGFACE, special_list)
|
|
|
|
fp.write(f'{len(vocab_list)} {len(merge_list)}\n')
|
|
|
|
for v in vocab_list:
|
|
|
|
fp.write(v + '\n')
|
|
|
|
for m in merge_list:
|
|
|
|
fp.write(m)
|
|
|
|
else:
|
|
|
|
# tiktoken or bert
|
|
|
|
if 'bert' in type(self.tokenizer).__name__.lower():
|
|
|
|
tokenizer_type = BERT
|
|
|
|
else:
|
|
|
|
tokenizer_type = TIKTOIKEN
|
|
|
|
# bert tokenizer
|
|
|
|
def unicode_to_byte(u: int):
|
|
|
|
if u >= 256 and u <= 288:
|
|
|
|
return u - 256
|
|
|
|
if u >= 289 and u <= 322:
|
|
|
|
return u - 162
|
|
|
|
if u == 323:
|
|
|
|
return 173
|
|
|
|
if u == 65372: # |
|
|
|
|
return 124
|
|
|
|
if u == 9601: # _
|
|
|
|
return 95
|
|
|
|
return u
|
|
|
|
vocab = self.tokenizer.get_vocab()
|
|
|
|
vocab_list = ['<unk>' for i in range(len(vocab))]
|
|
|
|
for k, v in vocab.items():
|
|
|
|
try:
|
|
|
|
vocab_list[int(v)] = bytes([unicode_to_byte(ord(c)) for c in k]).decode('utf-8', errors='ignore')
|
|
|
|
except:
|
|
|
|
vocab_list[int(v)] = k
|
|
|
|
special_list = list(self.tokenizer.added_tokens_decoder.keys())
|
|
|
|
with open(file_path, "w", encoding="utf8") as fp:
|
|
|
|
write_header(fp, tokenizer_type, special_list)
|
|
|
|
fp.write(f'{len(vocab_list)}\n')
|
|
|
|
for v in vocab_list:
|
|
|
|
line = base64.b64encode(v.encode('utf-8')).decode("utf8") + "\n"
|
|
|
|
fp.write(line)
|
|
|
|
return file_path
|
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingExporter(LlmExporter):
|
|
|
|
def __init__(self, args):
|
|
|
|
super().__init__(args)
|
|
|
|
self.dst_name = 'embedding'
|
|
|
|
|
|
|
|
def word_embed(self, input_ids):
|
|
|
|
return self.word_embeddings(input_ids.view(1, -1))
|
|
|
|
|
|
|
|
def bge_forward(self, inputs_embeds, position_ids, attention_mask):
|
|
|
|
# bert absolute position
|
|
|
|
inputs_embeds = inputs_embeds.reshape(1, -1, self.hidden_size)
|
|
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
|
|
embeddings = inputs_embeds + position_embeddings + self.token_type_embeddings
|
|
|
|
hidden_states = self.embedding_layernorm(embeddings)
|
|
|
|
for i in range(self.num_hidden_layers):
|
|
|
|
hidden_states = self.blocks[i](hidden_states, attention_mask)[0]
|
|
|
|
sentence_embeddings = hidden_states[:, 0]
|
|
|
|
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
|
|
|
return sentence_embeddings
|
|
|
|
|
|
|
|
def gte_forward(self, inputs_embeds, position_ids, attention_mask):
|
|
|
|
# rope position
|
|
|
|
inputs_embeds = inputs_embeds.reshape(1, -1, self.hidden_size)
|
|
|
|
freqs = position_ids.float().reshape(-1, 1) * self.inv_freq
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
|
|
rope_embeds = torch.stack([emb.cos(), emb.sin()]).unsqueeze(-2).unsqueeze(1)
|
|
|
|
attention_bias = 1 - attention_mask.float()
|
|
|
|
hidden_states = self.embedding_layernorm(inputs_embeds + self.token_type_embeddings)
|
|
|
|
for i in range(self.num_hidden_layers):
|
|
|
|
hidden_states = self.blocks[i](hidden_states, attention_bias, rope_embeds)[0]
|
|
|
|
sentence_embeddings = hidden_states[:, 0]
|
|
|
|
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
|
|
|
return sentence_embeddings
|
|
|
|
|
|
|
|
def forward(self, inputs_embeds, position_ids, attention_mask):
|
|
|
|
if self.model_type == 'bert':
|
|
|
|
return self.bge_forward(inputs_embeds, position_ids, attention_mask)
|
|
|
|
if self.model_type == 'new':
|
|
|
|
return self.gte_forward(inputs_embeds, position_ids, attention_mask)
|
|
|
|
raise RuntimeError(f'Not support embedding model: {self.model_type}!')
|
|
|
|
|
|
|
|
def response(self, query):
|
|
|
|
self.eval()
|
|
|
|
input_ids = self.tokenizer(query)['input_ids']
|
|
|
|
self.seq_len = len(input_ids)
|
|
|
|
input_ids = torch.tensor(input_ids)
|
|
|
|
position_ids = self.get_position_ids()
|
|
|
|
attention_mask = self.get_attention_mask()
|
|
|
|
inputs_embeds = self.word_embed(input_ids)
|
|
|
|
res = self.forward(inputs_embeds, position_ids, attention_mask)
|
2024-09-12 20:19:02 +08:00
|
|
|
# print(res)
|
2024-09-12 12:57:57 +08:00
|
|
|
return res
|
|
|
|
|
|
|
|
@spinner_run(f'load pretrained model ')
|
|
|
|
def load_model(self, model_path):
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
2024-11-18 14:37:45 +08:00
|
|
|
self.config = AutoConfig.from_pretrained(model_path)
|
|
|
|
self.config._attn_implementation = 'eager'
|
|
|
|
self.model = AutoModel.from_config(self.config)
|
2024-09-12 12:57:57 +08:00
|
|
|
transformer = self.model.encoder
|
|
|
|
self.model_type = self.config.model_type
|
|
|
|
self.lm_ = self.model.pooler
|
|
|
|
self.embed_ = self.model.embeddings
|
|
|
|
self.word_embeddings = self.embed_.word_embeddings
|
|
|
|
self.token_type_embeddings = self.embed_.token_type_embeddings.weight.data[0]
|
|
|
|
self.embedding_layernorm = self.embed_.LayerNorm
|
|
|
|
if hasattr(self.embed_, 'position_embeddings'):
|
|
|
|
self.position_embeddings = self.embed_.position_embeddings
|
|
|
|
self.hidden_size = self.word_embeddings.weight.shape[-1]
|
|
|
|
self.blocks = transformer.layer
|
|
|
|
if self.model_type == 'new':
|
|
|
|
self.inv_freq = self.embed_.rotary_emb.inv_freq
|
|
|
|
# some wrapper
|
|
|
|
self.stop_ids = []
|
|
|
|
self.num_hidden_layers = len(self.blocks)
|
|
|
|
self.embed = self.embed_
|
|
|
|
self.lm = self.lm_
|
|
|
|
# some config for export
|
|
|
|
self.model_dynamic_axes = {
|
|
|
|
"input_ids" : { 1: "seq_len" },
|
|
|
|
"position_ids" : { 1: "seq_len" },
|
|
|
|
"attention_mask" : { 3: "seq_len" }
|
|
|
|
}
|
|
|
|
self.attention_mask_type = 'int'
|
|
|
|
self.llm_config = {
|
|
|
|
'hidden_size' : self.hidden_size,
|
|
|
|
'layer_nums' : self.num_hidden_layers,
|
|
|
|
'attention_mask': self.attention_mask_type,
|
|
|
|
'key_value_shape': [],
|
|
|
|
"prompt_template": self.build_prompt('%s'),
|
|
|
|
'is_visual': False
|
|
|
|
}
|
|
|
|
return model_path
|
|
|
|
|
|
|
|
@spinner_run(f'export onnx model to ')
|
|
|
|
def export_onnx(self):
|
|
|
|
model = self.eval()
|
|
|
|
self.seq_len = 3
|
|
|
|
input_ids = torch.arange(3, dtype=torch.long)
|
|
|
|
position_ids = self.get_position_ids()
|
|
|
|
attention_mask = self.get_attention_mask()
|
|
|
|
inputs_embeds = self.word_embed(input_ids)
|
2024-09-12 20:19:02 +08:00
|
|
|
onnx_model = f'{self.onnx_path}/{self.dst_name}.onnx'
|
2024-09-12 12:57:57 +08:00
|
|
|
torch.onnx.export(
|
|
|
|
model, (inputs_embeds, position_ids, attention_mask),
|
|
|
|
onnx_model,
|
|
|
|
input_names=[
|
|
|
|
'input_ids',
|
|
|
|
'position_ids',
|
|
|
|
'attention_mask'
|
|
|
|
],
|
|
|
|
output_names=['sentence_embeddings'],
|
|
|
|
dynamic_axes=self.model_dynamic_axes,
|
|
|
|
do_constant_folding=True,
|
|
|
|
opset_version=15)
|
|
|
|
return onnx_model
|
|
|
|
|
|
|
|
def export(self, export_type):
|
|
|
|
export_mnn = 'mnn' in export_type
|
|
|
|
self.export_tokenizer()
|
|
|
|
self.export_config(export_mnn)
|
|
|
|
self.export_embed()
|
|
|
|
onnx_model = self.export_onnx()
|
2024-11-18 14:37:45 +08:00
|
|
|
if self.onnx_slim:
|
2024-09-12 12:57:57 +08:00
|
|
|
self.onnx_slim(onnx_model)
|
|
|
|
if export_mnn:
|
|
|
|
MNNConveter(onnx_model, None, self).export()
|
|
|
|
|
|
|
|
def build_prompt(self, query):
|
|
|
|
if self.model_type == 'bert':
|
|
|
|
return f'[CLS]{query}[SEP]'
|
|
|
|
if self.model_type == 'new':
|
|
|
|
return f'<s> {query}</s>'
|
|
|
|
|
|
|
|
def get_position_ids(self) -> torch.Tensor:
|
|
|
|
return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
|
|
|
|
|
|
|
|
def get_attention_mask(self) -> torch.Tensor:
|
|
|
|
return torch.ones([1, 1, 1, self.seq_len], dtype=torch.long)
|
|
|
|
|
2024-11-18 14:37:45 +08:00
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
def export(path,
|
|
|
|
type = None,
|
|
|
|
lora_path = None,
|
|
|
|
dst_path = './model',
|
|
|
|
export = 'onnx',
|
2024-11-18 14:37:45 +08:00
|
|
|
onnx_slim = False,
|
2024-09-12 12:57:57 +08:00
|
|
|
quant_bit = 4,
|
|
|
|
quant_block = 128,
|
|
|
|
lm_quant_bit = None):
|
|
|
|
args = argparse.Namespace()
|
|
|
|
for k, v in {
|
|
|
|
'path': path,
|
|
|
|
'type': type,
|
|
|
|
'lora_path': lora_path,
|
|
|
|
'dst_path': dst_path,
|
|
|
|
'export': export,
|
2024-11-18 14:37:45 +08:00
|
|
|
'onnx_slim': onnx_slim,
|
2024-09-12 12:57:57 +08:00
|
|
|
'quant_bit': quant_bit,
|
|
|
|
'quant_block': quant_block,
|
|
|
|
'lm_quant_bit': lm_quant_bit
|
|
|
|
}.items():
|
|
|
|
setattr(args, k, v)
|
|
|
|
if 'bge' in path:
|
|
|
|
llm_exporter = EmbeddingExporter(args)
|
|
|
|
else:
|
|
|
|
llm_exporter = LlmExporter(args)
|
|
|
|
# export
|
|
|
|
llm_exporter.export(export)
|
|
|
|
|
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser(description='llm_exporter', formatter_class=argparse.RawTextHelpFormatter)
|
|
|
|
parser.add_argument('--path', type=str, required=True,
|
|
|
|
help='path(`str` or `os.PathLike`):\nCan be either:'
|
|
|
|
'\n\t- A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO]'
|
|
|
|
'\n\t- A path to a *directory* clone from repo like `../chatglm-6b`.')
|
|
|
|
parser.add_argument('--type', type=str, default=None,
|
|
|
|
help='type(`str`, *optional*):'
|
|
|
|
'\n\tThe pretrain llm model type.'
|
|
|
|
)
|
2024-11-18 14:37:45 +08:00
|
|
|
parser.add_argument('--tokenizer_path', type=str, default=None, help='tokenizer path, defaut is `None` mean using `--path` value.')
|
2024-09-12 12:57:57 +08:00
|
|
|
parser.add_argument('--lora_path', type=str, default=None, help='lora path, defaut is `None` mean not apply lora.')
|
|
|
|
parser.add_argument('--dst_path', type=str, default='./model', help='export onnx/mnn model to path, defaut is `./model`.')
|
2024-11-18 14:37:45 +08:00
|
|
|
parser.add_argument('--verbose', action='store_true', help='Whether or not to print verbose.')
|
2024-09-12 12:57:57 +08:00
|
|
|
parser.add_argument('--test', type=str, help='test model inference with query `TEST`.')
|
|
|
|
parser.add_argument('--export', type=str, default=None, help='export model to an onnx/mnn model.')
|
2024-11-18 14:37:45 +08:00
|
|
|
parser.add_argument('--onnx_slim', action='store_true', help='Whether or not to use onnx-slim.')
|
2024-09-12 12:57:57 +08:00
|
|
|
parser.add_argument('--quant_bit', type=int, default=4, help='mnn quant bit, 4 or 8, default is 4.')
|
2024-11-18 14:37:45 +08:00
|
|
|
parser.add_argument('--quant_block', type=int, default=128, help='mnn quant block, default is 0 mean channle-wise.')
|
2024-09-12 12:57:57 +08:00
|
|
|
parser.add_argument('--lm_quant_bit', type=int, default=None, help='mnn lm_head quant bit, 4 or 8, default is `quant_bit`.')
|
|
|
|
parser.add_argument('--mnnconvert', type=str, default='../../../build/MNNConvert', help='local mnnconvert path, if invalid, using pymnn.')
|
2024-11-18 14:37:45 +08:00
|
|
|
parser.add_argument('--ppl', action='store_true', help='Whether or not to get all logits of input tokens.')
|
|
|
|
parser.add_argument('--awq', action='store_true', help='Whether or not to use awq quant.')
|
|
|
|
parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), defualt is False.')
|
2024-09-12 12:57:57 +08:00
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
model_path = args.path
|
|
|
|
model_type = args.type
|
|
|
|
|
|
|
|
if 'gte' in model_path or 'bge' in model_path:
|
|
|
|
llm_exporter = EmbeddingExporter(args)
|
|
|
|
else:
|
|
|
|
llm_exporter = LlmExporter(args)
|
|
|
|
|
|
|
|
# some actions
|
|
|
|
if args.test is not None:
|
|
|
|
llm_exporter.response(args.test)
|
|
|
|
|
|
|
|
if args.export is not None:
|
|
|
|
llm_exporter.export(args.export)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2024-10-29 19:32:47 +08:00
|
|
|
main()
|