ID3算法的java实现

ID3算法的java实现

ID3算法是经典的决策树学习生成算法。ID3算法的核心是在决策树各个节点上运用信息增益准则选择特征,递归的构建决策树。具体方法是:从根节点(root node)开始,对接点计算所有可能的特征的信息增益,选择信息增益最大的特征作为节点的特征,有该特征的不同取值建立子节点;再对子节点递归的调用以上方法,构建决策树;直到所有的特征的信息增益均很小或者没有特征可以选取为止。最后得到一个决策树。要理解ID3算法,需要先了解一些基本的信息论概念,包括信息量,熵,后验熵,条件熵。

java代码

/** * C4.5决策树数据结构 * @author zhenhua.chen * @Description: TODO * @date 2013-3-1 上午10:47:37 * */  
public class TreeNode {  
    private String nodeName; // 决策树节点名称 
    private List<String> splitAttributes; // 分裂属性名 
    private ArrayList<TreeNode> childrenNodes; // 决策树的子节点 
    private ArrayList<ArrayList<String>> dataSet; // 划分到该节点的数据集 
    private ArrayList<String> arrributeSet; // 数据集所有属性 

    public TreeNode(){  
        childrenNodes = new ArrayList<TreeNode>();  
    }  

    public String getNodeName() {  
        return nodeName;  
    }  
    public void setNodeName(String nodeName) {  
        this.nodeName = nodeName;  
    }  
    public List<String> getSplitAttributes() {  
        return splitAttributes;  
    }  
    public void setSplitAttributes(List<String> splitAttributes) {  
        this.splitAttributes = splitAttributes;  
    }  
    public ArrayList<TreeNode> getChildrenNodes() {  
        return childrenNodes;  
    }  
    public void setChildrenNodes(ArrayList<TreeNode> childrenNodes) {  
        this.childrenNodes = childrenNodes;  
    }  
    public ArrayList<ArrayList<String>> getDataSet() {  
        return dataSet;  
    }  
    public void setDataSet(ArrayList<ArrayList<String>> dataSet) {  
        this.dataSet = dataSet;  
    }  
    public ArrayList<String> getArrributeSet() {  
        return arrributeSet;  
    }  
    public void setArrributeSet(ArrayList<String> arrributeSet) {  
        this.arrributeSet = arrributeSet;  
    }  
} 

决策树算法:

/** * 构造决策树的类 * @author zhenhua.chen * @Description: TODO * @date 2013-3-1 下午4:42:07 * */  
public class DecisionTree {  
    /** * 建树类 * @param dataSet * @param attributeSet * @return */  
    public TreeNode buildTree(ArrayList<ArrayList<String>> dataSet, ArrayList<String> attributeSet) {  
        TreeNode node = new TreeNode();  
        node.setDataSet(dataSet);  
        node.setArrributeSet(attributeSet);  

        // 根据当前数据集计算决策树的节点 
        int index = -1;  
        double gain = 0;  
        double maxGain = 0;  
        for(int i = 0; i < attributeSet.size() - 1; i++) {  
            gain = ComputeUtil.computeEntropy(dataSet, attributeSet.size() - 1) - ComputeUtil.computeConditinalEntropy(dataSet, i);  
            if(gain > maxGain) {  
                index = i;  
                maxGain = gain;  
            }  
        }  
        ArrayList<String> splitAttributes = ComputeUtil.getTypes(dataSet, index); // 获取该节点下的分裂属性 
        node.setSplitAttributes(splitAttributes);  
        node.setNodeName(attributeSet.get(index));  

        // 判断每个属性列是否需要继续分裂 
        for(int i = 0; i < splitAttributes.size(); i++) {  
            ArrayList<ArrayList<String>> splitDataSet = ComputeUtil.getDataSet(dataSet, index, splitAttributes.get(i));  

            // 判断这个分裂子数据集的目标属性是否纯净,如果纯净则结束,否则继续分裂 
            int desColumn = splitDataSet.get(0).size() - 1; // 目标属性列所在的列号 
            ArrayList<String> desAttributes = ComputeUtil.getTypes(splitDataSet, desColumn);  
            TreeNode childNode = new TreeNode();  
            if(desAttributes.size() == 1) {  
                childNode.setNodeName(desAttributes.get(0));  
            } else {  
                ArrayList<String> newAttributeSet = new ArrayList<String>();  
                for(String s : attributeSet) { // 删除新属性集合中已作为决策树节点的属性值 
                    if(!s.equals(attributeSet.get(index))) {  
                        newAttributeSet.add(s);  
                    }  
                }  

                ArrayList<ArrayList<String>> newDataSet = new ArrayList<ArrayList<String>>();  
                for(ArrayList<String> data : splitDataSet) { // 除掉columnIndex参数指定的 
                    ArrayList<String> tmp = new ArrayList<String>();  
                    for(int j = 0; j < data.size(); j++) {  
                        if(j != index) {  
                            tmp.add(data.get(j));  
                        }  
                    }  
                    newDataSet.add(tmp);  
                }  

                childNode = buildTree(newDataSet, newAttributeSet); // 递归建树 
            }  
            node.getChildrenNodes().add(childNode);  
        }  
        return node;  
    }  

    /** * 打印建好的树 * @param root */  
    public void printTree(TreeNode root) {  
        System.out.println("----------------");  
        if(null != root.getSplitAttributes()) {  
            System.out.print("分裂节点:" + root.getNodeName());  
            for(String attr : root.getSplitAttributes()) {  
                System.out.print("(" + attr + ") ");  
            }  
        } else {  
            System.out.print("分裂节点:" + root.getNodeName());  
        }  

        if(null != root.getChildrenNodes()) {  
            for(TreeNode node : root.getChildrenNodes()) {  
                printTree(node);  
            }  
        }  

    }  

    /** * * @Title: searchTree * @Description: 层次遍历树 * @return void * @throws */  
    public void searchTree(TreeNode root) {  
        Queue<TreeNode> queue = new LinkedList<TreeNode>();  
        queue.offer(root);  

        while(queue.size() != 0) {  
            TreeNode node = queue.poll();  
            if(null != node.getSplitAttributes()) {  
                System.out.print("分裂节点:" + node.getNodeName() + "; ");   
                for(String attr : node.getSplitAttributes()) {  
                    System.out.print(" (" + attr + ") ");  
                }  
            } else {  
                System.out.print("叶子节点:" + node.getNodeName() + "; ");   
            }  

            if(null != node.getChildrenNodes()) {  
                for(TreeNode nod : node.getChildrenNodes()) {  
                    queue.offer(nod);  
                }  
            }  
        }  
    }  

} 

一些util代码:

/** * C4.5算法所需的各类计算方法 * @author zhenhua.chen * @Description: TODO * @date 2013-3-1 上午10:48:47 * */  
public class ComputeUtil {  

    /** * 获取指定数据集中指定属性列的各个类别 * @Title: getTypes * @Description: TODO * @return ArrayList<String> * @throws */  
    public static ArrayList<String> getTypes(ArrayList<ArrayList<String>> dataSet, int columnIndex) {  
        ArrayList<String> list = new ArrayList<String>();  
        for(ArrayList<String> data : dataSet) {  
            if(!list.contains(data.get(columnIndex))) {  
                list.add(data.get(columnIndex));  
            }  
        }  
        return list;  
    }  

    /** * 获取指定数据集中指定属性列的各个类别及其计数 * @Title: getClassCounts * @Description: TODO * @return Map<String,Integer> * @throws */  
    public static Map<String, Integer> getTypeCounts(ArrayList<ArrayList<String>> dataSet, int columnIndex) {  
        Map<String, Integer> map = new HashMap<String, Integer>();  
        for(ArrayList<String> data : dataSet) {  
            String key = data.get(columnIndex);  
            if(map.containsKey(key)) {  
                map.put(key, map.get(key) + 1);  
            } else {  
                map.put(key, 1);  
            }  
        }  
        return map;  
    }  

    /** * 获取指定列上指定类别的数据集合(分裂后的数据子集) * @Title: getDataSet * @Description: TODO * @return ArrayList<ArrayList<String>> * @throws */  
    public static ArrayList<ArrayList<String>> getDataSet(ArrayList<ArrayList<String>> dataSet, int columnIndex, String attribueClass) {  
        ArrayList<ArrayList<String>> splitDataSet = new ArrayList<ArrayList<String>>();  
        for(ArrayList<String> data : dataSet) {  
            if(data.get(columnIndex).equals(attribueClass)) {  
                splitDataSet.add(data);  
            }  
        }  

        return splitDataSet;  
    }  

    /** * 计算指定列(属性)的信息熵 * @Title: computeEntropy * @Description: TODO * @return double * @throws */  
    public static double computeEntropy(ArrayList<ArrayList<String>> dataSet, int columnIndex) {  
        Map<String, Integer> map = getTypeCounts(dataSet, columnIndex);  
        int dataSetSize = dataSet.size();  
        Iterator<String> keyIter = map.keySet().iterator();  
        double entropy = 0;  
        while(keyIter.hasNext()) {  
            double prob = (double)map.get((String)keyIter.next()) / (double)dataSetSize;  
            entropy += (-1) * prob * Math.log(prob) / Math.log(2);   

        }  
        return entropy;  
    }  

    /** * 计算基于指定属性列对目标属性的条件信息熵 */  
    public static double computeConditinalEntropy(ArrayList<ArrayList<String>> dataSet, int columnIndex) {  
        Map<String, Integer> map = getTypeCounts(dataSet, columnIndex);  // 获取该属性列的所有列别及其计数 

        double conditionalEntropy = 0; // 条件熵 

        // 获取根据每个类别分割后的数据集合 
        Iterator<String> iter = map.keySet().iterator();   
        while(iter.hasNext()) {  
            ArrayList<ArrayList<String>> splitDataSet = getDataSet(dataSet, columnIndex, (String)iter.next());  
            // 计算目标属性列的列索引 
            int desColumn = 0;  
            if(splitDataSet.get(0).size() > 0) {  
                desColumn = splitDataSet.get(0).size() - 1;  
            }  

            double probY = (double)splitDataSet.size() / (double)dataSet.size();  

            Map<String, Integer> map1 = getTypeCounts(splitDataSet, desColumn); //根据分割后的子集计算后验熵 
            Iterator<String> iter1 = map1.keySet().iterator();  
            double proteriorEntropy = 0;  
            while(iter1.hasNext()) {  
                String key = (String)iter1.next(); // 目标属性列中的一个分类 
                double posteriorProb = (double)map1.get(key) / (double)splitDataSet.size();  
                proteriorEntropy += (-1) * posteriorProb * Math.log(posteriorProb) / Math.log(2);  
            }  

            conditionalEntropy += probY * proteriorEntropy; // 基于某个分割属性计算条件熵 
        }  
        return conditionalEntropy;  
    }  
}

测试代码:

public class Test {  
    public static void main(String[] args) {  
        File f = new File("D:/test.txt");  
        BufferedReader reader = null;  

        try {  
            reader = new BufferedReader(new FileReader(f));  
            String str = null;  
            try {  
                str = reader.readLine();   
                ArrayList<String> attributeList = new ArrayList<String>();  
                String[] attributes = str.split("\t");  

                for(int i = 0; i < attributes.length; i++) {  
                    attributeList.add(attributes[i]);  
                }  

                ArrayList<ArrayList<String>> dataSet = new ArrayList<ArrayList<String>>();  
                while((str = reader.readLine()) != null) {  
                    ArrayList<String> tmpList = new ArrayList<String>();  
                    String[] s = str.split("\t");  
                    for(int i = 0; i < s.length; i++) {  
                        tmpList.add(s[i]);  
                    }  
                    dataSet.add(tmpList);  
                }  

                DecisionTree dt = new DecisionTree();  
                TreeNode root = dt.buildTree(dataSet, attributeList);  
// dt.printTree(root); 
                dt.searchTree(root);  

            } catch (IOException e) {  
                e.printStackTrace();  
            }  

        } catch (FileNotFoundException e) {  
            e.printStackTrace();  
        }  
    }  
} 
点赞