mirror of https://github.com/alibaba/MNN.git
79 lines
3.1 KiB
Python
79 lines
3.1 KiB
Python
|
|
import torch
|
||
|
|
def repack_low_bits(x, iNeedBits, block_size):
|
||
|
|
v = []
|
||
|
|
device = x.device
|
||
|
|
block_number = x.shape[0]
|
||
|
|
count = block_size * iNeedBits // 8
|
||
|
|
for i in range(0, count):
|
||
|
|
v.append(torch.zeros([block_number, 1], dtype=torch.uint8, device=device))
|
||
|
|
iOffset = 0
|
||
|
|
cMask = (1 << iNeedBits) - 1
|
||
|
|
index = 0
|
||
|
|
for i in range(0, block_size):
|
||
|
|
p0 = x[:, i:i+1]
|
||
|
|
uShift = 8 - iNeedBits - (iOffset % 8)
|
||
|
|
if uShift < 0:
|
||
|
|
v[index+iOffset // 8] |= ((p0 & cMask) >> (0 - uShift))
|
||
|
|
v[index+(iOffset // 8) + 1] |= ((p0 & cMask) << (8 + uShift))
|
||
|
|
else:
|
||
|
|
v[index+iOffset // 8] |= ((p0 & cMask) << uShift)
|
||
|
|
iOffset += iNeedBits
|
||
|
|
if iOffset % 8 == 0:
|
||
|
|
index += iOffset // 8
|
||
|
|
iOffset = 0
|
||
|
|
return torch.cat(v, axis=1)
|
||
|
|
|
||
|
|
def quant(weight, quant_bit, quant_block, symmetric, awq):
|
||
|
|
try:
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
weight = weight.cuda()
|
||
|
|
if torch.backends.mps.is_available():
|
||
|
|
weight = weight.to('mps')
|
||
|
|
except:
|
||
|
|
print('Failed to move weight to GPU, fallback to CPU')
|
||
|
|
oc, ic = weight.shape
|
||
|
|
if quant_block == 0:
|
||
|
|
block_size = ic
|
||
|
|
else:
|
||
|
|
block_size = quant_block
|
||
|
|
while ic % block_size != 0:
|
||
|
|
block_size /= 2
|
||
|
|
block_size = int(block_size)
|
||
|
|
block_num = ic // block_size
|
||
|
|
weight = weight.reshape(oc, block_num, block_size)
|
||
|
|
offset = 1 << (quant_bit - 1)
|
||
|
|
clip_max = offset - 1
|
||
|
|
if symmetric:
|
||
|
|
clip_min = -clip_max
|
||
|
|
abs_max, _ = torch.max(torch.abs(weight), axis=-1, keepdims=True)
|
||
|
|
scale = abs_max / clip_max
|
||
|
|
q_weight = torch.round(weight / scale)
|
||
|
|
q_weight = (torch.clamp(q_weight.flatten(), clip_min, clip_max) + offset).to(torch.uint8)
|
||
|
|
alpha = scale.flatten()
|
||
|
|
else:
|
||
|
|
clip_min = -offset
|
||
|
|
max_val, _ = torch.max(weight, axis=-1, keepdims=True)
|
||
|
|
min_val, _ = torch.min(weight, axis=-1, keepdims=True)
|
||
|
|
scale = (max_val - min_val) / (clip_max - clip_min)
|
||
|
|
|
||
|
|
if awq:
|
||
|
|
q_weight = torch.round(weight / scale) - torch.round(min_val / scale) + clip_min
|
||
|
|
zeros = (torch.round(min_val / scale) - clip_min) * scale
|
||
|
|
else:
|
||
|
|
q_weight = torch.round((weight - min_val) / scale) + clip_min
|
||
|
|
zeros = min_val - scale * clip_min
|
||
|
|
q_weight = (torch.clamp(q_weight.flatten(), clip_min, clip_max) + offset).to(torch.uint8)
|
||
|
|
alpha = torch.stack([zeros.flatten(), scale.flatten()], axis=-1).flatten()
|
||
|
|
|
||
|
|
if quant_bit < 8 and 8 % quant_bit == 0:
|
||
|
|
group_size = 8 // quant_bit
|
||
|
|
q_weight = q_weight.reshape(-1, group_size)
|
||
|
|
multipliers = [2 ** (quant_bit * (group_size - 1 - i)) for i in range(group_size)]
|
||
|
|
multipliers = torch.tensor(multipliers).to(q_weight.device)
|
||
|
|
q_weight = (q_weight * multipliers).sum(axis=1).to(torch.uint8)
|
||
|
|
elif quant_bit < 8:
|
||
|
|
q_weight = repack_low_bits(q_weight.reshape((block_num * oc, block_size)), quant_bit, block_size)
|
||
|
|
|
||
|
|
if q_weight.device is not torch.device('cpu'):
|
||
|
|
return q_weight.cpu(), alpha.float().cpu()
|
||
|
|
return q_weight, alpha.float()
|