Java学习(Day 31)

学习来源:日撸 Java 三百行(61-70天,决策树与集成学习)_闵帆的博客-CSDN博客

  • 一、决策树
  • 二、构建步骤
  • 三、具体实现
    • 1. 想法思路
    • 2. 具体代码
    • 3. 运行截图
  • 四、总结


决策树是一种机器学习的方法. 决策树的生成算法有 ID3, C4.5 和 C5.0 等. 本文主要介绍 ID3 生成算法.

决策树是一种树形结构, 其中每个内部节点表示一个属性上的判断, 每个分支代表一个判断结果的输出, 最后每个叶节点代表一种分类结果.

例如在下图中, 使用 Yes 和 No 来表示出去玩和不出去玩两个分类结果. 首先对属性 Outlook 判断, Outlook 表示天气, 具有 Sunny, Overcast, Rain三个判断值. 如果属性 Outlook 判断值是 Sunny, 接下来就要对 Humidity 属性判断. 如果 Outlook 的判断值是 Overcast, 就可以直接输出分类结果为 Yes.



1.节点的分裂:一般当一个节点所代表的属性无法给出判断时, 则选择将这一节点分成2个子节点(如不是二叉树的情况会分成 n 个子节点). 在上面的示例中对 Outlook 属性的判断就将其分裂成了三个子节点.

2.谁做父节点:这个的意思就是为什么要选 Outlook 属性做为这棵决策树的根节点.这就是之前所提到的 ID3 生成算法.

ID3: 由增熵 (Entropy) 原理来决定哪个做父节点, 哪个节点需要分裂. 对于一组数据, 熵越小说明分类结果越好. 熵定义如下:

E n t r o p y = − ∑ i = 1 n p ( x i ) log ⁡ p ( x i ) Entropy = - \sum_{i=1}^{n}p(x_i)\log{p(x_i)} Entropy=i=1np(xi)logp(xi)

其中 p ( x i ) p(x_i) p(xi) 表示随机事件为 x i x_i xi 的概率.

例如在数据集 weather.arff 中

@relation weather
@attribute Outlook {Sunny, Overcast, Rain}
@attribute Temperature {Hot, Mild, Cool}
@attribute Humidity {High, Normal, Low}
@attribute Windy {FALSE, TRUE}
@attribute Play {N, P}

含有 Sunny 的实例有 5 个, 其中 Play 结果为 N 的有 3 个, 结果为 P 的有 2 个. 那么就有


p ( P l a y = N ∣ O u t l o o k = S u n n y ) = 3 5 (1) p(Play = N | Outlook = Sunny) = \frac{3}{5} \tag{1} p(Play=NOutlook=Sunny)=53(1)

p ( P l a y = P ∣ O u t l o o k = S u n n y ) = 2 5 (2) p(Play = P | Outlook = Sunny) = \frac{2}{5} \tag{2} p(Play=POutlook=Sunny)=52(2)

(1) (2) 式子就是之前所提到的 p ( x i ) p(x_i) p(xi)

三、具体实现 1. 想法思路

和一般构建一颗树是一样的, 先要确立一个父节点. 在决策树中父节点的选择是通过之前增熵对每个属性处理后来选择. 当一个属性选择后就会从整体中删除表示不能够再使用.那么实际步骤就是一个递归的过程.


2. 具体代码
package decisiontree;

import java.util.Arrays;

import weka.core.*;

 * The ID3 decision tree inductive algorithm.
 * @author Shi-Huai Wen Email: [email protected].
public class ID3 {
     * The data.
    Instances dataset;

     * Is this dataset pure (only one label)?
    boolean pure;

     * The number of classes. For binary classification it is 2.
    int numClasses;

     * Available instances. Other instances do not belong this branch.
    int[] availableInstances;

     * Available attributes. Other attributes have been selected in the path
     * from the root.
    int[] availableAttributes;

     * The selected attribute.
    int splitAttribute;

     * The children nodes.
    ID3[] children;

     * My label. Inner nodes also have a label. For example,  never appear in the training data, but 
     * is valid in other cases.
    int label;

     * Small block cannot be split further.
    static int smallBlockThreshold = 3;

     * *******************
     * The constructor.
     * @param paraFilename The given file.
     * *******************
    public ID3(String paraFilename) {
        dataset = null;
        try {
            FileReader fileReader = new FileReader(paraFilename);
            dataset = new Instances(fileReader);
        } catch (Exception ee) {
            System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee);
        } // Of try

        dataset.setClassIndex(dataset.numAttributes() - 1);
        numClasses = dataset.classAttribute().numValues();

        availableInstances = new int[dataset.numInstances()];
        for (int i = 0; i < availableInstances.length; i++) {
            availableInstances[i] = i;
        } // Of for i
        availableAttributes = new int[dataset.numAttributes() - 1];
        for (int i = 0; i < availableAttributes.length; i++) {
            availableAttributes[i] = i;
        } // Of for i

        // Initialize.
        children = null;
        // Determine the label by simple voting.
        label = getMajorityClass(availableInstances);
        // Determine whether or not it is pure.
        pure = pureJudge(availableInstances);
    }// Of the first constructor

     * *******************
     * The constructor.
     * @param paraDataset The given dataset.
     * *******************
    public ID3(Instances paraDataset, int[] paraAvailableInstances, int[] paraAvailableAttributes) {
        // Copy its reference instead of clone the availableInstances.
        dataset = paraDataset;
        availableInstances = paraAvailableInstances;
        availableAttributes = paraAvailableAttributes;

        // Initialize.
        children = null;
        // Determine the label by simple voting.
        label = getMajorityClass(availableInstances);
        // Determine whether or not it is pure.
        pure = pureJudge(availableInstances);
    }// Of the second constructor

     * *********************************
     * Is the given block pure?
     * @param paraBlock The block.
     * @return True if pure.
     * *********************************
    public boolean pureJudge(int[] paraBlock) {
        pure = true;

        // Just compare with 0
        for (int i = 1; i < paraBlock.length; i++) {
            if (dataset.instance(paraBlock[i]).classValue() != dataset.instance(paraBlock[0]).classValue()) {
                pure = false;
            } // Of if
        } // Of for i

        return pure;
    }// Of pureJudge

     * *********************************
     * Compute the majority class of the given block for voting.
     * @param paraBlock The block.
     * @return The majority class.
     * *********************************
    public int getMajorityClass(int[] paraBlock) {
        int[] tempClassCounts = new int[dataset.numClasses()];
        for (int i : paraBlock) {
            tempClassCounts[(int) dataset.instance(i).classValue()]++;
        } // Of foreach

        int resultMajorityClass = -1;
        int tempMaxCount = -1;

        for (int i = 0; i < tempClassCounts.length; i++) {
            if (tempMaxCount < tempClassCounts[i]) {
                resultMajorityClass = i;
                tempMaxCount = tempClassCounts[i];
            } // Of if
        } // Of for i

        return resultMajorityClass;
    }// Of getMajorityClass

     * *********************************
     * Select the best attribute.
     * @return The best attribute index.
     * *********************************
    public int selectBestAttribute() {
        splitAttribute = -1;
        double tempMinimalEntropy = 10000;
        double tempEntropy;
        for (int availableAttribute : availableAttributes) {
            tempEntropy = conditionalEntropy(availableAttribute);
            if (tempMinimalEntropy > tempEntropy) {
                tempMinimalEntropy = tempEntropy;
                splitAttribute = availableAttribute;
            } // Of if
        } // Of foreach
        return splitAttribute;
    }// Of selectBestAttribute

     * *********************************
     * Compute the conditional entropy of an attribute.
     * @param paraAttribute The given attribute.
     * @return The entropy.
     * *********************************
    public double conditionalEntropy(int paraAttribute) {
        // Step 1. Statistics.
        int tempNumClasses = dataset.numClasses();
        int tempNumValues = dataset.attribute(paraAttribute).numValues();
        int tempNumInstances = availableInstances.length;
        double[] tempValueCounts = new double[tempNumValues];
        double[][] tempCountMatrix = new double[tempNumValues][tempNumClasses];

        int tempClass, tempValue;
        for (int availableInstance : availableInstances) {
            tempClass = (int) dataset.instance(availableInstance).classValue();
            tempValue = (int) dataset.instance(availableInstance).value(paraAttribute);
        } // Of for i

        // Step 2.
        double resultEntropy = 0;
        double tempEntropy, tempFraction;
        for (int i = 0; i < tempNumValues; i++) {
            if (tempValueCounts[i] == 0) {
            } // Of if
            tempEntropy = 0;
            for (int j = 0; j < tempNumClasses; j++) {
                tempFraction = tempCountMatrix[i][j] / tempValueCounts[i];
                if (tempFraction == 0) {
                } // Of if

                // 信息熵越小, 信息的纯度越高, 信息量就越少
                // H(X) = -p(x) * log p(x)
                tempEntropy += -tempFraction * Math.log(tempFraction);
            } // Of for j\
            // 最小化条件信息熵
            resultEntropy += tempValueCounts[i] / tempNumInstances * tempEntropy;
        } // Of for i

        return resultEntropy;
    }// Of conditionalEntropy

     * *********************************
     * Split the data according to the given attribute.
     * @return The blocks.
     * *********************************
    public int[][] splitData(int paraAttribute) {
        int tempNumValues = dataset.attribute(paraAttribute).numValues();

        int[][] resultBlocks = new int[tempNumValues][];
        int[] tempSizes = new int[tempNumValues];

        // First scan to count the size of each block.
        int tempValue;
        for (int availableInstance : availableInstances) {
            tempValue = (int) dataset.instance(availableInstance).value(paraAttribute);
        } // Of for i

        // Allocate space.
        for (int i = 0; i < tempNumValues; i++) {
            resultBlocks[i] = new int[tempSizes[i]];
        } // Of for i

        // Second scan to fill.
        Arrays.fill(tempSizes, 0);
        for (int availableInstance : availableInstances) {
            tempValue = (int) dataset.instance(availableInstance).value(paraAttribute);
            // Copy data.
            resultBlocks[tempValue][tempSizes[tempValue]] = availableInstance;
        } // Of for i

        return resultBlocks;
    }// Of splitData

     * *********************************
     * Build the tree recursively.
     * *********************************
    public void buildTree() {
        // Is pure return.
        if (pureJudge(availableInstances)) {
        } // Of if

        // Less than or equal to small block just return
        if (availableInstances.length <= smallBlockThreshold) {
        } // Of if

        int[][] tempSubBlocks = splitData(splitAttribute);
        children = new ID3[tempSubBlocks.length];

        // Construct the remaining attribute set.
        int[] tempRemainingAttributes = new int[availableAttributes.length - 1];
        for (int i = 0; i < availableAttributes.length; i++) {
            if (availableAttributes[i] < splitAttribute) {
                tempRemainingAttributes[i] = availableAttributes[i];
            } else if (availableAttributes[i] > splitAttribute) {
                tempRemainingAttributes[i - 1] = availableAttributes[i];
            } // Of if
        } // Of for i

        // Construct children.
        for (int i = 0; i < children.length; i++) {
            if ((tempSubBlocks[i] == null) || (tempSubBlocks[i].length == 0)) {
                children[i] = null;
            } else {
                // System.out.println("Building children #" + i + " with
                // instances " + Arrays.toString(tempSubBlocks[i]));
                children[i] = new ID3(dataset, tempSubBlocks[i], tempRemainingAttributes);

                // Important code: do this recursively
            } // Of if
        } // Of for i
    }// Of buildTree

     * *********************************
     * Classify an instance.
     * @param paraInstance The given instance.
     * @return The prediction.
     * *********************************
    public int classify(Instance paraInstance) {
        if (children == null) {
            return label;
        } // Of if

        ID3 tempChild = children[(int) paraInstance.value(splitAttribute)];
        if (tempChild == null) {
            return label;
        } // Of if

        return tempChild.classify(paraInstance);
    }// Of classify

     * *********************************
     * Test on a testing set.
     * @param paraDataset The given testing data.
     * @return The accuracy.
     * *********************************
    public double test(Instances paraDataset) {
        double tempCorrect = 0;
        for (int i = 0; i < paraDataset.numInstances(); i++) {
            if (classify(paraDataset.instance(i)) == (int) paraDataset.instance(i).classValue()) {
            } // Of i
        } // Of for i

        return tempCorrect / paraDataset.numInstances();
    }// Of test

     * *********************************
     * Test on the training set.
     * @return The accuracy.
     * *********************************
    public double selfTest() {
        return test(dataset);
    }// Of selfTest

     * ******************
     * Overrides the method claimed in Object.
     * @return The tree structure.
     * ******************
    public String toString() {
        StringBuilder resultString = new StringBuilder();
        String tempAttributeName = dataset.attribute(splitAttribute).name();
        if (children == null) {
            resultString.append("class = ").append(label);
        } else {
            for (int i = 0; i < children.length; i++) {
                if (children[i] == null) {
                            .append(" = ")
                            .append("class = ")
                } else {
                            .append(" = ")
                } // Of if
            } // Of for i
        } // Of if

        return resultString.toString();
    }// Of toString

     * ************************
     * Test this class.
     * ************************
    public static void id3Test() {
        ID3 tempID3 = new ID3("D:/Work/sampledata/weather.arff");

        ID3.smallBlockThreshold = 3;

        System.out.println("The tree is: \r\n" + tempID3);

        double tempAccuracy = tempID3.selfTest();
        System.out.println("The accuracy is: " + tempAccuracy);
    }// Of id3Test

     * ************************
     * Test this class.
     * @param args Not used now.
     * ************************
    public static void main(String[] args) {
    }// Of main
} // Of class ID3
3. 运行截图


对建决策树的大致过程了解了, 相比起简单的 if-else, 决策树的精髓就在与 ID3 生成算法, 采用了信息熵这一个通信方面的知识. 果然研究要集百家之长才能得到更好的解决方法.



