mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			147 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			147 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
| import glob
 | |
| import json
 | |
| import torch
 | |
| from safetensors import safe_open
 | |
| 
 | |
| class GPTQWeight:
 | |
|     def __init__(self, name):
 | |
|         self.name = name
 | |
| 
 | |
|     def __repr__(self) -> str:
 | |
|         if hasattr(self, 'qweight'):
 | |
|             return f'{self.name}, {self.qweight.shape}, {self.scales.shape}'
 | |
|         return 'None'
 | |
| 
 | |
|     def add(self, name, tensor):
 | |
|         setattr(self, name, tensor)
 | |
| 
 | |
|     def weight(self, idx):
 | |
|         shape = self.qweight.shape
 | |
|         if len(shape) == 2:
 | |
|             ic, oc = shape
 | |
|             self.qweight = self.qweight.reshape(ic//16, 16, oc)
 | |
|         return self.qweight[idx]
 | |
| 
 | |
|     def scale(self, idx):
 | |
|         return self.scales[idx]
 | |
| 
 | |
| class MNNWeight:
 | |
|     def __init__(self, name, external, weight_elements):
 | |
|         self.name = name
 | |
|         self.external = external
 | |
|         self.quant_bits = 4
 | |
|         if round(weight_elements / external[1]) == 2:
 | |
|             self.quant_bits = 4
 | |
|             self.a_min = -8
 | |
|         else:
 | |
|             self.quant_bits = 8
 | |
|             self.a_min = -128
 | |
|         self.parse_name()
 | |
| 
 | |
|     def __repr__(self) -> str:
 | |
|         return f'{self.layer_id}.{self.op_id}.{self.block_id}, {self.external}'
 | |
| 
 | |
|     def parse_name(self):
 | |
|         parts = self.name.split('/')
 | |
|         if len(parts) > 4:
 | |
|             self.layer_id = parts[1].split('.')[1]
 | |
|             self.op_id = parts[2] + '.' + parts[3]
 | |
|             self.block_id = parts[-1].split('__')[-1]
 | |
|         else:
 | |
|             self.layer_id = -1
 | |
|             self.op_id = parts[2]
 | |
|             self.block_id = parts[-1].split('__')[-1]
 | |
| 
 | |
|     def key(self):
 | |
|         if self.layer_id == -1: return self.op_id
 | |
|         return f'{self.layer_id}.{self.op_id}'
 | |
|     def offset(self): return self.external[0]
 | |
|     def weight_size(self): return self.external[1]
 | |
|     def scale_size(self): return self.external[2]
 | |
| 
 | |
| class GPTQ:
 | |
|     def __init__(self, gptq_path):
 | |
|         self.load(gptq_path)
 | |
| 
 | |
|     def load(self, path):
 | |
|         for tensor in glob.glob(f'{path}/*.safetensors'):
 | |
|             self.load_safetensor(tensor)
 | |
| 
 | |
|     def prefix(self, name):
 | |
|         splits = name.split('.')
 | |
|         if 'lm_head' in splits[0] and len(splits) == 2:
 | |
|             return splits[0], splits[1]
 | |
|         if len(splits) < 5:
 | |
|             return None, None
 | |
|         pre = f'{splits[2]}.{splits[3]}.{splits[4]}'
 | |
|         suf = splits[-1]
 | |
|         return pre, suf
 | |
| 
 | |
|     def get(self, key : str):
 | |
|         if key in self.weight_dict:
 | |
|             return self.weight_dict[key]
 | |
|         return None
 | |
| 
 | |
|     def load_safetensor(self, tensor):
 | |
|         self.weight_dict = dict()
 | |
|         with safe_open(tensor, framework="pt") as f:
 | |
|             for k in f.keys():
 | |
|                 p, s = self.prefix(k)
 | |
|                 if p is None: continue
 | |
|                 if s not in ['qweight', 'scales']: continue
 | |
|                 if p not in self.weight_dict:
 | |
|                     self.weight_dict[p] = GPTQWeight(p)
 | |
|                 self.weight_dict[p].add(s, f.get_tensor(k))
 | |
| 
 | |
|     @staticmethod
 | |
|     def weight_reorder(qweight, bits=4, group_size=128):
 | |
|         oc = qweight.shape[-1]
 | |
|         wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)
 | |
|         weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
 | |
|         torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
 | |
|         weight = weight.reshape(-1, oc).transpose(1, 0)
 | |
|         if bits == 8:
 | |
|             weight = weight.to(torch.uint8)
 | |
|             return weight
 | |
|         weight = weight.reshape(-1, 2).to(torch.uint8)
 | |
|         weight = weight[:, 0] * 16 + weight[:, 1]
 | |
|         return weight
 | |
| 
 | |
|     def apply(self, graph_path, weight_path):
 | |
|         # parse mnn graph
 | |
|         mnn_weights = []
 | |
|         mnn_graph = json.load(open(graph_path, 'rt'))
 | |
|         for op in mnn_graph['oplists']:
 | |
|             if op['type'] == 'Convolution':
 | |
|                 name = op['name']
 | |
|                 external = op['main']['external']
 | |
|                 weight_elements = op['main']['common']['outputCount'] * op['main']['common']['inputCount']
 | |
|                 mnn_weights.append(MNNWeight(name, external, weight_elements))
 | |
|         # load mnn weight
 | |
|         external_weight = open(weight_path, 'r+b')
 | |
|         for mnn_weight in mnn_weights:
 | |
|             gptq_weight = self.get(mnn_weight.key())
 | |
|             if gptq_weight is None: continue
 | |
|             # print(f'write {mnn_weight.key()} ... ', end='')
 | |
|             weight = gptq_weight.qweight
 | |
|             scale = gptq_weight.scales.float().transpose(1, 0)
 | |
|             # write weight data
 | |
|             weight = GPTQ.weight_reorder(weight, mnn_weight.quant_bits)
 | |
|             weight_bytes = weight.numpy().tobytes()
 | |
|             weight_size = mnn_weight.weight_size()
 | |
|             header_len = weight_size - len(weight_bytes)
 | |
|             assert(header_len > 0)
 | |
|             external_weight.seek(mnn_weight.offset() + header_len)
 | |
|             external_weight.write(weight_bytes)
 | |
|             scale_size = mnn_weight.scale_size()
 | |
|             is_asy = scale.numel() * scale.element_size() < scale_size
 | |
|             # write scale data
 | |
|             if is_asy:
 | |
|                 # zeros = mnn_weight.a_min * scale
 | |
|                 zeros = torch.zeros_like(scale)
 | |
|                 scale = torch.stack([zeros, scale], axis=-1)
 | |
|             scale_bytes = scale.numpy().tobytes()
 | |
|             assert(scale_size == len(scale_bytes))
 | |
|             external_weight.write(scale_bytes)
 | |
|             # print('Done!')
 | |
|         external_weight.close() |