mirror of https://github.com/alibaba/MNN.git
[Python:Refractor] Add Image train Demo and move dirs
This commit is contained in:
parent
53663bcc1b
commit
c4fe1da194
|
|
@ -0,0 +1,10 @@
|
||||||
|
# PC 上安装
|
||||||
|
仅支持python3
|
||||||
|
|
||||||
|
1、进入pymnn/pip_package 目录
|
||||||
|
2、python3 build_deps.py
|
||||||
|
3、sudo python3 setup.py install
|
||||||
|
|
||||||
|
# 示例
|
||||||
|
1、v2接口/推理:pymnn/examples/MNNEngineDemo
|
||||||
|
2、训练/ 压缩:pymnn/examples/MNNTrain
|
||||||
|
|
@ -0,0 +1,85 @@
|
||||||
|
import MNN
|
||||||
|
import MNN.var as var
|
||||||
|
c_train = MNN.c_train
|
||||||
|
nn = c_train.cnn
|
||||||
|
F = MNN.expr
|
||||||
|
data = c_train.data
|
||||||
|
import time
|
||||||
|
|
||||||
|
import sys
|
||||||
|
class Net(MNN.train.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
modelFile = sys.argv[1]
|
||||||
|
print(modelFile)
|
||||||
|
varMap = F.load_dict(modelFile)
|
||||||
|
inputVar = varMap['input']
|
||||||
|
outputVar = varMap['MobilenetV2/Logits/AvgPool']
|
||||||
|
self.net = c_train.load_module([inputVar], [outputVar], True)
|
||||||
|
self.fc = nn.conv(1280, 4, [1, 1])
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.net(x)
|
||||||
|
x = self.fc(x)
|
||||||
|
x = F.softmax(F.reshape(F.convert(x, F.NCHW), [0, -1]))
|
||||||
|
return x
|
||||||
|
|
||||||
|
scale = [0.00784314, 0.00784314, 0.00784314, 0.00784314]
|
||||||
|
mean = [127.5, 127.5, 127.5, 0]
|
||||||
|
|
||||||
|
imageConfig = data.image.config(MNN.cv.BGR, 224, 224, scale, mean, [1.0, 1.0], False)
|
||||||
|
picturePath = sys.argv[2]
|
||||||
|
print(picturePath)
|
||||||
|
txtPath = sys.argv[3]
|
||||||
|
imageDataset = data.image.image_label(picturePath, txtPath, imageConfig, False)
|
||||||
|
imageLoader = imageDataset.create_loader(10, True, True, 0)
|
||||||
|
|
||||||
|
def trainFunc(loader, net, opt):
|
||||||
|
loader.reset()
|
||||||
|
net.train(True)
|
||||||
|
t0 = time.time()
|
||||||
|
iter_number = loader.iter_number()
|
||||||
|
for i in range(0, iter_number):
|
||||||
|
example = loader.next()[0]
|
||||||
|
data = example[0][0]
|
||||||
|
label = F.reshape(example[1][0], [-1])
|
||||||
|
data = F.convert(data, F.NC4HW4)
|
||||||
|
predict = net(data)
|
||||||
|
target = F.one_hot(F.cast(label, F.int), var.int(4), var.float(1.0), var.float(0.0))
|
||||||
|
loss = c_train.loss.CrossEntropy(predict, target)
|
||||||
|
if i % 10 == 0:
|
||||||
|
print(i, loss.read(), iter_number)
|
||||||
|
opt.step(loss)
|
||||||
|
t1 = time.time()
|
||||||
|
cost = t1 - t0
|
||||||
|
print("Epoch cost: %.3f" %cost)
|
||||||
|
F.save(net.parameters(), "cache/temp.snapshot")
|
||||||
|
|
||||||
|
def testFunc(loader, net):
|
||||||
|
loader.reset()
|
||||||
|
net.train(False)
|
||||||
|
iter_number = loader.iter_number()
|
||||||
|
correct = 0
|
||||||
|
for i in range(0, iter_number):
|
||||||
|
example = loader.next()[0]
|
||||||
|
data = example[0][0]
|
||||||
|
label = F.reshape(example[1][0], [-1])
|
||||||
|
data = F.convert(data, F.NC4HW4)
|
||||||
|
predict = net(data)
|
||||||
|
predict = F.argmax(predict, 1)
|
||||||
|
accu = F.reduce_sum(F.equal(predict, F.cast(label, F.int)), [], False)
|
||||||
|
correct += accu.read()[0]
|
||||||
|
print("Accu: ", correct * 100.0 / loader.size(), "%")
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
net.loadParameters(F.load("cache/temp.snapshot"))
|
||||||
|
opt = c_train.SGD(0.0001, 0.9);
|
||||||
|
opt.append(net.parameters())
|
||||||
|
F.setThreadNumber(4)
|
||||||
|
testTxt = sys.argv[4]
|
||||||
|
testDataset = data.image.image_label(picturePath, testTxt, imageConfig, False)
|
||||||
|
testLoader = testDataset.create_loader(10, True, False, 0)
|
||||||
|
|
||||||
|
for epoch in range(0, 10):
|
||||||
|
testFunc(testLoader, net)
|
||||||
|
trainFunc(imageLoader, net, opt)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,65 @@
|
||||||
|
import MNN
|
||||||
|
import MNN.var as var
|
||||||
|
c_train = MNN.c_train
|
||||||
|
nn = c_train.cnn
|
||||||
|
F = MNN.expr
|
||||||
|
data = c_train.data
|
||||||
|
import time
|
||||||
|
|
||||||
|
import sys
|
||||||
|
modelFile = sys.argv[1]
|
||||||
|
print(modelFile)
|
||||||
|
|
||||||
|
varMap = F.load_dict(modelFile)
|
||||||
|
inputVar = varMap['sub_7']
|
||||||
|
outputVar = varMap['ResizeBilinear_3']
|
||||||
|
net = c_train.load_module([inputVar], [outputVar], True)
|
||||||
|
c_train.compress.quantize(net, 8, c_train.compress.PerChannel, c_train.compress.MovingAverage)
|
||||||
|
checkNet = c_train.load_module([inputVar], [outputVar], False)
|
||||||
|
|
||||||
|
scale = [0.00784314, 0.00784314, 0.00784314, 0.00784314]
|
||||||
|
mean = [127.5, 127.5, 127.5, 0]
|
||||||
|
|
||||||
|
imageConfig = data.image.config(MNN.cv.BGR, 257, 257, scale, mean, [1.0, 1.0], False)
|
||||||
|
picturePath = sys.argv[2]
|
||||||
|
print(picturePath)
|
||||||
|
imageDataset = data.image.image_no_label(picturePath, imageConfig)
|
||||||
|
imageLoader = imageDataset.create_loader(5, True, True, 0)
|
||||||
|
|
||||||
|
def trainFunc(loader, net, checkNet, opt):
|
||||||
|
loader.reset()
|
||||||
|
net.train(True)
|
||||||
|
t0 = time.time()
|
||||||
|
iter_number = loader.iter_number()
|
||||||
|
for i in range(0, iter_number):
|
||||||
|
example = loader.next()[0]
|
||||||
|
data = example[0][0]
|
||||||
|
data = F.convert(data, F.NC4HW4)
|
||||||
|
p0 = net(data)
|
||||||
|
p1 = checkNet(data)
|
||||||
|
p0 = F.reshape(F.convert(p0, F.NCHW), [0, -1])
|
||||||
|
p1 = F.reshape(F.convert(p1, F.NCHW), [0, -1])
|
||||||
|
loss = c_train.loss.MSE(p0, p1)
|
||||||
|
opt.step(loss)
|
||||||
|
if i % 10 == 0:
|
||||||
|
print(loss.read())
|
||||||
|
t1 = time.time()
|
||||||
|
cost = t1 - t0
|
||||||
|
print("Epoch cost: %.3f" %cost)
|
||||||
|
F.save(net.parameters(), "cache/temp.snapshot")
|
||||||
|
|
||||||
|
|
||||||
|
opt = c_train.SGD(0.000000000, 0.9);
|
||||||
|
opt.append(net.parameters())
|
||||||
|
|
||||||
|
for epoch in range(0, 1):
|
||||||
|
trainFunc(imageLoader, net, checkNet, opt)
|
||||||
|
|
||||||
|
net.train(False)
|
||||||
|
testInput = F.placeholder([1, 3, 257, 257], F.NC4HW4)
|
||||||
|
testInput.set_name("data")
|
||||||
|
testOutput = net(testInput)
|
||||||
|
testOutput.set_name("prob");
|
||||||
|
quanName = "temp.quan.mnn"
|
||||||
|
print("Save to " + quanName)
|
||||||
|
F.save([testOutput], quanName)
|
||||||
|
|
@ -1,74 +1,65 @@
|
||||||
import MNN.train as train
|
import MNN
|
||||||
import MNNPy.train
|
import MNN.var as var
|
||||||
import MNN.train.cnn as nn
|
c_train = MNN.c_train
|
||||||
import MNN.expr as F
|
nn = c_train.cnn
|
||||||
|
F = MNN.expr
|
||||||
|
data = c_train.data
|
||||||
import time
|
import time
|
||||||
import MNN.train.data as data
|
|
||||||
|
|
||||||
class Net(MNNPy.train.Module):
|
class Net(MNN.train.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self.conv1 = nn.Conv(1, 20, [5, 5])
|
self.conv1 = nn.conv(1, 20, [5, 5])
|
||||||
self.conv2 = nn.Conv(20, 50, [5, 5])
|
self.conv2 = nn.conv(20, 50, [5, 5])
|
||||||
self.fc1 = nn.Linear(800, 500)
|
self.fc1 = nn.linear(800, 500)
|
||||||
self.fc2 = nn.Linear(500, 10)
|
self.fc2 = nn.linear(500, 10)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = F.Relu(self.conv1(x))
|
x = F.relu(self.conv1(x))
|
||||||
x = F.MaxPool(x, [2, 2], [2, 2])
|
x = F.max_pool(x, [2, 2], [2, 2])
|
||||||
x = F.Relu(self.conv2(x))
|
x = F.relu(self.conv2(x))
|
||||||
x = F.MaxPool(x, [2, 2], [2, 2])
|
x = F.max_pool(x, [2, 2], [2, 2])
|
||||||
x = F.Convert(x, F.NCHW)
|
x = F.convert(x, F.NCHW)
|
||||||
x = F.Reshape(x, [0, -1])
|
x = F.reshape(x, [0, -1])
|
||||||
x = F.Relu(self.fc1(x))
|
x = F.relu(self.fc1(x))
|
||||||
x = self.fc2(x)
|
x = self.fc2(x)
|
||||||
x = F.Softmax(x, 1)
|
x = F.softmax(x, 1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def initFloat(value):
|
|
||||||
res = F.Input([], F.NCHW, F.float)
|
|
||||||
res.write([value])
|
|
||||||
res.fix(F.Const)
|
|
||||||
return res
|
|
||||||
def initInt(value):
|
|
||||||
res = F.Input([], F.NCHW, F.int)
|
|
||||||
res.write([value])
|
|
||||||
res.fix(F.Const)
|
|
||||||
return res
|
|
||||||
|
|
||||||
def testFunc(loader, net):
|
def testFunc(loader, net):
|
||||||
loader.reset()
|
loader.reset()
|
||||||
net.train(False)
|
net.train(False)
|
||||||
iterNumber = loader.iterNumber()
|
iter_number = loader.iter_number()
|
||||||
correct = 0
|
correct = 0
|
||||||
for i in range(0, iterNumber):
|
for i in range(0, iter_number):
|
||||||
example = loader.next()[0]
|
example = loader.next()[0]
|
||||||
data = example[0][0]
|
data = example[0][0]
|
||||||
label = example[1][0]
|
label = example[1][0]
|
||||||
|
|
||||||
data = F.Multiply(F.Cast(data, F.float), initFloat(1.0/255.0))
|
data = F.cast(data, F.float) * var.float(1.0/255.0)
|
||||||
predict = net(data)
|
predict = net(data)
|
||||||
predict = F.ArgMax(predict, 1)
|
predict = F.argmax(predict, 1)
|
||||||
accu = F.ReduceSum(F.Equal(predict, F.Cast(label, F.int)), [], False)
|
accu = F.reduce_sum(F.equal(predict, F.cast(label, F.int)), [], False)
|
||||||
correct += accu.read()[0]
|
correct += accu.read()[0]
|
||||||
print(correct * 1.0 / loader.size())
|
print("Accu: ", correct * 100.0 / loader.size(), "%")
|
||||||
|
|
||||||
|
|
||||||
def trainFunc(loader, net, opt):
|
def trainFunc(loader, net, opt):
|
||||||
loader.reset()
|
loader.reset()
|
||||||
net.train()
|
net.train()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
iterNumber = loader.iterNumber()
|
iter_number = loader.iter_number()
|
||||||
for i in range(0, iterNumber):
|
for i in range(0, iter_number):
|
||||||
example = loader.next()[0]
|
example = loader.next()[0]
|
||||||
data = example[0][0]
|
data = example[0][0]
|
||||||
label = example[1][0]
|
label = example[1][0]
|
||||||
|
|
||||||
data = F.Multiply(F.Cast(data, F.float), initFloat(1.0/255.0))
|
data = F.cast(data, F.float) * var.float(1.0/255.0)
|
||||||
predict = net(data)
|
predict = net(data)
|
||||||
target = F.OneHot(F.Cast(label, F.int), initInt(10), initFloat(1.0), initFloat(0.0))
|
target = F.one_hot(F.cast(label, F.int), var.int(10), var.float(1.0), var.float(0.0))
|
||||||
loss = train.loss.CrossEntropy(predict, target)
|
loss = c_train.loss.CrossEntropy(predict, target)
|
||||||
opt.step(loss)
|
opt.step(loss)
|
||||||
if i % 100 == 0:
|
if i % 100 == 0:
|
||||||
print(loss.read())
|
print(loss.read())
|
||||||
|
|
@ -79,14 +70,17 @@ def trainFunc(loader, net, opt):
|
||||||
|
|
||||||
|
|
||||||
net = Net()
|
net = Net()
|
||||||
opt = train.SGD(0.01, 0.9)
|
opt = c_train.SGD(0.01, 0.9)
|
||||||
net.loadParameters(F.load("cache/temp.snapshot"))
|
net.loadParameters(F.load("cache/temp.snapshot"))
|
||||||
opt.append(net.parameters())
|
opt.append(net.parameters())
|
||||||
|
|
||||||
mnistDataset = data.mnist.create("/Users/jiangxiaotang/data/mnist", data.mnist.Train)
|
import sys
|
||||||
trainLoader = mnistDataset.createLoader(64, True, True, 0)
|
mnistDataPath = sys.argv[1]
|
||||||
testmnistDataset = data.mnist.create("/Users/jiangxiaotang/data/mnist", data.mnist.Test)
|
|
||||||
testLoader = mnistDataset.createLoader(10, True, False, 0)
|
mnistDataset = data.mnist.create(mnistDataPath, data.mnist.Train)
|
||||||
|
trainLoader = mnistDataset.create_loader(64, True, True, 0)
|
||||||
|
testmnistDataset = data.mnist.create(mnistDataPath, data.mnist.Test)
|
||||||
|
testLoader = mnistDataset.create_loader(10, True, False, 0)
|
||||||
|
|
||||||
F.setThreadNumber(4)
|
F.setThreadNumber(4)
|
||||||
for epoch in range(0, 10):
|
for epoch in range(0, 10):
|
||||||
|
|
@ -94,7 +88,7 @@ for epoch in range(0, 10):
|
||||||
# Save Model
|
# Save Model
|
||||||
fileName = 'cache/%d.mnist.mnn' %epoch
|
fileName = 'cache/%d.mnist.mnn' %epoch
|
||||||
net.train(False)
|
net.train(False)
|
||||||
predict = net.forward(F.Input([1, 1, 28, 28], F.NC4HW4))
|
predict = net.forward(F.placeholder([1, 1, 28, 28], F.NC4HW4))
|
||||||
print("Save to " + fileName)
|
print("Save to " + fileName)
|
||||||
F.save([predict], fileName)
|
F.save([predict], fileName)
|
||||||
testFunc(testLoader, net)
|
testFunc(testLoader, net)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
from _mnncengine import *
|
||||||
|
from . import train
|
||||||
|
from . import tools
|
||||||
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from . import mnn, mnnops, mnnquant, mnnvisual, mnnconvert
|
||||||
|
from . import utils
|
||||||
|
from . import mnn_fb
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue