# 决策树ID3（Java实现）

table 1

 outlook temperature humidity windy play sunny hot high FALSE no sunny hot high TRUE no overcast hot high FALSE yes rainy mild high FALSE yes rainy cool normal FALSE yes rainy cool normal TRUE no overcast cool normal TRUE yes sunny mild high FALSE no sunny cool normal FALSE yes rainy mild normal FALSE yes sunny mild normal TRUE yes overcast mild high TRUE yes overcast hot normal FALSE yes rainy mild high TRUE no

### ID3算法

table 2

 outlook temperature humidity windy play yes no yes no yes no yes no yes no sunny 2 3 hot 2 2 high 3 4 FALSE 6 2 9 5 overcast 4 0 mild 4 2 normal 6 1 TRUR 3 3 rainy 3 2 cool 3 1

outlook=sunny时，2/5的概率打球，3/5的概率不打球。entropy=0.971

outlook=overcast时，entropy=0

outlook=rainy时，entropy=0.971

gain(outlook)最大（即outlook在第一步使系统的信息熵下降得最快），所以决策树的根节点就取outlook。

### Java实现

 `@relation weather.symbolic`   `@attribute outlook {sunny, overcast, rainy}` `@attribute temperature {hot, mild, cool}` `@attribute humidity {high, normal}` `@attribute windy {TRUE, FALSE}` `@attribute play {yes, no}`   `@data` `sunny,hot,high,FALSE,no` `sunny,hot,high,TRUE,no` `overcast,hot,high,FALSE,yes` `rainy,mild,high,FALSE,yes` `rainy,cool,normal,FALSE,yes` `rainy,cool,normal,TRUE,no` `overcast,cool,normal,TRUE,yes` `sunny,mild,high,FALSE,no` `sunny,cool,normal,FALSE,yes` `rainy,mild,normal,FALSE,yes` `sunny,mild,normal,TRUE,yes` `overcast,mild,high,TRUE,yes` `overcast,hot,normal,FALSE,yes` `rainy,mild,high,TRUE,no`

 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 `package`  `dt;`   `import`  `java.io.BufferedReader;` `import`  `java.io.File;` `import`  `java.io.FileReader;` `import`  `java.io.FileWriter;` `import`  `java.io.IOException;` `import`  `java.util.ArrayList;` `import`  `java.util.Iterator;` `import`  `java.util.LinkedList;` `import`  `java.util.List;` `import`  `java.util.regex.Matcher;` `import`  `java.util.regex.Pattern;`   `import`  `org.dom4j.Document;` `import`  `org.dom4j.DocumentHelper;` `import`  `org.dom4j.Element;` `import`  `org.dom4j.io.OutputFormat;` `import`  `org.dom4j.io.XMLWriter;`   `public`  `class`  `ID3 {` `    ` `private`  `ArrayList attribute = ` `new`  `ArrayList(); ` `// 存储属性的名称` `    ` `private`  `ArrayList> attributevalue = ` `new`  `ArrayList>(); ` `// 存储每个属性的取值` `    ` `private`  `ArrayList data = ` `new`  `ArrayList();; ` `// 原始数据` `    ` `int`  `decatt; ` `// 决策变量在属性集中的索引` `    ` `public`  `static`  `final`  `String patternString = ` `"@attribute(.*)[{](.*?)[}]"` `;`   `    ` `Document xmldoc;` `    ` `Element root;`   `    ` `public`  `ID3() {` `        ` `xmldoc = DocumentHelper.createDocument();` `        ` `root = xmldoc.addElement(` `"root"` `);` `        ` `root.addElement(` `"DecisionTree"` `).addAttribute(` `"value"` `, ` `"null"` `);` `    ` `}`   `    ` `public`  `static`  `void`  `main(String[] args) {` `        ` `ID3 inst = ` `new`  `ID3();` `        ` `inst.readARFF(` `new`  `File(` `"/home/orisun/test/weather.nominal.arff"` `));` `        ` `inst.setDec(` `"play"` `);` `        ` `LinkedList ll=` `new`  `LinkedList();` `        ` `for` `(` `int`  `i=` `0` `;i al=` `new`  `ArrayList();` `        ` `for` `(` `int`  `i=` `0` `;i al = ` `new`  `ArrayList(values.length);` `                    ` `for`  `(String value : values) {` `                        ` `al.add(value.trim());` `                    ` `}` `                    ` `attributevalue.add(al);` `                ` `} ` `else`  `if`  `(line.startsWith(` `"@data"` `)) {` `                    ` `while`  `((line = br.readLine()) != ` `null` `) {` `                        ` `if` `(line==` `""` `)` `                            ` `continue` `;` `                        ` `String[] row = line.split(` `","` `);` `                        ` `data.add(row);` `                    ` `}` `                ` `} ` `else`  `{` `                    ` `continue` `;` `                ` `}` `            ` `}` `            ` `br.close();` `        ` `} ` `catch`  `(IOException e1) {` `            ` `e1.printStackTrace();` `        ` `}` `    ` `}`   `    ` `//设置决策变量` `    ` `public`  `void`  `setDec(` `int`  `n) {` `        ` `if`  `(n < ` `0`  `|| n >= attribute.size()) {` `            ` `System.err.println(` `"决策变量指定错误。"` `);` `            ` `System.exit(` `2` `);` `        ` `}` `        ` `decatt = n;` `    ` `}` `    ` `public`  `void`  `setDec(String name) {` `        ` `int`  `n = attribute.indexOf(name);` `        ` `setDec(n);` `    ` `}`   `    ` `//给一个样本（数组中是各种情况的计数），计算它的熵` `    ` `public`  `double`  `getEntropy(` `int` `[] arr) {` `        ` `double`  `entropy = ` `0.0` `;` `        ` `int`  `sum = ` `0` `;` `        ` `for`  `(` `int`  `i = ` `0` `; i < arr.length; i++) {` `            ` `entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(` `2` `);` `            ` `sum += arr[i];` `        ` `}` `        ` `entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(` `2` `);` `        ` `entropy /= sum;` `        ` `return`  `entropy;` `    ` `}`   `    ` `//给一个样本数组及样本的算术和，计算它的熵` `    ` `public`  `double`  `getEntropy(` `int` `[] arr, ` `int`  `sum) {` `        ` `double`  `entropy = ` `0.0` `;` `        ` `for`  `(` `int`  `i = ` `0` `; i < arr.length; i++) {` `            ` `entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(` `2` `);` `        ` `}` `        ` `entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(` `2` `);` `        ` `entropy /= sum;` `        ` `return`  `entropy;` `    ` `}`   `    ` `public`  `boolean`  `infoPure(ArrayList subset) {` `        ` `String value = data.get(subset.get(` `0` `))[decatt];` `        ` `for`  `(` `int`  `i = ` `1` `; i < subset.size(); i++) {` `            ` `String next=data.get(subset.get(i))[decatt];` `            ` `//equals表示对象内容相同，==表示两个对象指向的是同一片内存` `            ` `if`  `(!value.equals(next))` `                ` `return`  `false` `;` `        ` `}` `        ` `return`  `true` `;` `    ` `}`   `    ` `// 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵` `    ` `public`  `double`  `calNodeEntropy(ArrayList subset, ` `int`  `index) {` `        ` `int`  `sum = subset.size();` `        ` `double`  `entropy = ` `0.0` `;` `        ` `int` `[][] info = ` `new`  `int` `[attributevalue.get(index).size()][];` `        ` `for`  `(` `int`  `i = ` `0` `; i < info.length; i++)` `            ` `info[i] = ` `new`  `int` `[attributevalue.get(decatt).size()];` `        ` `int` `[] count = ` `new`  `int` `[attributevalue.get(index).size()];` `        ` `for`  `(` `int`  `i = ` `0` `; i < sum; i++) {` `            ` `int`  `n = subset.get(i);` `            ` `String nodevalue = data.get(n)[index];` `            ` `int`  `nodeind = attributevalue.get(index).indexOf(nodevalue);` `            ` `count[nodeind]++;` `            ` `String decvalue = data.get(n)[decatt];` `            ` `int`  `decind = attributevalue.get(decatt).indexOf(decvalue);` `            ` `info[nodeind][decind]++;` `        ` `}` `        ` `for`  `(` `int`  `i = ` `0` `; i < info.length; i++) {` `            ` `entropy += getEntropy(info[i]) * count[i] / sum;` `        ` `}` `        ` `return`  `entropy;` `    ` `}`   `    ` `// 构建决策树` `    ` `public`  `void`  `buildDT(String name, String value, ArrayList subset,` `            ` `LinkedList selatt) {` `        ` `Element ele = ` `null` `;` `        ` `@SuppressWarnings` `(` `"unchecked"` `)` `        ` `List list = root.selectNodes(` `"//"` `+name);` `        ` `Iterator iter=list.iterator();` `        ` `while` `(iter.hasNext()){` `            ` `ele=iter.next();` `            ` `if` `(ele.attributeValue(` `"value"` `).equals(value))` `                ` `break` `;` `        ` `}` `        ` `if`  `(infoPure(subset)) {` `            ` `ele.setText(data.get(subset.get(` `0` `))[decatt]);` `            ` `return` `;` `        ` `}` `        ` `int`  `minIndex = -` `1` `;` `        ` `double`  `minEntropy = Double.MAX_VALUE;` `        ` `for`  `(` `int`  `i = ` `0` `; i < selatt.size(); i++) {` `            ` `if`  `(i == decatt)` `                ` `continue` `;` `            ` `double`  `entropy = calNodeEntropy(subset, selatt.get(i));` `            ` `if`  `(entropy < minEntropy) {` `                ` `minIndex = selatt.get(i);` `                ` `minEntropy = entropy;` `            ` `}` `        ` `}` `        ` `String nodeName = attribute.get(minIndex);` `        ` `selatt.remove(` `new`  `Integer(minIndex));` `        ` `ArrayList attvalues = attributevalue.get(minIndex);` `        ` `for`  `(String val : attvalues) {` `            ` `ele.addElement(nodeName).addAttribute(` `"value"` `, val);` `            ` `ArrayList al = ` `new`  `ArrayList();` `            ` `for`  `(` `int`  `i = ` `0` `; i < subset.size(); i++) {` `                ` `if`  `(data.get(subset.get(i))[minIndex].equals(val)) {` `                    ` `al.add(subset.get(i));` `                ` `}` `            ` `}` `            ` `buildDT(nodeName, val, al, selatt);` `        ` `}` `    ` `}`   `    ` `// 把xml写入文件` `    ` `public`  `void`  `writeXML(String filename) {` `        ` `try`  `{` `            ` `File file = ` `new`  `File(filename);` `            ` `if`  `(!file.exists())` `                ` `file.createNewFile();` `            ` `FileWriter fw = ` `new`  `FileWriter(file);` `            ` `OutputFormat format = OutputFormat.createPrettyPrint(); ` `// 美化格式` `            ` `XMLWriter output = ` `new`  `XMLWriter(fw, format);` `            ` `output.write(xmldoc);` `            ` `output.close();` `        ` `} ` `catch`  `(IOException e) {` `            ` `System.out.println(e.getMessage());` `        ` `}` `    ` `}` `}`

 ``   `` `  ` `` `    ` `` `      ` `no` `      ` `yes` `    ` `` `    ` `yes` `    ` `` `      ` `no` `      ` `yes` `    ` `` `  ` `` ``

原文作者：B树
原文地址: https://blog.csdn.net/B_H_L/article/details/9214681
本文转自网络文章，转载此文章仅为分享知识，如有侵权，请联系博主进行删除。