R语言 决策树及其实现

一颗决策树包含一个根结点、若干个内部结点和若干个叶结点;叶结点对应于决策结果,其他每个结点则对应于一个属性测试;每个结点包含的样本集合根据属性测试的结果被划分到子结点中;根结点包含样本全集。从根结点到叶结点的路径对应于了一个判定测试序列。

目的:为了产生一颗泛化能力强,即处理未见示例能力强的据决策树。

特别注意几点:

1)通常所说的属性是离散,若属性是连续,则要把属性离散化,最简单的是是采用二分法(找划分点)

2)缺失值处理

决策树是一个递归过程,以下三种情形会导致递归返回:

1)当前结点包含的样本属于同一类别,无需划分;

2)当前属性集为空,或是所有样本在所有属性上取值相同,无法划分;

3)当前结点包含的样本集合为空,不能划分。

信息增益:一般而言,信息增益越大,则意味着使用属性a来划分所获得的“纯度提升”越大

增益率:与信息增益的原理一样,但增益率可以校正存在偏向于选择取值较多的特征的问题


剪枝处理

1)预剪枝

在决策树生成过程中,对每个结点在划分前先进行估计,若当前结点的划分不能带来决策树泛化性能提升,则停止划分并将当前结点标记为叶结点。


2)后剪枝

先从训练集生成一颗完整的决策树,然后自底向上地对非叶结点进行考察,若将该结点对应的子树替换为叶结点能带来决策树泛化性能提升,则将该子树替换为叶结点。

R语言实现 

  library(C50); library(rpart); library(party); library(rpart.plot)
  library(caret)
  
  # 加载数据
  car <- read.table('./data/car.data', sep = ',')
  colnames(car) <- c('buy', 'main', 'doors', 'capacity', 'lug_boot', 'safety', 'accept')
  
  # 数据集分为测试和训练
  ind <- createDataPartition(car$accept, times = 1, p = 0.75, list = FALSE)
  carTR <- car[ind, ]
  carTE <- car[-ind, ]
  
  # 建立模型

  # 决策树
  # rpart包
  # 在rpart包中有函数rpart.control预剪枝,prune后剪枝
  #
  # 预剪枝:
  # rpart.control对树进行一些设置  
  # minsplit是最小分支节点数,这里指大于等于20,那么该节点会继续分划下去,否则停止  
  # minbucket:树中叶节点包含的最小样本数  
  # maxdepth:决策树最大深度 
  # xval:交叉验证的次数
  # cp全称为complexity parameter,指某个点的复杂度,对每一步拆分,模型的拟合优度必须提高的程度
  #
  # 后剪枝:
  # 主要是调节参数是cp
  # prune函数可以实现最小代价复杂度剪枝法,对于CART的结果,每个节点均输出一个对应的cp
  # prune函数通过设置cp参数来对决策树进行修剪,cp为复杂度系数
  tc <- rpart.control(minsplit = 20, minbucket = 20, maxdepth = 10, xval = 5, cp = 0.005) # 预剪枝
  rpart.model <- rpart(accept ~ ., data = carTR, control = tc)
  rpart.model <- prune(rpart.model, 
                       cp = rpart.model$cptable[which.min(rpart.model$cptable[,"xerror"]),"CP"]) # 后剪枝
  rpart.plot(rpart.model, under = TRUE, faclen = 0, cex = 0.5, main = "决策树") # 画图
  
  # C5.0
  # C5.0包
  c5.0.model <- C5.0(accept ~ ., data = carTR) # C5.0
  plot(c5.0.model)
  
  # 使用ctree函数实现条件推理决策树算法
  # party包
  ctree.model <- ctree(accept ~ ., data = carTR)
  
  # 预测结果,并构建混淆矩阵,查看准确率
  # 构建result,存放预测结果
  result <- data.frame(arithmetic = c('C5.0', 'CART', 'ctree'), errTR = rep(0, 3),errTE = rep(0, 3))
  
  for (i in 1:3) {
    # 预测结果
    carTR_predict <- predict(switch(i, c5.0.model, rpart.model, ctree.model), newdata = carTR,
                             type = switch(i, 'class', 'class', 'response'))
    carTE_predict <- predict(switch(i, c5.0.model, rpart.model, ctree.model), newdata = carTE,
                             type = switch(i, 'class', 'class', 'response'))
    # 混淆矩阵
    tableTR <- table(actual = carTR$accept, predict = carTR_predict)
    tableTE <- table(actual = carTE$accept, predict = carTE_predict)
    
    # 计算误差矩阵
    result[i, 2] <- paste(round((sum(tableTR) - sum(diag(tableTR)))*100/sum(tableTR), 2), '%')
    result[i, 3] <- paste(round((sum(tableTE) - sum(diag(tableTE)))*100/sum(tableTE), 2), '%')
  }
  #查看误差率
> result
  arithmetic  errTR  errTE
1       C5.0 1.16 % 3.25 %
2       CART 5.94 % 7.89 %
3      ctree 4.47 %  5.8 %

点赞