geekdoc-python-zh/docs/askpython/load-and-plot-mnist-dataset...

155 lines
4.4 KiB
Markdown
Raw Permalink Normal View History

2024-10-20 12:24:46 +08:00
# 如何在 Python 中加载和绘制 MNIST 数据集?
> 原文:<https://www.askpython.com/python/examples/load-and-plot-mnist-dataset-in-python>
本教程讲述了在 Python 中加载 MNIST 数据集的步骤。 **MNIST 数据集**是一个手写数字的大型数据库。它通常用于训练各种图像处理系统。
*MNIST 是美国国家标准技术研究院数据库的缩写。*
该数据集用于训练模型来识别手写数字。这可用于扫描信件上的手写 pin 码。
MNIST 包含了从 T2 0 到 9 的 70000 张 28×28 的手写数字图像。
## 为什么 MNIST 数据集如此受欢迎?
MNIST 受欢迎的原因有很多,这些是:
* MNSIT 数据集**公开提供。**
* 这些数据在使用前几乎不需要处理。
* 这是一个**庞大的**数据集。
此外,该数据集通常用于图像处理和机器学习课程。
## 在 Python 中加载 MNIST 数据集
在本教程中,我们将学习 MNIST 数据集。我们还将了解如何用 python 加载 MNIST 数据集。
### 1.在 Python 中加载数据集
让我们从将数据集加载到 python 笔记本开始。加载数据最简单的方法是通过 Keras。
```py
from keras.datasets import mnist
```
MNIST 数据集由训练数据和测试数据组成。每个图像存储在 28X28 中,相应的输出是图像中的数字。
我们可以通过观察训练和测试数据的形状来验证这一点。
要将数据加载到变量中,请使用:
```py
(train_X, train_y), (test_X, test_y) = mnist.load_data()
```
要打印训练和测试向量的形状,请使用:
```py
print('X_train: ' + str(train_X.shape))
print('Y_train: ' + str(train_y.shape))
print('X_test: ' + str(test_X.shape))
print('Y_test: ' + str(test_y.shape))
```
我们得到以下输出:
```py
X_train: (60000, 28, 28)
Y_train: (60000,)
X_test: (10000, 28, 28)
Y_test: (10000,)
```
由此我们可以得出以下关于 MNIST 数据集的结论:
* 训练集包含 60k 图像,测试集包含 10k 图像。
* 训练输入向量的尺寸为**【60000 X 28 X 28】。**
* 训练输出向量的大小为**【60000 X 1】。**
* 每个单独的输入向量的大小为**【28 X 28】。**
* 每个单独的输出向量的维数为[ **1]**
### 2.绘制 MNIST 数据集
让我们尝试显示 MNIST 数据集中的图像。从导入 **[Matplotlib](https://www.askpython.com/python-modules/matplotlib/python-matplotlib) 开始。**
```py
from matplotlib import pyplot
```
要绘制数据,请使用以下代码:
```py
from matplotlib import pyplot
for i in range(9):
pyplot.subplot(330 + 1 + i)
pyplot.imshow(train_X[i], cmap=pyplot.get_cmap('gray'))
pyplot.show()
```
输出结果如下:
![Mnist Dataset](img/9715f1f5b262dcab77029814738d0345.png)
Mnist Dataset
## 用 Python 加载和绘制 MNIST 数据集的完整代码
本教程的完整代码如下所示:
```py
from keras.datasets import mnist
from matplotlib import pyplot
#loading
(train_X, train_y), (test_X, test_y) = mnist.load_data()
#shape of dataset
print('X_train: ' + str(train_X.shape))
print('Y_train: ' + str(train_y.shape))
print('X_test: ' + str(test_X.shape))
print('Y_test: ' + str(test_y.shape))
#plotting
from matplotlib import pyplot
for i in range(9):
pyplot.subplot(330 + 1 + i)
pyplot.imshow(train_X[i], cmap=pyplot.get_cmap('gray'))
pyplot.show()
```
## 下一步是什么?
现在您已经导入了 MNIST 数据集,您可以将其用于影像分类。
当谈到图像分类的任务时,没有什么可以击败卷积神经网络(CNN)。CNN 包含**卷积层、汇聚层、扁平化层**。
让我们看看每一层都做了什么。
### 1.卷积层
卷积层使用较小的像素过滤器过滤图像。这将减小图像的大小,而不会丢失像素之间的关系。
### 2.汇集层
池层的主要工作是减少卷积后图像的空间大小。
池层通过选择像素内的最大值、平均值或和值来减少参数的数量。
**最大池化**是最常用的池化技术。
### 3.展平层
展平层将多维像素向量表示为一维像素向量。
## 结论
本教程是关于将 MNIST 数据集加载到 python 中的。我们研究了 MNIST 数据集,并简要讨论了可用于 MNIST 数据集图像分类的 CNN 网络。
如果你想进一步了解 Python 中的图像处理,请通读这篇教程,学习如何使用 OpenCV 在 Python 中[读取图像。](https://www.askpython.com/python-modules/read-images-in-python-opencv)