mirror of https://github.com/alibaba/MNN.git
70 lines
3.0 KiB
Python
70 lines
3.0 KiB
Python
|
|
import os
|
||
|
|
import onnx
|
||
|
|
|
||
|
|
class OnnxRebuilder:
|
||
|
|
def __init__(self, onnx_path, weight_ops):
|
||
|
|
self.weight_ops = weight_ops
|
||
|
|
self.onnx_model = onnx.load(onnx_path)
|
||
|
|
self.dst_path = onnx_path
|
||
|
|
self.onnx_weight_path = f'{onnx_path}.data'
|
||
|
|
self.onnx_weight_offset = 0
|
||
|
|
|
||
|
|
def make_external(self, name, data, shape):
|
||
|
|
# write to external weight
|
||
|
|
length = self.onnx_weight.write(data.tobytes())
|
||
|
|
location = os.path.basename(self.onnx_weight_path)
|
||
|
|
offset = self.onnx_weight_offset
|
||
|
|
self.onnx_weight_offset += length
|
||
|
|
tensor = onnx.TensorProto()
|
||
|
|
tensor.name = name
|
||
|
|
tensor.data_type = onnx.TensorProto.FLOAT
|
||
|
|
tensor.dims.extend(shape)
|
||
|
|
# external info
|
||
|
|
tensor.data_location = onnx.TensorProto.EXTERNAL
|
||
|
|
for k, v in { "location": location, "offset": offset, "length": length }.items():
|
||
|
|
entry = tensor.external_data.add()
|
||
|
|
entry.key = k
|
||
|
|
entry.value = str(v)
|
||
|
|
self.onnx_model.graph.initializer.append(tensor)
|
||
|
|
|
||
|
|
def build_weight(self, name, has_bias, ic, oc):
|
||
|
|
assert(name in self.weight_ops)
|
||
|
|
linear = self.weight_ops[name]
|
||
|
|
assert(linear.in_features == ic and
|
||
|
|
linear.out_features == oc and
|
||
|
|
(linear.bias is not None) == has_bias)
|
||
|
|
weight_name, bias_name = f'{name}_weight', f'{name}_bias'
|
||
|
|
weight = linear.weight.data.transpose(1, 0).flatten().float().numpy()
|
||
|
|
self.make_external(weight_name, weight, [ic, oc])
|
||
|
|
if has_bias:
|
||
|
|
bias = linear.bias.data.flatten().float().numpy()
|
||
|
|
self.make_external(bias_name, bias, [oc])
|
||
|
|
return weight_name, bias_name
|
||
|
|
|
||
|
|
def rebuild(self):
|
||
|
|
from onnx import helper
|
||
|
|
new_nodes = []
|
||
|
|
self.onnx_weight = open(self.onnx_weight_path, 'wb')
|
||
|
|
for node in self.onnx_model.graph.node:
|
||
|
|
if node.op_type == 'FakeLinear':
|
||
|
|
attributes = {a.name: a for a in node.attribute}
|
||
|
|
name = attributes.get('name').s.decode('utf-8')
|
||
|
|
has_bias = attributes.get('has_bias').i
|
||
|
|
ic = attributes.get('in_features').i
|
||
|
|
oc = attributes.get('out_features').i
|
||
|
|
weight, bias = self.build_weight(name, has_bias, ic, oc)
|
||
|
|
if has_bias:
|
||
|
|
# fakelinear -> matmul + add
|
||
|
|
middle_tensor = f'{name}_matmul'
|
||
|
|
new_nodes.append(helper.make_node('MatMul', [node.input[0], weight], [middle_tensor], name))
|
||
|
|
new_nodes.append(helper.make_node('Add', [middle_tensor, bias], node.output, f'{name}/Add'))
|
||
|
|
else:
|
||
|
|
# fakelinear -> matmul
|
||
|
|
new_nodes.append(helper.make_node('MatMul', [node.input[0], weight], node.output, name))
|
||
|
|
else:
|
||
|
|
new_nodes.append(node)
|
||
|
|
self.onnx_weight.close()
|
||
|
|
del self.onnx_model.graph.node[:]
|
||
|
|
self.onnx_model.graph.node.extend(new_nodes)
|
||
|
|
onnx.save(self.onnx_model, self.dst_path)
|
||
|
|
return self.onnx_weight_path
|