接下来的内容,是关于如何实际应用之前编写的ANN,来完成手写数字识别的任务。
准备
首先,需要下载数据集,以用于训练和测试。这里使用缩小版的mnist数据集。训练集有100条数据,测试集有10条数据。大家可以去这里下载。
https://raw.githubusercontent.com/makeyourownneuralnetwork/makeyourownneuralnetwork/master/mnist_dataset/mnist_test_10.csv
https://raw.githubusercontent.com/makeyourownneuralnetwork/makeyourownneuralnetwork/master/mnist_dataset/mnist_train_100.csv
Coding
导入需要的库。
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import scipy.special
from NN import NN
导入数据。
train_file = open("./mnist_train_100.csv", 'r')
train_list = train_file.readlines()
train_file.close()
test_file = open("./mnist_test_10.csv", 'r')
test_list = test_file.readlines()
test_file.close()
初始化模型。
input_nodes = 784
hidden_nodes = 100
output_nodes = 10
learning_rate = 0.3
nn = NN(input_nodes, hidden_nodes, output_nodes, learning_rate)
开始训练。epoch为迭代次数,因为数据集比较小,每轮迭代都会用上整个训练集。这里就只迭代一次好了。
数据集每一行为一条数据,第一个值是标签,也就是这条数据所代表的数字。接下来784(28*28)个值则是每一个像素点,范围是0~255。这里等比例缩小每个像素的值,把范围缩到0~0.99,再加上0.01。最终范围是0.01到1,避免了值为0的“死值”。
标签的值也不能太极端,我们对对应数字的节点的期望值为0.99,而其他节点的期望值为0.01。
epoch = 1
for i in range(epoch):
for record in train_list:
values = record.split(',')
inputs = (np.asfarray(values[1:]) / 255.0 * 0.99) + 0.01
labels = np.zeros(output_nodes) + 0.01
labels[int(values[0])] = 0.99
nn.train(inputs, labels)
这里大家可以尝试使用完整的mnist数据集来训练。也可以尝试更改迭代次数。