5.8 KiB
Python 中的手写数字识别
原文:https://www.askpython.com/python/examples/handwritten-digit-recognition
你好,初学者!今天在本教程中,我们将学习如何从 sklearn 数据集中已经可用的 MNIST 数据集中识别手写数字。为了识别数字,我们将利用卷积神经网络 (CNN)。
我们先从了解 CNN 是什么开始。
什么是卷积神经网络?
CNN 是基于多层感知器的计算任务的最重要的神经网络模型之一。这些模型在图像处理方面表现尤为出色。例如,手写的识别。手写识别是神经网络最基本和最优秀的用途之一。CNN 模型经过多层训练,可以做出正确的预测
卷积神经网络用例
CNN 在图像处理等领域发挥着重要作用。它对探测和预测有很大的影响。它甚至被用于纳米技术,如制造半导体。在这里,它被用来检测材料中的缺陷。如果 CNN 与 Keras 或 Tensorflow 一起使用,与各种分类算法相比,它给出了最高的准确度。与任何其他数据集相比,CNN 和反向传播架构使 MNIST 数据集具有最高的准确性。通过研究,CNN 每天都在开发新的应用。在德国,提出了使用 CNN 的交通标志识别模型。
手写数字识别数据集的加载和准备
我们将要使用的数据集包含大约60000 张训练图像和 10000 张测试图像。然后我们将数据分别分成训练和测试数据集。
x_train和x_test包含图像的像素代码,而y_test和y_train包含来自0–9的标签,其代表数字,因为数字可以从 0 到 9 变化。
现在,我们需要检查数据集的形状是否可以在 CNN 模型中使用。数据的大小被观察为 (60000,28,28) ,这意味着 60000 个大小为 28×28 像素的图像。
但是为了使用 Keras API,我们需要一个 4 维数组数据集,因此我们需要将 3 维数据转换成 4 维数据集。
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
input_shape = (28, 28, 1)
下一步是标准化数据,首先将数据转换为浮点型,然后除以 255(最大 RGB 码–最小 RGB 码)。
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
构建模型
在本教程中,用户将利用 Keras API 来构建模型,为了做到这一点,我们将从 Keras 导入顺序模型并添加多层,如下所示:
- Conv2D
- 最大池化
- 变平
- 拒绝传统社会的人
- 稠密的
脱落层负责对抗过度拟合,展平层将 2D 阵列展平为 1D 阵列。
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Dropout, Flatten, MaxPooling2D
model = Sequential()
model.add(Conv2D(28, kernel_size=(3,3), input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(128, activation=tf.nn.relu))
model.add(Dropout(0.2))
model.add(Dense(10,activation=tf.nn.softmax))
编译和拟合模型
所以现在我们创造了一个非优化的空 CNN 。然后,我们设置一个具有给定的损失函数的优化器,该优化器利用一个度量,并且通过使用创建的训练数据集来拟合模型。ADAM 优化器优于其他类似的优化器。
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x=x_train,y=y_train, epochs=10)
培训过程的结果如下。
Training Result Handwritten Digit Recog
在使用evaluate函数评估模型时,我们观察到的准确度为 98.4% 。
可视化结果
我们的最后一步是可视化训练模型的结果,并在subplots的帮助下绘制它们。同样的代码和输出如下所示。我们可以看到结果相当准确。
import matplotlib.pyplot as plt
plt.style.use('seaborn')
plt.figure(figsize=(10,10))
plt.subplot(4,4,1)
image_index = 2853
predict = x_test[image_index].reshape(28,28)
pred = model.predict(x_test[image_index].reshape(1, 28, 28, 1))
plt.imshow(x_test[image_index].reshape(28, 28),cmap='Greys')
plt.title("Predicted Label: "+str(pred.argmax()))
plt.subplot(4,4,2)
image_index = 2000
predict = x_test[image_index].reshape(28,28)
pred = model.predict(x_test[image_index].reshape(1, 28, 28, 1))
plt.imshow(x_test[image_index].reshape(28, 28),cmap='Greys')
plt.title("Predicted Label: "+str(pred.argmax()))
plt.subplot(4,4,3)
image_index = 1500
predict = x_test[image_index].reshape(28,28)
pred = model.predict(x_test[image_index].reshape(1, 28, 28, 1))
plt.imshow(x_test[image_index].reshape(28, 28),cmap='Greys')
plt.title("Predicted Label: "+str(pred.argmax()))
plt.subplot(4,4,4)
image_index = 1345
predict = x_test[image_index].reshape(28,28)
pred = model.predict(x_test[image_index].reshape(1, 28, 28, 1))
plt.imshow(x_test[image_index].reshape(28, 28),cmap='Greys')
plt.title("Predicted Label: "+str(pred.argmax()))
Handwritten Digit Recog Output
结论
在本教程中,我们构建了自己的 CNN 集成手写数字识别模型。而且精度出来还挺不错的!
感谢您的阅读!

