2020-03-12 09:45:16 +08:00
|
|
|
import MNN
|
|
|
|
|
import MNN.var as var
|
|
|
|
|
c_train = MNN.c_train
|
|
|
|
|
nn = c_train.cnn
|
|
|
|
|
F = MNN.expr
|
|
|
|
|
data = c_train.data
|
2020-02-17 19:41:45 +08:00
|
|
|
import time
|
|
|
|
|
|
2020-03-12 09:45:16 +08:00
|
|
|
class Net(MNN.train.Module):
|
2020-02-17 19:41:45 +08:00
|
|
|
def __init__(self):
|
|
|
|
|
super(Net, self).__init__()
|
2020-03-12 09:45:16 +08:00
|
|
|
self.conv1 = nn.conv(1, 20, [5, 5])
|
|
|
|
|
self.conv2 = nn.conv(20, 50, [5, 5])
|
|
|
|
|
self.fc1 = nn.linear(800, 500)
|
|
|
|
|
self.fc2 = nn.linear(500, 10)
|
2020-02-17 19:41:45 +08:00
|
|
|
|
|
|
|
|
def forward(self, x):
|
2020-03-12 09:45:16 +08:00
|
|
|
x = F.relu(self.conv1(x))
|
|
|
|
|
x = F.max_pool(x, [2, 2], [2, 2])
|
|
|
|
|
x = F.relu(self.conv2(x))
|
|
|
|
|
x = F.max_pool(x, [2, 2], [2, 2])
|
|
|
|
|
x = F.convert(x, F.NCHW)
|
|
|
|
|
x = F.reshape(x, [0, -1])
|
|
|
|
|
x = F.relu(self.fc1(x))
|
2020-02-17 19:41:45 +08:00
|
|
|
x = self.fc2(x)
|
2020-03-12 09:45:16 +08:00
|
|
|
x = F.softmax(x, 1)
|
2020-02-17 19:41:45 +08:00
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def testFunc(loader, net):
|
|
|
|
|
loader.reset()
|
|
|
|
|
net.train(False)
|
2020-03-12 09:45:16 +08:00
|
|
|
iter_number = loader.iter_number()
|
2020-02-17 19:41:45 +08:00
|
|
|
correct = 0
|
2020-03-12 09:45:16 +08:00
|
|
|
for i in range(0, iter_number):
|
2020-02-17 19:41:45 +08:00
|
|
|
example = loader.next()[0]
|
|
|
|
|
data = example[0][0]
|
|
|
|
|
label = example[1][0]
|
|
|
|
|
|
2020-03-12 09:45:16 +08:00
|
|
|
data = F.cast(data, F.float) * var.float(1.0/255.0)
|
2020-02-17 19:41:45 +08:00
|
|
|
predict = net(data)
|
2020-03-12 09:45:16 +08:00
|
|
|
predict = F.argmax(predict, 1)
|
|
|
|
|
accu = F.reduce_sum(F.equal(predict, F.cast(label, F.int)), [], False)
|
2020-02-17 19:41:45 +08:00
|
|
|
correct += accu.read()[0]
|
2020-03-12 09:45:16 +08:00
|
|
|
print("Accu: ", correct * 100.0 / loader.size(), "%")
|
2020-02-17 19:41:45 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def trainFunc(loader, net, opt):
|
|
|
|
|
loader.reset()
|
|
|
|
|
net.train()
|
|
|
|
|
t0 = time.time()
|
2020-03-12 09:45:16 +08:00
|
|
|
iter_number = loader.iter_number()
|
|
|
|
|
for i in range(0, iter_number):
|
2020-02-17 19:41:45 +08:00
|
|
|
example = loader.next()[0]
|
|
|
|
|
data = example[0][0]
|
|
|
|
|
label = example[1][0]
|
|
|
|
|
|
2020-03-12 09:45:16 +08:00
|
|
|
data = F.cast(data, F.float) * var.float(1.0/255.0)
|
2020-02-17 19:41:45 +08:00
|
|
|
predict = net(data)
|
2020-03-12 09:45:16 +08:00
|
|
|
target = F.one_hot(F.cast(label, F.int), var.int(10), var.float(1.0), var.float(0.0))
|
|
|
|
|
loss = c_train.loss.CrossEntropy(predict, target)
|
2020-02-17 19:41:45 +08:00
|
|
|
opt.step(loss)
|
|
|
|
|
if i % 100 == 0:
|
|
|
|
|
print(loss.read())
|
|
|
|
|
t1 = time.time()
|
|
|
|
|
cost = t1 - t0
|
|
|
|
|
print("Epoch cost: %.3f" %cost)
|
|
|
|
|
F.save(net.parameters(), "cache/temp.snapshot")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
net = Net()
|
2020-03-12 09:45:16 +08:00
|
|
|
opt = c_train.SGD(0.01, 0.9)
|
2020-02-17 19:41:45 +08:00
|
|
|
net.loadParameters(F.load("cache/temp.snapshot"))
|
|
|
|
|
opt.append(net.parameters())
|
|
|
|
|
|
2020-03-12 09:45:16 +08:00
|
|
|
import sys
|
|
|
|
|
mnistDataPath = sys.argv[1]
|
|
|
|
|
|
|
|
|
|
mnistDataset = data.mnist.create(mnistDataPath, data.mnist.Train)
|
|
|
|
|
trainLoader = mnistDataset.create_loader(64, True, True, 0)
|
|
|
|
|
testmnistDataset = data.mnist.create(mnistDataPath, data.mnist.Test)
|
|
|
|
|
testLoader = mnistDataset.create_loader(10, True, False, 0)
|
2020-02-17 19:41:45 +08:00
|
|
|
|
|
|
|
|
F.setThreadNumber(4)
|
|
|
|
|
for epoch in range(0, 10):
|
|
|
|
|
trainFunc(trainLoader, net, opt)
|
|
|
|
|
# Save Model
|
|
|
|
|
fileName = 'cache/%d.mnist.mnn' %epoch
|
|
|
|
|
net.train(False)
|
2020-03-12 09:45:16 +08:00
|
|
|
predict = net.forward(F.placeholder([1, 1, 28, 28], F.NC4HW4))
|
2020-02-17 19:41:45 +08:00
|
|
|
print("Save to " + fileName)
|
|
|
|
|
F.save([predict], fileName)
|
|
|
|
|
testFunc(testLoader, net)
|