ID3算法的java实现
ID3算法是经典的决策树学习生成算法。ID3算法的核心是在决策树各个节点上运用信息增益准则选择特征,递归的构建决策树。具体方法是:从根节点(root node)开始,对接点计算所有可能的特征的信息增益,选择信息增益最大的特征作为节点的特征,有该特征的不同取值建立子节点;再对子节点递归的调用以上方法,构建决策树;直到所有的特征的信息增益均很小或者没有特征可以选取为止。最后得到一个决策树。要理解ID3算法,需要先了解一些基本的信息论概念,包括信息量,熵,后验熵,条件熵。
java代码
/** * C4.5决策树数据结构 * @author zhenhua.chen * @Description: TODO * @date 2013-3-1 上午10:47:37 * */
public class TreeNode {
private String nodeName; // 决策树节点名称
private List<String> splitAttributes; // 分裂属性名
private ArrayList<TreeNode> childrenNodes; // 决策树的子节点
private ArrayList<ArrayList<String>> dataSet; // 划分到该节点的数据集
private ArrayList<String> arrributeSet; // 数据集所有属性
public TreeNode(){
childrenNodes = new ArrayList<TreeNode>();
}
public String getNodeName() {
return nodeName;
}
public void setNodeName(String nodeName) {
this.nodeName = nodeName;
}
public List<String> getSplitAttributes() {
return splitAttributes;
}
public void setSplitAttributes(List<String> splitAttributes) {
this.splitAttributes = splitAttributes;
}
public ArrayList<TreeNode> getChildrenNodes() {
return childrenNodes;
}
public void setChildrenNodes(ArrayList<TreeNode> childrenNodes) {
this.childrenNodes = childrenNodes;
}
public ArrayList<ArrayList<String>> getDataSet() {
return dataSet;
}
public void setDataSet(ArrayList<ArrayList<String>> dataSet) {
this.dataSet = dataSet;
}
public ArrayList<String> getArrributeSet() {
return arrributeSet;
}
public void setArrributeSet(ArrayList<String> arrributeSet) {
this.arrributeSet = arrributeSet;
}
}
决策树算法:
/** * 构造决策树的类 * @author zhenhua.chen * @Description: TODO * @date 2013-3-1 下午4:42:07 * */
public class DecisionTree {
/** * 建树类 * @param dataSet * @param attributeSet * @return */
public TreeNode buildTree(ArrayList<ArrayList<String>> dataSet, ArrayList<String> attributeSet) {
TreeNode node = new TreeNode();
node.setDataSet(dataSet);
node.setArrributeSet(attributeSet);
// 根据当前数据集计算决策树的节点
int index = -1;
double gain = 0;
double maxGain = 0;
for(int i = 0; i < attributeSet.size() - 1; i++) {
gain = ComputeUtil.computeEntropy(dataSet, attributeSet.size() - 1) - ComputeUtil.computeConditinalEntropy(dataSet, i);
if(gain > maxGain) {
index = i;
maxGain = gain;
}
}
ArrayList<String> splitAttributes = ComputeUtil.getTypes(dataSet, index); // 获取该节点下的分裂属性
node.setSplitAttributes(splitAttributes);
node.setNodeName(attributeSet.get(index));
// 判断每个属性列是否需要继续分裂
for(int i = 0; i < splitAttributes.size(); i++) {
ArrayList<ArrayList<String>> splitDataSet = ComputeUtil.getDataSet(dataSet, index, splitAttributes.get(i));
// 判断这个分裂子数据集的目标属性是否纯净,如果纯净则结束,否则继续分裂
int desColumn = splitDataSet.get(0).size() - 1; // 目标属性列所在的列号
ArrayList<String> desAttributes = ComputeUtil.getTypes(splitDataSet, desColumn);
TreeNode childNode = new TreeNode();
if(desAttributes.size() == 1) {
childNode.setNodeName(desAttributes.get(0));
} else {
ArrayList<String> newAttributeSet = new ArrayList<String>();
for(String s : attributeSet) { // 删除新属性集合中已作为决策树节点的属性值
if(!s.equals(attributeSet.get(index))) {
newAttributeSet.add(s);
}
}
ArrayList<ArrayList<String>> newDataSet = new ArrayList<ArrayList<String>>();
for(ArrayList<String> data : splitDataSet) { // 除掉columnIndex参数指定的
ArrayList<String> tmp = new ArrayList<String>();
for(int j = 0; j < data.size(); j++) {
if(j != index) {
tmp.add(data.get(j));
}
}
newDataSet.add(tmp);
}
childNode = buildTree(newDataSet, newAttributeSet); // 递归建树
}
node.getChildrenNodes().add(childNode);
}
return node;
}
/** * 打印建好的树 * @param root */
public void printTree(TreeNode root) {
System.out.println("----------------");
if(null != root.getSplitAttributes()) {
System.out.print("分裂节点:" + root.getNodeName());
for(String attr : root.getSplitAttributes()) {
System.out.print("(" + attr + ") ");
}
} else {
System.out.print("分裂节点:" + root.getNodeName());
}
if(null != root.getChildrenNodes()) {
for(TreeNode node : root.getChildrenNodes()) {
printTree(node);
}
}
}
/** * * @Title: searchTree * @Description: 层次遍历树 * @return void * @throws */
public void searchTree(TreeNode root) {
Queue<TreeNode> queue = new LinkedList<TreeNode>();
queue.offer(root);
while(queue.size() != 0) {
TreeNode node = queue.poll();
if(null != node.getSplitAttributes()) {
System.out.print("分裂节点:" + node.getNodeName() + "; ");
for(String attr : node.getSplitAttributes()) {
System.out.print(" (" + attr + ") ");
}
} else {
System.out.print("叶子节点:" + node.getNodeName() + "; ");
}
if(null != node.getChildrenNodes()) {
for(TreeNode nod : node.getChildrenNodes()) {
queue.offer(nod);
}
}
}
}
}
一些util代码:
/** * C4.5算法所需的各类计算方法 * @author zhenhua.chen * @Description: TODO * @date 2013-3-1 上午10:48:47 * */
public class ComputeUtil {
/** * 获取指定数据集中指定属性列的各个类别 * @Title: getTypes * @Description: TODO * @return ArrayList<String> * @throws */
public static ArrayList<String> getTypes(ArrayList<ArrayList<String>> dataSet, int columnIndex) {
ArrayList<String> list = new ArrayList<String>();
for(ArrayList<String> data : dataSet) {
if(!list.contains(data.get(columnIndex))) {
list.add(data.get(columnIndex));
}
}
return list;
}
/** * 获取指定数据集中指定属性列的各个类别及其计数 * @Title: getClassCounts * @Description: TODO * @return Map<String,Integer> * @throws */
public static Map<String, Integer> getTypeCounts(ArrayList<ArrayList<String>> dataSet, int columnIndex) {
Map<String, Integer> map = new HashMap<String, Integer>();
for(ArrayList<String> data : dataSet) {
String key = data.get(columnIndex);
if(map.containsKey(key)) {
map.put(key, map.get(key) + 1);
} else {
map.put(key, 1);
}
}
return map;
}
/** * 获取指定列上指定类别的数据集合(分裂后的数据子集) * @Title: getDataSet * @Description: TODO * @return ArrayList<ArrayList<String>> * @throws */
public static ArrayList<ArrayList<String>> getDataSet(ArrayList<ArrayList<String>> dataSet, int columnIndex, String attribueClass) {
ArrayList<ArrayList<String>> splitDataSet = new ArrayList<ArrayList<String>>();
for(ArrayList<String> data : dataSet) {
if(data.get(columnIndex).equals(attribueClass)) {
splitDataSet.add(data);
}
}
return splitDataSet;
}
/** * 计算指定列(属性)的信息熵 * @Title: computeEntropy * @Description: TODO * @return double * @throws */
public static double computeEntropy(ArrayList<ArrayList<String>> dataSet, int columnIndex) {
Map<String, Integer> map = getTypeCounts(dataSet, columnIndex);
int dataSetSize = dataSet.size();
Iterator<String> keyIter = map.keySet().iterator();
double entropy = 0;
while(keyIter.hasNext()) {
double prob = (double)map.get((String)keyIter.next()) / (double)dataSetSize;
entropy += (-1) * prob * Math.log(prob) / Math.log(2);
}
return entropy;
}
/** * 计算基于指定属性列对目标属性的条件信息熵 */
public static double computeConditinalEntropy(ArrayList<ArrayList<String>> dataSet, int columnIndex) {
Map<String, Integer> map = getTypeCounts(dataSet, columnIndex); // 获取该属性列的所有列别及其计数
double conditionalEntropy = 0; // 条件熵
// 获取根据每个类别分割后的数据集合
Iterator<String> iter = map.keySet().iterator();
while(iter.hasNext()) {
ArrayList<ArrayList<String>> splitDataSet = getDataSet(dataSet, columnIndex, (String)iter.next());
// 计算目标属性列的列索引
int desColumn = 0;
if(splitDataSet.get(0).size() > 0) {
desColumn = splitDataSet.get(0).size() - 1;
}
double probY = (double)splitDataSet.size() / (double)dataSet.size();
Map<String, Integer> map1 = getTypeCounts(splitDataSet, desColumn); //根据分割后的子集计算后验熵
Iterator<String> iter1 = map1.keySet().iterator();
double proteriorEntropy = 0;
while(iter1.hasNext()) {
String key = (String)iter1.next(); // 目标属性列中的一个分类
double posteriorProb = (double)map1.get(key) / (double)splitDataSet.size();
proteriorEntropy += (-1) * posteriorProb * Math.log(posteriorProb) / Math.log(2);
}
conditionalEntropy += probY * proteriorEntropy; // 基于某个分割属性计算条件熵
}
return conditionalEntropy;
}
}
测试代码:
public class Test {
public static void main(String[] args) {
File f = new File("D:/test.txt");
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(f));
String str = null;
try {
str = reader.readLine();
ArrayList<String> attributeList = new ArrayList<String>();
String[] attributes = str.split("\t");
for(int i = 0; i < attributes.length; i++) {
attributeList.add(attributes[i]);
}
ArrayList<ArrayList<String>> dataSet = new ArrayList<ArrayList<String>>();
while((str = reader.readLine()) != null) {
ArrayList<String> tmpList = new ArrayList<String>();
String[] s = str.split("\t");
for(int i = 0; i < s.length; i++) {
tmpList.add(s[i]);
}
dataSet.add(tmpList);
}
DecisionTree dt = new DecisionTree();
TreeNode root = dt.buildTree(dataSet, attributeList);
// dt.printTree(root);
dt.searchTree(root);
} catch (IOException e) {
e.printStackTrace();
}
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
}