mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			79 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			79 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
| import torch
 | |
| def repack_low_bits(x, iNeedBits, block_size):
 | |
|     v = []
 | |
|     device = x.device
 | |
|     block_number = x.shape[0]
 | |
|     count = block_size * iNeedBits // 8
 | |
|     for i in range(0, count):
 | |
|         v.append(torch.zeros([block_number, 1], dtype=torch.uint8, device=device))
 | |
|     iOffset = 0
 | |
|     cMask = (1 << iNeedBits) - 1
 | |
|     index = 0
 | |
|     for i in range(0, block_size):
 | |
|         p0 = x[:, i:i+1]
 | |
|         uShift = 8 - iNeedBits - (iOffset % 8)
 | |
|         if uShift < 0:
 | |
|             v[index+iOffset // 8] |= ((p0 & cMask) >> (0 - uShift))
 | |
|             v[index+(iOffset // 8) + 1] |= ((p0 & cMask) << (8 + uShift))
 | |
|         else:
 | |
|             v[index+iOffset // 8] |= ((p0 & cMask) << uShift)
 | |
|         iOffset += iNeedBits
 | |
|         if iOffset % 8 == 0:
 | |
|             index += iOffset // 8
 | |
|             iOffset = 0
 | |
|     return torch.cat(v, axis=1) 
 | |
| 
 | |
| def quant(weight, quant_bit, quant_block, symmetric, awq):
 | |
|     try:
 | |
|         if torch.cuda.is_available():
 | |
|             weight = weight.cuda()
 | |
|         if torch.backends.mps.is_available():
 | |
|             weight = weight.to('mps')
 | |
|     except:
 | |
|         print('Failed to move weight to GPU, fallback to CPU')
 | |
|     oc, ic = weight.shape
 | |
|     if quant_block == 0:
 | |
|         block_size = ic
 | |
|     else:
 | |
|         block_size = quant_block
 | |
|     while ic % block_size != 0:
 | |
|         block_size /= 2
 | |
|     block_size = int(block_size)
 | |
|     block_num = ic // block_size
 | |
|     weight = weight.reshape(oc, block_num, block_size)
 | |
|     offset = 1 << (quant_bit - 1)
 | |
|     clip_max = offset - 1
 | |
|     if symmetric:
 | |
|         clip_min = -clip_max
 | |
|         abs_max, _ = torch.max(torch.abs(weight), axis=-1, keepdims=True)
 | |
|         scale = abs_max / clip_max
 | |
|         q_weight = torch.round(weight / scale)
 | |
|         q_weight = (torch.clamp(q_weight.flatten(), clip_min, clip_max) + offset).to(torch.uint8)
 | |
|         alpha = scale.flatten()
 | |
|     else:
 | |
|         clip_min = -offset
 | |
|         max_val, _ = torch.max(weight, axis=-1, keepdims=True)
 | |
|         min_val, _ = torch.min(weight, axis=-1, keepdims=True)
 | |
|         scale = (max_val - min_val) / (clip_max - clip_min)
 | |
| 
 | |
|         if awq:
 | |
|             q_weight = torch.round(weight / scale) - torch.round(min_val / scale) + clip_min
 | |
|             zeros =  (torch.round(min_val / scale) - clip_min) * scale
 | |
|         else:
 | |
|             q_weight = torch.round((weight - min_val) / scale) + clip_min
 | |
|             zeros =  min_val - scale * clip_min
 | |
|         q_weight = (torch.clamp(q_weight.flatten(), clip_min, clip_max) + offset).to(torch.uint8)
 | |
|         alpha = torch.stack([zeros.flatten(), scale.flatten()], axis=-1).flatten()
 | |
| 
 | |
|     if quant_bit < 8 and 8 % quant_bit == 0:
 | |
|         group_size = 8 // quant_bit
 | |
|         q_weight = q_weight.reshape(-1, group_size)
 | |
|         multipliers = [2 ** (quant_bit * (group_size - 1 - i)) for i in range(group_size)]
 | |
|         multipliers = torch.tensor(multipliers).to(q_weight.device)
 | |
|         q_weight = (q_weight * multipliers).sum(axis=1).to(torch.uint8)
 | |
|     elif quant_bit < 8:
 | |
|         q_weight = repack_low_bits(q_weight.reshape((block_num * oc, block_size)), quant_bit, block_size)
 | |
| 
 | |
|     if q_weight.device is not torch.device('cpu'):
 | |
|         return q_weight.cpu(), alpha.float().cpu()
 | |
|     return q_weight, alpha.float() |