Pyhton实现决策树算法 MNIST数据集
决策树是一种比较接近人类思维方式的算法,将样本通过每个特征值的信息增益进行划分,从而保证每个划分之后的结果信息熵的消减量达到最大。具体的原理请大家自己查找相关资料。
sklearn实现代码如下, 准确率可以达到90%左右。
from sklearn import tree
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
train_num = 10000
test_num = 100
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels
if __name__ == '__main__':
# 获得一个决策树分类器
clf = tree.DecisionTreeClassifier()
# 拟合
clf.fit(x_train[:train_num], y_train[:train_num])
# 预测
prediction = clf.predict(x_test[:test_num])
accurancy = np.sum(np.equal(prediction, y_test[:test_num])) / test_num
print('prediction : ', prediction)
print('accurancy : ', accurancy)