mirror of https://github.com/alibaba/MNN.git
				
				
				
			[Python:Refractor] Add Image train Demo and move dirs
This commit is contained in:
		
							parent
							
								
									53663bcc1b
								
							
						
					
					
						commit
						c4fe1da194
					
				|  | @ -0,0 +1,10 @@ | |||
| # PC 上安装 | ||||
| 仅支持python3 | ||||
| 
 | ||||
| 1、进入pymnn/pip_package 目录 | ||||
| 2、python3 build_deps.py | ||||
| 3、sudo python3 setup.py install | ||||
| 
 | ||||
| # 示例 | ||||
| 1、v2接口/推理:pymnn/examples/MNNEngineDemo | ||||
| 2、训练/ 压缩:pymnn/examples/MNNTrain | ||||
|  | @ -0,0 +1,85 @@ | |||
| import MNN | ||||
| import MNN.var as var | ||||
| c_train = MNN.c_train | ||||
| nn = c_train.cnn | ||||
| F = MNN.expr | ||||
| data = c_train.data | ||||
| import time | ||||
| 
 | ||||
| import sys | ||||
| class Net(MNN.train.Module): | ||||
|     def __init__(self): | ||||
|         super(Net, self).__init__() | ||||
|         modelFile = sys.argv[1] | ||||
|         print(modelFile) | ||||
|         varMap = F.load_dict(modelFile) | ||||
|         inputVar = varMap['input'] | ||||
|         outputVar = varMap['MobilenetV2/Logits/AvgPool'] | ||||
|         self.net = c_train.load_module([inputVar], [outputVar], True) | ||||
|         self.fc = nn.conv(1280, 4, [1, 1]) | ||||
|     def forward(self, x): | ||||
|         x = self.net(x) | ||||
|         x = self.fc(x) | ||||
|         x = F.softmax(F.reshape(F.convert(x, F.NCHW), [0, -1])) | ||||
|         return x | ||||
| 
 | ||||
| scale = [0.00784314, 0.00784314, 0.00784314, 0.00784314] | ||||
| mean = [127.5, 127.5, 127.5, 0] | ||||
| 
 | ||||
| imageConfig = data.image.config(MNN.cv.BGR, 224, 224, scale, mean, [1.0, 1.0], False) | ||||
| picturePath = sys.argv[2] | ||||
| print(picturePath) | ||||
| txtPath = sys.argv[3] | ||||
| imageDataset = data.image.image_label(picturePath, txtPath, imageConfig, False) | ||||
| imageLoader = imageDataset.create_loader(10, True, True, 0) | ||||
| 
 | ||||
| def trainFunc(loader, net, opt): | ||||
|     loader.reset() | ||||
|     net.train(True) | ||||
|     t0 = time.time() | ||||
|     iter_number = loader.iter_number() | ||||
|     for i in range(0, iter_number): | ||||
|         example = loader.next()[0] | ||||
|         data = example[0][0] | ||||
|         label = F.reshape(example[1][0], [-1]) | ||||
|         data = F.convert(data, F.NC4HW4) | ||||
|         predict = net(data) | ||||
|         target = F.one_hot(F.cast(label, F.int), var.int(4), var.float(1.0), var.float(0.0)) | ||||
|         loss = c_train.loss.CrossEntropy(predict, target) | ||||
|         if i % 10 == 0: | ||||
|             print(i, loss.read(), iter_number) | ||||
|         opt.step(loss) | ||||
|     t1 = time.time() | ||||
|     cost = t1 - t0 | ||||
|     print("Epoch cost: %.3f" %cost) | ||||
|     F.save(net.parameters(), "cache/temp.snapshot") | ||||
| 
 | ||||
| def testFunc(loader, net): | ||||
|     loader.reset() | ||||
|     net.train(False) | ||||
|     iter_number = loader.iter_number() | ||||
|     correct = 0 | ||||
|     for i in range(0, iter_number): | ||||
|         example = loader.next()[0] | ||||
|         data = example[0][0] | ||||
|         label = F.reshape(example[1][0], [-1]) | ||||
|         data = F.convert(data, F.NC4HW4) | ||||
|         predict = net(data) | ||||
|         predict = F.argmax(predict, 1) | ||||
|         accu = F.reduce_sum(F.equal(predict, F.cast(label, F.int)), [], False) | ||||
|         correct += accu.read()[0] | ||||
|     print("Accu: ", correct * 100.0 / loader.size(), "%") | ||||
| 
 | ||||
| net = Net() | ||||
| net.loadParameters(F.load("cache/temp.snapshot")) | ||||
| opt = c_train.SGD(0.0001, 0.9); | ||||
| opt.append(net.parameters()) | ||||
| F.setThreadNumber(4) | ||||
| testTxt = sys.argv[4] | ||||
| testDataset = data.image.image_label(picturePath, testTxt, imageConfig, False) | ||||
| testLoader = testDataset.create_loader(10, True, False, 0) | ||||
| 
 | ||||
| for epoch in range(0, 10): | ||||
|     testFunc(testLoader, net) | ||||
|     trainFunc(imageLoader, net, opt) | ||||
| 
 | ||||
|  | @ -0,0 +1,65 @@ | |||
| import MNN | ||||
| import MNN.var as var | ||||
| c_train = MNN.c_train | ||||
| nn = c_train.cnn | ||||
| F = MNN.expr | ||||
| data = c_train.data | ||||
| import time | ||||
| 
 | ||||
| import sys | ||||
| modelFile = sys.argv[1] | ||||
| print(modelFile) | ||||
| 
 | ||||
| varMap = F.load_dict(modelFile) | ||||
| inputVar = varMap['sub_7'] | ||||
| outputVar = varMap['ResizeBilinear_3'] | ||||
| net = c_train.load_module([inputVar], [outputVar], True) | ||||
| c_train.compress.quantize(net, 8,  c_train.compress.PerChannel, c_train.compress.MovingAverage) | ||||
| checkNet = c_train.load_module([inputVar], [outputVar], False) | ||||
| 
 | ||||
| scale = [0.00784314, 0.00784314, 0.00784314, 0.00784314] | ||||
| mean = [127.5, 127.5, 127.5, 0] | ||||
| 
 | ||||
| imageConfig = data.image.config(MNN.cv.BGR, 257, 257, scale, mean, [1.0, 1.0], False) | ||||
| picturePath = sys.argv[2] | ||||
| print(picturePath) | ||||
| imageDataset = data.image.image_no_label(picturePath, imageConfig) | ||||
| imageLoader = imageDataset.create_loader(5, True, True, 0) | ||||
| 
 | ||||
| def trainFunc(loader, net, checkNet, opt): | ||||
|     loader.reset() | ||||
|     net.train(True) | ||||
|     t0 = time.time() | ||||
|     iter_number = loader.iter_number() | ||||
|     for i in range(0, iter_number): | ||||
|         example = loader.next()[0] | ||||
|         data = example[0][0] | ||||
|         data = F.convert(data, F.NC4HW4) | ||||
|         p0 = net(data) | ||||
|         p1 = checkNet(data) | ||||
|         p0 = F.reshape(F.convert(p0, F.NCHW), [0, -1]) | ||||
|         p1 = F.reshape(F.convert(p1, F.NCHW), [0, -1]) | ||||
|         loss = c_train.loss.MSE(p0, p1) | ||||
|         opt.step(loss) | ||||
|         if i % 10 == 0: | ||||
|             print(loss.read()) | ||||
|     t1 = time.time() | ||||
|     cost = t1 - t0 | ||||
|     print("Epoch cost: %.3f" %cost) | ||||
|     F.save(net.parameters(), "cache/temp.snapshot") | ||||
| 
 | ||||
| 
 | ||||
| opt = c_train.SGD(0.000000000, 0.9); | ||||
| opt.append(net.parameters()) | ||||
| 
 | ||||
| for epoch in range(0, 1): | ||||
|     trainFunc(imageLoader, net, checkNet, opt) | ||||
| 
 | ||||
| net.train(False) | ||||
| testInput = F.placeholder([1, 3, 257, 257], F.NC4HW4) | ||||
| testInput.set_name("data") | ||||
| testOutput = net(testInput) | ||||
| testOutput.set_name("prob"); | ||||
| quanName = "temp.quan.mnn" | ||||
| print("Save to " + quanName) | ||||
| F.save([testOutput], quanName) | ||||
|  | @ -1,74 +1,65 @@ | |||
| import MNN.train as train | ||||
| import MNNPy.train | ||||
| import MNN.train.cnn as nn | ||||
| import MNN.expr as F | ||||
| import MNN | ||||
| import MNN.var as var | ||||
| c_train = MNN.c_train | ||||
| nn = c_train.cnn | ||||
| F = MNN.expr | ||||
| data = c_train.data | ||||
| import time | ||||
| import MNN.train.data as data | ||||
| 
 | ||||
| class Net(MNNPy.train.Module): | ||||
| class Net(MNN.train.Module): | ||||
|     def __init__(self): | ||||
|         super(Net, self).__init__() | ||||
|         self.conv1 = nn.Conv(1, 20, [5, 5]) | ||||
|         self.conv2 = nn.Conv(20, 50, [5, 5]) | ||||
|         self.fc1 = nn.Linear(800, 500) | ||||
|         self.fc2 = nn.Linear(500, 10) | ||||
|         self.conv1 = nn.conv(1, 20, [5, 5]) | ||||
|         self.conv2 = nn.conv(20, 50, [5, 5]) | ||||
|         self.fc1 = nn.linear(800, 500) | ||||
|         self.fc2 = nn.linear(500, 10) | ||||
| 
 | ||||
|     def forward(self, x): | ||||
|         x = F.Relu(self.conv1(x)) | ||||
|         x = F.MaxPool(x, [2, 2], [2, 2]) | ||||
|         x = F.Relu(self.conv2(x)) | ||||
|         x = F.MaxPool(x, [2, 2], [2, 2]) | ||||
|         x = F.Convert(x, F.NCHW) | ||||
|         x = F.Reshape(x, [0, -1]) | ||||
|         x = F.Relu(self.fc1(x)) | ||||
|         x = F.relu(self.conv1(x)) | ||||
|         x = F.max_pool(x, [2, 2], [2, 2]) | ||||
|         x = F.relu(self.conv2(x)) | ||||
|         x = F.max_pool(x, [2, 2], [2, 2]) | ||||
|         x = F.convert(x, F.NCHW) | ||||
|         x = F.reshape(x, [0, -1]) | ||||
|         x = F.relu(self.fc1(x)) | ||||
|         x = self.fc2(x) | ||||
|         x = F.Softmax(x, 1) | ||||
|         x = F.softmax(x, 1) | ||||
|         return x | ||||
| 
 | ||||
| 
 | ||||
| def initFloat(value): | ||||
|     res = F.Input([], F.NCHW, F.float) | ||||
|     res.write([value]) | ||||
|     res.fix(F.Const) | ||||
|     return res | ||||
| def initInt(value): | ||||
|     res = F.Input([], F.NCHW, F.int) | ||||
|     res.write([value]) | ||||
|     res.fix(F.Const) | ||||
|     return res | ||||
| 
 | ||||
| def testFunc(loader, net): | ||||
|     loader.reset() | ||||
|     net.train(False) | ||||
|     iterNumber = loader.iterNumber() | ||||
|     iter_number = loader.iter_number() | ||||
|     correct = 0 | ||||
|     for i in range(0, iterNumber): | ||||
|     for i in range(0, iter_number): | ||||
|         example = loader.next()[0] | ||||
|         data = example[0][0] | ||||
|         label = example[1][0] | ||||
| 
 | ||||
|         data = F.Multiply(F.Cast(data, F.float), initFloat(1.0/255.0)) | ||||
|         data = F.cast(data, F.float) * var.float(1.0/255.0) | ||||
|         predict = net(data) | ||||
|         predict = F.ArgMax(predict, 1) | ||||
|         accu = F.ReduceSum(F.Equal(predict, F.Cast(label, F.int)), [], False) | ||||
|         predict = F.argmax(predict, 1) | ||||
|         accu = F.reduce_sum(F.equal(predict, F.cast(label, F.int)), [], False) | ||||
|         correct += accu.read()[0] | ||||
|     print(correct * 1.0 / loader.size()) | ||||
|     print("Accu: ", correct * 100.0 / loader.size(), "%") | ||||
| 
 | ||||
| 
 | ||||
| def trainFunc(loader, net, opt): | ||||
|     loader.reset() | ||||
|     net.train() | ||||
|     t0 = time.time() | ||||
|     iterNumber = loader.iterNumber() | ||||
|     for i in range(0, iterNumber): | ||||
|     iter_number = loader.iter_number() | ||||
|     for i in range(0, iter_number): | ||||
|         example = loader.next()[0] | ||||
|         data = example[0][0] | ||||
|         label = example[1][0] | ||||
| 
 | ||||
|         data = F.Multiply(F.Cast(data, F.float), initFloat(1.0/255.0)) | ||||
|         data = F.cast(data, F.float) * var.float(1.0/255.0) | ||||
|         predict = net(data) | ||||
|         target = F.OneHot(F.Cast(label, F.int), initInt(10), initFloat(1.0), initFloat(0.0)) | ||||
|         loss = train.loss.CrossEntropy(predict, target) | ||||
|         target = F.one_hot(F.cast(label, F.int), var.int(10), var.float(1.0), var.float(0.0)) | ||||
|         loss = c_train.loss.CrossEntropy(predict, target) | ||||
|         opt.step(loss) | ||||
|         if i % 100 == 0: | ||||
|             print(loss.read()) | ||||
|  | @ -79,14 +70,17 @@ def trainFunc(loader, net, opt): | |||
| 
 | ||||
| 
 | ||||
| net = Net() | ||||
| opt = train.SGD(0.01, 0.9) | ||||
| opt = c_train.SGD(0.01, 0.9) | ||||
| net.loadParameters(F.load("cache/temp.snapshot")) | ||||
| opt.append(net.parameters()) | ||||
| 
 | ||||
| mnistDataset = data.mnist.create("/Users/jiangxiaotang/data/mnist", data.mnist.Train) | ||||
| trainLoader = mnistDataset.createLoader(64, True, True, 0) | ||||
| testmnistDataset = data.mnist.create("/Users/jiangxiaotang/data/mnist", data.mnist.Test) | ||||
| testLoader = mnistDataset.createLoader(10, True, False, 0) | ||||
| import sys | ||||
| mnistDataPath = sys.argv[1] | ||||
| 
 | ||||
| mnistDataset = data.mnist.create(mnistDataPath, data.mnist.Train) | ||||
| trainLoader = mnistDataset.create_loader(64, True, True, 0) | ||||
| testmnistDataset = data.mnist.create(mnistDataPath, data.mnist.Test) | ||||
| testLoader = mnistDataset.create_loader(10, True, False, 0) | ||||
| 
 | ||||
| F.setThreadNumber(4) | ||||
| for epoch in range(0, 10): | ||||
|  | @ -94,7 +88,7 @@ for epoch in range(0, 10): | |||
|     # Save Model | ||||
|     fileName = 'cache/%d.mnist.mnn' %epoch | ||||
|     net.train(False) | ||||
|     predict = net.forward(F.Input([1, 1, 28, 28], F.NC4HW4)) | ||||
|     predict = net.forward(F.placeholder([1, 1, 28, 28], F.NC4HW4)) | ||||
|     print("Save to " + fileName) | ||||
|     F.save([predict], fileName) | ||||
|     testFunc(testLoader, net) | ||||
|  |  | |||
|  | @ -0,0 +1,4 @@ | |||
| from _mnncengine import * | ||||
| from . import train | ||||
| from . import tools | ||||
| 
 | ||||
|  | @ -0,0 +1,3 @@ | |||
| from . import mnn, mnnops, mnnquant, mnnvisual, mnnconvert  | ||||
| from . import utils | ||||
| from . import mnn_fb | ||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
		Reference in New Issue