TensorFlow教程·决策树和随机森林

百度了一下好像这块的代码不多?找了一个能用的代码(似乎也比较旧了,可能以后用不了)

""" Random Forest. Implement Random Forest algorithm with TensorFlow, and apply it to classify handwritten digit images. This example is using the MNIST database of handwritten digits as training samples (http://yann.lecun.com/exdb/mnist/). Author: Aymeric Damien Project: https://github.com/aymericdamien/TensorFlow-Examples/ """

from __future__ import print_function

import tensorflow as tf
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.ops import resources

# # 忽略所有GPU的代码块,emm大概因为随机森林不能用GPU加速所以不热门?
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = ""

# Import MNIST data,下载使用mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)

tf.reset_default_graph()  # 这个是针对类似jupyter notebook的会话运行

# 参数
num_steps = 1000  # 一共训练几轮
batch_size = 1024  # 一轮包含多少数据
num_classes = 10  # 数字化个数,即分为几类
num_features = 784  # 每张有28*28个像素
num_trees = 100  # 森林有几棵树
max_nodes = 10000  # 最大节点数

# 输入X输出Y数据占位符
X = tf.placeholder(tf.float32, shape=[None, num_features])
# 随机森林的输出必须是数字int,所以
Y = tf.placeholder(tf.int32, shape=[None])

# 随机森林参数
hparams = tensor_forest.ForestHParams(num_classes=num_classes,
                                      num_features=num_features,
                                      num_trees=num_trees,
                                      max_nodes=max_nodes).fill()

# 建立森林
forest_graph = tensor_forest.RandomForestGraphs(hparams)
# Get training graph and loss
train_op = forest_graph.training_graph(X, Y)  # ???反正定义了一个训练结点
loss_op = forest_graph.training_loss(X, Y)  # ???反正和误差有关

# Measure the accuracy
# Returns:A tuple of (probabilities, tree_paths, variance).反正只需要第一个,是一个集合
infer_op, _, _ = forest_graph.inference_graph(X)
correct_prediction = tf.equal(
    tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))  # 返回布尔数组
accuracy_op = tf.reduce_mean(
    tf.cast(correct_prediction, tf.float32))  # 逻辑转浮点取平均,变成正确率

# Initialize the variables (i.e. assign their default value) and forest resources
# 全部初始化
init_vars = tf.group(tf.global_variables_initializer(),
                     resources.initialize_resources(resources.shared_resources()))

# Start TensorFlow session
sess = tf.Session()

# Run the initializer
sess.run(init_vars)

# Training
for i in range(1, num_steps + 1):
    # Prepare Data
    # Get the next batch of MNIST data (only images are needed, not labels)
    batch_x, batch_y = mnist.train.next_batch(batch_size)
    _, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})
    if i % 50 == 0 or i == 1:
        acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})
        print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))

# Test Model测试正确率
test_x, test_y = mnist.test.images, mnist.test.labels
print("Test Accuracy:", sess.run(
    accuracy_op, feed_dict={X: test_x, Y: test_y}))

打印结果

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
INFO:tensorflow:Constructing forest with params = 
INFO:tensorflow:{'num_trees': 100, 'max_nodes': 10000, 'bagging_fraction': 1.0, 'feature_bagging_fraction': 1.0, 'num_splits_to_consider': 28, 'max_fertile_nodes': 0, 'split_after_samples': 250, 'valid_leaf_threshold': 1, 'dominate_method': 'bootstrap', 'dominate_fraction': 0.99, 'model_name': 'all_dense', 'split_finish_name': 'basic', 'split_pruning_name': 'none', 'collate_examples': False, 'checkpoint_stats': False, 'use_running_stats_method': False, 'initialize_average_splits': False, 'inference_tree_paths': False, 'param_file': None, 'split_name': 'less_or_equal', 'early_finish_check_every_samples': 0, 'prune_every_samples': 0, 'num_classes': 10, 'num_features': 784, 'bagged_num_features': 784, 'bagged_features': None, 'regression': False, 'num_outputs': 1, 'num_output_columns': 11, 'base_random_seed': 0, 'leaf_model_type': 0, 'stats_model_type': 0, 'finish_type': 0, 'pruning_type': 0, 'split_type': 0}
Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
INFO:tensorflow:Constructing forest with params = 
INFO:tensorflow:{'num_trees': 100, 'max_nodes': 10000, 'bagging_fraction': 1.0, 'feature_bagging_fraction': 1.0, 'num_splits_to_consider': 28, 'max_fertile_nodes': 0, 'split_after_samples': 250, 'valid_leaf_threshold': 1, 'dominate_method': 'bootstrap', 'dominate_fraction': 0.99, 'model_name': 'all_dense', 'split_finish_name': 'basic', 'split_pruning_name': 'none', 'collate_examples': False, 'checkpoint_stats': False, 'use_running_stats_method': False, 'initialize_average_splits': False, 'inference_tree_paths': False, 'param_file': None, 'split_name': 'less_or_equal', 'early_finish_check_every_samples': 0, 'prune_every_samples': 0, 'num_classes': 10, 'num_features': 784, 'bagged_num_features': 784, 'bagged_features': None, 'regression': False, 'num_outputs': 1, 'num_output_columns': 11, 'base_random_seed': 0, 'leaf_model_type': 0, 'stats_model_type': 0, 'finish_type': 0, 'pruning_type': 0, 'split_type': 0}
Step 1, Loss: -1.720000, Acc: 0.512695
Step 50, Loss: -256.399994, Acc: 0.917969
Step 100, Loss: -542.260010, Acc: 0.942383
Step 150, Loss: -830.960022, Acc: 0.940430
Step 200, Loss: -1118.959961, Acc: 0.955078
Step 250, Loss: -1403.420044, Acc: 0.958984
Step 300, Loss: -1684.959961, Acc: 0.961914
Step 350, Loss: -1964.099976, Acc: 0.961914
Step 400, Loss: -2236.679932, Acc: 0.971680
Step 450, Loss: -2511.560059, Acc: 0.964844
Step 500, Loss: -2782.419922, Acc: 0.969727
Step 550, Loss: -3045.379883, Acc: 0.965820
Step 600, Loss: -3310.080078, Acc: 0.974609
Step 650, Loss: -3568.800049, Acc: 0.972656
Step 700, Loss: -3827.020020, Acc: 0.977539
Step 750, Loss: -4081.060059, Acc: 0.981445
Step 800, Loss: -4331.720215, Acc: 0.971680
Step 850, Loss: -4576.040039, Acc: 0.981445
Step 900, Loss: -4818.740234, Acc: 0.982422
Step 950, Loss: -5061.339844, Acc: 0.976562
Step 1000, Loss: -5297.240234, Acc: 0.985352
Step 1050, Loss: -5529.259766, Acc: 0.973633
Step 1100, Loss: -5757.000000, Acc: 0.982422
Step 1150, Loss: -5988.939941, Acc: 0.984375
Step 1200, Loss: -6210.520020, Acc: 0.978516
Step 1250, Loss: -6429.000000, Acc: 0.986328
Step 1300, Loss: -6646.479980, Acc: 0.988281
Step 1350, Loss: -6861.580078, Acc: 0.987305
Step 1400, Loss: -7070.700195, Acc: 0.988281
Step 1450, Loss: -7282.620117, Acc: 0.985352
Step 1500, Loss: -7488.779785, Acc: 0.982422
Step 1550, Loss: -7688.120117, Acc: 0.984375
Step 1600, Loss: -7887.859863, Acc: 0.990234
Step 1650, Loss: -8085.560059, Acc: 0.991211
Step 1700, Loss: -8279.139648, Acc: 0.984375
Step 1750, Loss: -8469.519531, Acc: 0.997070
Step 1800, Loss: -8659.879883, Acc: 0.990234
Step 1850, Loss: -8843.639648, Acc: 0.985352
Step 1900, Loss: -9023.559570, Acc: 0.993164
Step 1950, Loss: -9203.799805, Acc: 0.993164
Step 2000, Loss: -9381.019531, Acc: 0.994141
Step 2050, Loss: -9556.320312, Acc: 0.994141
Step 2100, Loss: -9725.719727, Acc: 0.988281
Step 2150, Loss: -9875.740234, Acc: 0.997070
Step 2200, Loss: -9970.719727, Acc: 0.994141
Step 2250, Loss: -9998.120117, Acc: 0.995117
Step 2300, Loss: -10001.000000, Acc: 0.992188
Step 2350, Loss: -10001.000000, Acc: 0.988281
Step 2400, Loss: -10001.000000, Acc: 0.997070
Step 2450, Loss: -10001.000000, Acc: 0.995117
Step 2500, Loss: -10001.000000, Acc: 0.996094
Step 2550, Loss: -10001.000000, Acc: 0.994141
Step 2600, Loss: -10001.000000, Acc: 0.992188
Step 2650, Loss: -10001.000000, Acc: 0.993164
Step 2700, Loss: -10001.000000, Acc: 0.998047
Step 2750, Loss: -10001.000000, Acc: 0.994141
Step 2800, Loss: -10001.000000, Acc: 0.997070
Step 2850, Loss: -10001.000000, Acc: 0.993164
Step 2900, Loss: -10001.000000, Acc: 0.994141
Step 2950, Loss: -10001.000000, Acc: 0.996094
Step 3000, Loss: -10001.000000, Acc: 0.998047
Step 3050, Loss: -10001.000000, Acc: 0.991211
Step 3100, Loss: -10001.000000, Acc: 0.996094
Step 3150, Loss: -10001.000000, Acc: 0.994141
Step 3200, Loss: -10001.000000, Acc: 0.994141
Step 3250, Loss: -10001.000000, Acc: 0.995117
Step 3300, Loss: -10001.000000, Acc: 0.995117
Step 3350, Loss: -10001.000000, Acc: 0.992188
Step 3400, Loss: -10001.000000, Acc: 0.994141
Step 3450, Loss: -10001.000000, Acc: 0.994141
Step 3500, Loss: -10001.000000, Acc: 0.994141
Step 3550, Loss: -10001.000000, Acc: 0.991211
Step 3600, Loss: -10001.000000, Acc: 0.994141
Step 3650, Loss: -10001.000000, Acc: 0.994141
Step 3700, Loss: -10001.000000, Acc: 0.991211
Step 3750, Loss: -10001.000000, Acc: 0.994141
Step 3800, Loss: -10001.000000, Acc: 0.992188
Step 3850, Loss: -10001.000000, Acc: 0.995117
Step 3900, Loss: -10001.000000, Acc: 0.991211
Step 3950, Loss: -10001.000000, Acc: 0.990234
Step 4000, Loss: -10001.000000, Acc: 0.991211
Step 4050, Loss: -10001.000000, Acc: 0.988281
Step 4100, Loss: -10001.000000, Acc: 0.995117
Step 4150, Loss: -10001.000000, Acc: 0.991211
Step 4200, Loss: -10001.000000, Acc: 0.993164
Step 4250, Loss: -10001.000000, Acc: 0.992188
Step 4300, Loss: -10001.000000, Acc: 0.996094
Step 4350, Loss: -10001.000000, Acc: 0.995117
Step 4400, Loss: -10001.000000, Acc: 0.994141
Step 4450, Loss: -10001.000000, Acc: 0.992188
Step 4500, Loss: -10001.000000, Acc: 0.992188
Step 4550, Loss: -10001.000000, Acc: 0.996094
Step 4600, Loss: -10001.000000, Acc: 0.997070
Step 4650, Loss: -10001.000000, Acc: 0.993164
Step 4700, Loss: -10001.000000, Acc: 0.997070
Step 4750, Loss: -10001.000000, Acc: 0.990234
Step 4800, Loss: -10001.000000, Acc: 0.993164
Step 4850, Loss: -10001.000000, Acc: 0.992188
Step 4900, Loss: -10001.000000, Acc: 0.994141
Step 4950, Loss: -10001.000000, Acc: 0.992188
Step 5000, Loss: -10001.000000, Acc: 0.993164
Test Accuracy: 0.963

emm速度和正确率比卷积差多了

对于随机森林的通俗理解 – CSDN博客blog.csdn.net《TensorFlow教程·决策树和随机森林》
《TensorFlow教程·决策树和随机森林》
《TensorFlow教程·决策树和随机森林》 图源网络

森林里有100颗分类的树,每棵树把X分进十个结点中的一个结点。看哪一个结点的结果多,森林的结果就是这个结点对应的标签。

至于具体怎么种树怎么剪枝,TensorFlow已经封装好了。

  1. 从784个像素点随机抽出一些像素,考察它们每一个像素对于结果的离散程度(即纯度)的作用,越强越好(即分类越有效),改变权值,使得树最优化。
  2. 建立不同的100颗树,组成森林。
    原文作者:质乎
    原文地址: https://zhuanlan.zhihu.com/p/41509731
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞