Pyhton实现决策树算法 MNIST数据集

Pyhton实现决策树算法 MNIST数据集

决策树是一种比较接近人类思维方式的算法,将样本通过每个特征值的信息增益进行划分,从而保证每个划分之后的结果信息熵的消减量达到最大。具体的原理请大家自己查找相关资料。

sklearn实现代码如下, 准确率可以达到90%左右。

from sklearn import tree
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
train_num = 10000
test_num = 100

x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels

if __name__ == '__main__':
    # 获得一个决策树分类器
    clf = tree.DecisionTreeClassifier()
    # 拟合
    clf.fit(x_train[:train_num], y_train[:train_num])
    # 预测
    prediction = clf.predict(x_test[:test_num])

    accurancy = np.sum(np.equal(prediction, y_test[:test_num])) / test_num
    print('prediction : ', prediction)
    print('accurancy : ', accurancy)

点赞