树回归 CART算法

线性回归创建的预测模型需要拟合所有的样本点,在数据拥有众多特征并且特征之间关系十分复杂时,构建全局模型太难,而且,生活中很多问题是非线性的,不可能使用全局线性模型来拟合任何数据。

一种可行的方法是把数据集切分成很多分易建模的数据,然后利用线性回归技术来建模。如果首次切分后仍然难以拟合线性模型就继续切分。这种切分方式下,树结构和回归法就相当有用。

CART算法:分类回归树,既可用于分类也可用于回归。

第三章使用的决策树构建算法是ID3,每次选取当前最佳的特征来分割数据。属于贪心算法,不考虑能否达到全局最优。而且容易造成过拟合、不能直接处理连续型特征,只有事先将连续型特征转换成离散型,才能使用ID3算法。

而使用二元切分法则易于对树构建过程进行调整以处理连续型特征。如果特征值大于给定值就走左子树,小于给定值就走右子树。

CART算法的实现代码:

from numpy import *
def loadDataSet(filename):
    dataMat=[]
    f=open(filename)
    for line in f.readlines():
        curLine=line.strip().split('	')
        floatLine=list(map(float,curLine))
        dataMat.append(floatLine)
    return dataMat
def binSplitDataSet(dataSet,feature,value):
    mat0=dataSet[nonzero(dataSet[:,feature]>value)[0],:]
    mat1=dataSet[nonzero(dataSet[:,feature]<=value)[0],:]
    return mat0,mat1
def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

chooseBestSplit()函数暂未实现。


 

 将CART算法用于回归:
回归树假设叶子节点是常数值。用平方误差的总值(总方差)来计算连续型数值的混乱程度。总方差等于均方差乘以数据集中样本点的个数。

chooseBestSplit():给定某个误差计算方法,该函数会找到数据集上最佳的二元切分方式。还要确定什么时候停止切分,一旦停止切分就会生成一个叶子节点。所以:用最佳方式切分数据集和生成相应的叶节点。

伪代码:

对每个特征:
    对每个特征值:
        将数据集切分成两份
        计算切分后的误差
        如果当前误差小于当前最小误差,将当前切分设定为最佳切分并更新最小误差
    返回最佳切分的特征和阈值

 切分函数的实现:

def regLeaf(dataSet):   #负责生成叶节点,当chooseBestSplit函数确定不再对数据进行切分时,将调用regLeaf函数得到叶节点的模型
    return mean(dataSet[:,-1])  #在回归树中,此模型就是目标变量的均值

def regErr(dataSet):    # 误差估计函数,计算目标变量的平方误差,需要返回总误差,即为均方误差乘以数据集中样本个数
    return var(dataSet[:, -1]) * shape(dataSet)[0]

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): #ops为用户指定的参数,用于控制函数的停止时机
    tolS = ops[0]  # 容许的误差下降值
    tolN = ops[1]  # 切分的最少样本数
    if len(set(dataSet[:, -1].T.tolist()[0])) == 1:  # 统计不同剩余特征值得数目,如果数目为一,就不需要再切分而直接返回
        return None, leafType(dataSet)
    else:
        m, n = shape(dataSet)
        S = errType(dataSet)    #误差
        bestS = inf     #最小误差
        bestIndex = 0
        bestValue = 0
        for featIndex in range(n - 1):  # 对所有特征进行遍历,找到最佳切分方式。最佳切分就是使得切分后能达到最低误差的切分
            # for splitVal in set(dataSet[:, featIndex]):  # 遍历某个特征的所有特征值
            for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]):
                mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)  # 按照某个特征的某个值将数据切分成两个数据子集
                if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  # 如果某个子集行数不大于tolN,也不应该切分
                    continue
                newS = errType(mat0) + errType(mat1)  # 新误差由切分后的两个数据子集组成的误差
                if newS < bestS:  # 判断新切分能否降低误差
                    bestIndex = featIndex
                    bestValue = splitVal
                    bestS = newS
        if (S - bestS) < tolS:  # 如果误差降低不大则退出
            return None, leafType(dataSet)
        mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
        if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  # 如果切分出的数据集很小则退出
            return None, leafType(dataSet)
        return bestIndex, bestValue

 regLeaf():负责生成叶节点,即求当前数据集目标值的平均值作为回归预测值。当chooseBestSplit()确定不再对数据进行切分时,将调用regLeaf()函数来得到叶节点的模型。回归树中,该模型是目标变量的均值。

regErr():误差估计函数。在给定数据集上计算目标变量的平方误差。

chooseBestSplit():构建回归树的核心函数。目的是找到数据的最佳二元切分方式。如果找不到好的二元切分,就返回None并同时调用regLeaf()方法来产生叶节点。

运行代码:

if __name__=='__main__':
    myMat=loadDataSet('ex00.txt')
    myMat=mat(myMat)
    result=createTree(myMat)
    print(result)

 输出为:

{'spInd': 0, 'spVal': 0.48813, 'right': -0.04465028571428572, 'left': 1.0180967672413792}

只有两个叶节点,对照下面的散点图可以看出,在数据0.48813左侧的数据,回归预测值为-0.04465,右侧预测值为1.018。

数据集散点图:

树回归 CART算法

因为数据集简单,所以得到的回归树也简单。

更换数据集测试:

if __name__=='__main__':
    myMat2=loadDataSet('ex2.txt')
    myMat2=mat(myMat2)
    myTree = createTree(myMat2, ops=(0, 1))
    print(myTree)

 输出:

{'spInd': 0, 'spVal': 0.499171, 'right': {'spInd': 0, 'spVal': 0.457563, 'right': {'spInd': 0, 'spVal': 0.455761, 'right': {'spInd': 0, 'spVal': 0.126833, 'right': {'spInd': 0, 'spVal': 0.124723, 'right': {'spInd': 0, 'spVal': 0.085111, 'right': {'spInd': 0, 'spVal': 0.084661, 'right': {'spInd': 0, 'spVal': 0.080061, 'right': {'spInd': 0, 'spVal': 0.068373, 'right': {'spInd': 0, 'spVal': 0.061219, 'right': {'spInd': 0, 'spVal': 0.044737, 'right': {'spInd': 0, 'spVal': 0.028546, 'right': {'spInd': 0, 'spVal': 0.000256, 'right': 9.668106, 'left': -8.377094}, 'left': {'spInd': 0, 'spVal': 0.039914, 'right': 11.220099, 'left': 3.855393}}, 'left': {'spInd': 0, 'spVal': 0.053764, 'right': -13.731698, 'left': {'spInd': 0, 'spVal': 0.055862, 'right': -3.131497, 'left': 6.695567}}}, 'left': -15.160836}, 'left': {'spInd': 0, 'spVal': 0.079632, 'right': 29.420068, 'left': 2.229873}}, 'left': -24.132226}, 'left': 37.820659}, 'left': {'spInd': 0, 'spVal': 0.108801, 'right': {'spInd': 0, 'spVal': 0.10796, 'right': {'spInd': 0, 'spVal': 0.085873, 'right': -10.137104, 'left': -1.293195}, 'left': -16.106164}, 'left': {'spInd': 0, 'spVal': 0.11515, 'right': 13.795828, 'left': -1.402796}}}, 'left': 22.891675}, 'left': {'spInd': 0, 'spVal': 0.130626, 'right': -39.524461, 'left': {'spInd': 0, 'spVal': 0.382037, 'right': {'spInd': 0, 'spVal': 0.335182, 'right': {'spInd': 0, 'spVal': 0.324274, 'right': {'spInd': 0, 'spVal': 0.309133, 'right': {'spInd': 0, 'spVal': 0.131833, 'right': 22.478291, 'left': {'spInd': 0, 'spVal': 0.138619, 'right': -29.087463, 'left': {'spInd': 0, 'spVal': 0.156067, 'right': {'spInd': 0, 'spVal': 0.13988, 'right': 7.336784, 'left': 7.557349}, 'left': {'spInd': 0, 'spVal': 0.166765, 'right': {'spInd': 0, 'spVal': 0.156273, 'right': 0.225886, 'left': {'spInd': 0, 'spVal': 0.164134, 'right': -27.405211, 'left': {'spInd': 0, 'spVal': 0.166431, 'right': -6.512506, 'left': -14.740059}}}, 'left': {'spInd': 0, 'spVal': 0.193282, 'right': {'spInd': 0, 'spVal': 0.176523, 'right': 0.946348, 'left': 18.208423}, 'left': {'spInd': 0, 'spVal': 0.211633, 'right': {'spInd': 0, 'spVal': 0.202161, 'right': {'spInd': 0, 'spVal': 0.199903, 'right': -3.372472, 'left': -1.983889}, 'left': {'spInd': 0, 'spVal': 0.203993, 'right': -22.379119, 'left': {'spInd': 0, 'spVal': 0.206207, 'right': -12.619036, 'left': -8.332207}}}, 'left': {'spInd': 0, 'spVal': 0.228473, 'right': {'spInd': 0, 'spVal': 0.222271, 'right': {'spInd': 0, 'spVal': 0.218321, 'right': {'spInd': 0, 'spVal': 0.217214, 'right': -3.958752, 'left': 1.410768}, 'left': -9.255852}, 'left': {'spInd': 0, 'spVal': 0.2232, 'right': 15.501642, 'left': 19.425158}}, 'left': {'spInd': 0, 'spVal': 0.25807, 'right': {'spInd': 0, 'spVal': 0.228628, 'right': -2.266273, 'left': {'spInd': 0, 'spVal': 0.228751, 'right': -30.812912, 'left': {'spInd': 0, 'spVal': 0.232802, 'right': 1.222318, 'left': -20.425137}}}, 'left': {'spInd': 0, 'spVal': 0.284794, 'right': {'spInd': 0, 'spVal': 0.273863, 'right': {'spInd': 0, 'spVal': 0.264926, 'right': {'spInd': 0, 'spVal': 0.264639, 'right': 2.557923, 'left': 5.280579}, 'left': -9.457556}, 'left': 35.623746}, 'left': {'spInd': 0, 'spVal': 0.300318, 'right': {'spInd': 0, 'spVal': 0.297107, 'right': {'spInd': 0, 'spVal': 0.295993, 'right': {'spInd': 0, 'spVal': 0.290749, 'right': -14.391613, 'left': -14.988279}, 'left': -1.798377}, 'left': -18.051318}, 'left': 8.814725}}}}}}}}}}, 'left': {'spInd': 0, 'spVal': 0.310956, 'right': -49.939516, 'left': {'spInd': 0, 'spVal': 0.318309, 'right': -27.605424, 'left': -13.189243}}}, 'left': {'spInd': 0, 'spVal': 0.32889, 'right': 39.783113, 'left': {'spInd': 0, 'spVal': 0.331364, 'right': -1.290825, 'left': {'spInd': 0, 'spVal': 0.3349, 'right': 18.97665, 'left': 2.768225}}}}, 'left': {'spInd': 0, 'spVal': 0.370042, 'right': {'spInd': 0, 'spVal': 0.35679, 'right': {'spInd': 0, 'spVal': 0.350725, 'right': {'spInd': 0, 'spVal': 0.350065, 'right': {'spInd': 0, 'spVal': 0.342761, 'right': {'spInd': 0, 'spVal': 0.342155, 'right': {'spInd': 0, 'spVal': 0.3417, 'right': -23.547711, 'left': -16.930416}, 'left': -31.584855}, 'left': -1.319852}, 'left': -40.086564}, 'left': {'spInd': 0, 'spVal': 0.351478, 'right': -0.461116, 'left': -19.526539}}, 'left': -32.124495}, 'left': {'spInd': 0, 'spVal': 0.378965, 'right': {'spInd': 0, 'spVal': 0.373501, 'right': -8.228297, 'left': {'spInd': 0, 'spVal': 0.377383, 'right': 5.241196, 'left': 13.583555}}, 'left': -29.007783}}}, 'left': {'spInd': 0, 'spVal': 0.388789, 'right': {'spInd': 0, 'spVal': 0.385021, 'right': 24.816941, 'left': 21.578007}, 'left': {'spInd': 0, 'spVal': 0.437652, 'right': {'spInd': 0, 'spVal': 0.412516, 'right': {'spInd': 0, 'spVal': 0.403228, 'right': {'spInd': 0, 'spVal': 0.391609, 'right': 3.001104, 'left': -1.729244}, 'left': -26.419289}, 'left': {'spInd': 0, 'spVal': 0.418943, 'right': 44.161493, 'left': {'spInd': 0, 'spVal': 0.426711, 'right': -21.594268, 'left': {'spInd': 0, 'spVal': 0.428582, 'right': 15.224266, 'left': 19.745224}}}}, 'left': {'spInd': 0, 'spVal': 0.454312, 'right': {'spInd': 0, 'spVal': 0.446196, 'right': -5.108172, 'left': {'spInd': 0, 'spVal': 0.451087, 'right': -28.724685, 'left': -20.360067}}, 'left': {'spInd': 0, 'spVal': 0.454375, 'right': 3.043912, 'left': 9.841938}}}}}}}, 'left': -34.044555}, 'left': {'spInd': 0, 'spVal': 0.465561, 'right': {'spInd': 0, 'spVal': 0.463241, 'right': 17.171057, 'left': 30.051931}, 'left': {'spInd': 0, 'spVal': 0.467383, 'right': {'spInd': 0, 'spVal': 0.46568, 'right': -23.777531, 'left': -9.712925}, 'left': {'spInd': 0, 'spVal': 0.483803, 'right': 5.224234, 'left': {'spInd': 0, 'spVal': 0.487381, 'right': 27.729263, 'left': {'spInd': 0, 'spVal': 0.487537, 'right': 5.149336, 'left': 11.924204}}}}}}, 'left': {'spInd': 0, 'spVal': 0.729397, 'right': {'spInd': 0, 'spVal': 0.640515, 'right': {'spInd': 0, 'spVal': 0.613004, 'right': {'spInd': 0, 'spVal': 0.606417, 'right': {'spInd': 0, 'spVal': 0.513332, 'right': {'spInd': 0, 'spVal': 0.508548, 'right': {'spInd': 0, 'spVal': 0.508542, 'right': 96.403373, 'left': 93.292829}, 'left': 101.075609}, 'left': {'spInd': 0, 'spVal': 0.533511, 'right': {'spInd': 0, 'spVal': 0.51915, 'right': 116.176162, 'left': {'spInd': 0, 'spVal': 0.531944, 'right': 124.795495, 'left': 129.766743}}, 'left': {'spInd': 0, 'spVal': 0.548539, 'right': {'spInd': 0, 'spVal': 0.546601, 'right': {'spInd': 0, 'spVal': 0.537834, 'right': 90.995536, 'left': {'spInd': 0, 'spVal': 0.543843, 'right': 98.36201, 'left': 96.319043}}, 'left': 83.114502}, 'left': {'spInd': 0, 'spVal': 0.553797, 'right': {'spInd': 0, 'spVal': 0.549814, 'right': 137.267576, 'left': 120.857321}, 'left': {'spInd': 0, 'spVal': 0.560301, 'right': 82.903945, 'left': {'spInd': 0, 'spVal': 0.599142, 'right': {'spInd': 0, 'spVal': 0.589806, 'right': {'spInd': 0, 'spVal': 0.582311, 'right': {'spInd': 0, 'spVal': 0.571214, 'right': {'spInd': 0, 'spVal': 0.569327, 'right': 108.435392, 'left': 114.872056}, 'left': 82.589328}, 'left': {'spInd': 0, 'spVal': 0.585413, 'right': 125.295113, 'left': 98.674874}}, 'left': 130.378529}, 'left': 93.521396}}}}}}, 'left': 168.180746}, 'left': {'spInd': 0, 'spVal': 0.623909, 'right': {'spInd': 0, 'spVal': 0.618868, 'right': 76.917665, 'left': 87.181863}, 'left': {'spInd': 0, 'spVal': 0.628061, 'right': {'spInd': 0, 'spVal': 0.624827, 'right': 105.970743, 'left': 117.628346}, 'left': {'spInd': 0, 'spVal': 0.637999, 'right': {'spInd': 0, 'spVal': 0.632691, 'right': 93.645293, 'left': 91.656617}, 'left': 82.713621}}}}, 'left': {'spInd': 0, 'spVal': 0.642373, 'right': 140.613941, 'left': {'spInd': 0, 'spVal': 0.642707, 'right': 82.500766, 'left': {'spInd': 0, 'spVal': 0.665329, 'right': {'spInd': 0, 'spVal': 0.661073, 'right': {'spInd': 0, 'spVal': 0.652462, 'right': 112.715799, 'left': 115.687524}, 'left': 121.980607}, 'left': {'spInd': 0, 'spVal': 0.706961, 'right': {'spInd': 0, 'spVal': 0.698472, 'right': {'spInd': 0, 'spVal': 0.689099, 'right': {'spInd': 0, 'spVal': 0.666452, 'right': {'spInd': 0, 'spVal': 0.665652, 'right': 105.547997, 'left': 120.014736}, 'left': {'spInd': 0, 'spVal': 0.667851, 'right': 92.449664, 'left': {'spInd': 0, 'spVal': 0.680486, 'right': 110.367074, 'left': 112.378209}}}, 'left': 120.521925}, 'left': {'spInd': 0, 'spVal': 0.69892, 'right': 92.470636, 'left': {'spInd': 0, 'spVal': 0.699873, 'right': 115.586605, 'left': {'spInd': 0, 'spVal': 0.70639, 'right': 105.062147, 'left': 106.180427}}}}, 'left': {'spInd': 0, 'spVal': 0.70889, 'right': 135.416767, 'left': {'spInd': 0, 'spVal': 0.716211, 'right': {'spInd': 0, 'spVal': 0.710234, 'right': 108.553919, 'left': 103.345308}, 'left': 110.90283}}}}}}}, 'left': {'spInd': 0, 'spVal': 0.952833, 'right': {'spInd': 0, 'spVal': 0.759504, 'right': {'spInd': 0, 'spVal': 0.740859, 'right': {'spInd': 0, 'spVal': 0.731636, 'right': 73.912028, 'left': 93.773929}, 'left': {'spInd': 0, 'spVal': 0.757527, 'right': 63.549854, 'left': 81.106762}}, 'left': {'spInd': 0, 'spVal': 0.763328, 'right': 115.199195, 'left': {'spInd': 0, 'spVal': 0.769043, 'right': 64.041941, 'left': {'spInd': 0, 'spVal': 0.790312, 'right': {'spInd': 0, 'spVal': 0.786865, 'right': {'spInd': 0, 'spVal': 0.785574, 'right': {'spInd': 0, 'spVal': 0.777582, 'right': 100.838446, 'left': 107.024467}, 'left': 100.598825}, 'left': {'spInd': 0, 'spVal': 0.787755, 'right': 118.642009, 'left': 110.15973}}, 'left': {'spInd': 0, 'spVal': 0.806158, 'right': {'spInd': 0, 'spVal': 0.799873, 'right': {'spInd': 0, 'spVal': 0.798198, 'right': 76.853728, 'left': 91.368473}, 'left': 62.877698}, 'left': {'spInd': 0, 'spVal': 0.815215, 'right': {'spInd': 0, 'spVal': 0.811602, 'right': {'spInd': 0, 'spVal': 0.811363, 'right': 112.981216, 'left': 99.841379}, 'left': 118.319942}, 'left': {'spInd': 0, 'spVal': 0.833026, 'right': {'spInd': 0, 'spVal': 0.823848, 'right': {'spInd': 0, 'spVal': 0.819722, 'right': 70.054508, 'left': 59.342323}, 'left': 76.723835}, 'left': {'spInd': 0, 'spVal': 0.841547, 'right': {'spInd': 0, 'spVal': 0.838587, 'right': 134.089674, 'left': 115.669032}, 'left': {'spInd': 0, 'spVal': 0.841625, 'right': 60.552308, 'left': {'spInd': 0, 'spVal': 0.944221, 'right': {'spInd': 0, 'spVal': 0.85497, 'right': {'spInd': 0, 'spVal': 0.84294, 'right': 95.893131, 'left': {'spInd': 0, 'spVal': 0.847219, 'right': 76.240984, 'left': 89.20993}}, 'left': {'spInd': 0, 'spVal': 0.936524, 'right': {'spInd': 0, 'spVal': 0.934853, 'right': {'spInd': 0, 'spVal': 0.925782, 'right': {'spInd': 0, 'spVal': 0.910975, 'right': {'spInd': 0, 'spVal': 0.901444, 'right': {'spInd': 0, 'spVal': 0.901421, 'right': {'spInd': 0, 'spVal': 0.892999, 'right': {'spInd': 0, 'spVal': 0.888426, 'right': {'spInd': 0, 'spVal': 0.872199, 'right': {'spInd': 0, 'spVal': 0.866451, 'right': {'spInd': 0, 'spVal': 0.856421, 'right': 107.166848, 'left': 94.402102}, 'left': 111.552716}, 'left': {'spInd': 0, 'spVal': 0.883615, 'right': {'spInd': 0, 'spVal': 0.872883, 'right': 95.887712, 'left': 95.348184}, 'left': {'spInd': 0, 'spVal': 0.885676, 'right': 108.045948, 'left': 94.896354}}}, 'left': 82.436686}, 'left': {'spInd': 0, 'spVal': 0.900699, 'right': {'spInd': 0, 'spVal': 0.896683, 'right': 107.00162, 'left': 109.188248}, 'left': 100.133819}}, 'left': 87.300625}, 'left': {'spInd': 0, 'spVal': 0.908629, 'right': 118.513475, 'left': 106.814667}}, 'left': {'spInd': 0, 'spVal': 0.912161, 'right': 85.005351, 'left': {'spInd': 0, 'spVal': 0.915263, 'right': 96.71761, 'left': 92.074619}}}, 'left': 115.753994}, 'left': 65.548418}, 'left': {'spInd': 0, 'spVal': 0.937766, 'right': 119.949824, 'left': 100.120253}}}, 'left': {'spInd': 0, 'spVal': 0.948822, 'right': 69.318649, 'left': {'spInd': 0, 'spVal': 0.949198, 'right': 105.752508, 'left': {'spInd': 0, 'spVal': 0.952377, 'right': 73.520802, 'left': 100.649591}}}}}}}}}}}}}, 'left': {'spInd': 0, 'spVal': 0.965969, 'right': {'spInd': 0, 'spVal': 0.956951, 'right': {'spInd': 0, 'spVal': 0.953902, 'right': 130.92648, 'left': {'spInd': 0, 'spVal': 0.954711, 'right': 100.935789, 'left': 82.016541}}, 'left': {'spInd': 0, 'spVal': 0.958512, 'right': 135.837013, 'left': {'spInd': 0, 'spVal': 0.960398, 'right': 123.559747, 'left': 112.386764}}}, 'left': {'spInd': 0, 'spVal': 0.968621, 'right': 98.648346, 'left': 86.399637}}}}}

 散点图:

树回归 CART算法

得到的树很复杂,改变ops元组的值:

if __name__=='__main__':
    myMat2 = loadDataSet('ex2.txt')
    myMat2 = mat(myMat2)
    myTree = createTree(myMat2, ops=(10000, 4))
    print(myTree)

 输出:

{'spInd': 0, 'spVal': 0.499171, 'left': 101.35815937735848, 'right': -2.637719329787234}

 也可以得到仅有两个叶节点的树。


树剪枝:
一棵树如果节点过多,表明该模型可能对数据进行了过拟合。通过降低决策树的复杂度来避免过拟合的过程称为“剪枝”。

在函数chooseBestSplit()中的提前终止条件,实际上是“预剪枝”操作,预剪枝操作对于参数ops元组非常敏感,难以获得有效的回归树。

后剪枝:利用测试集对数进行剪枝。由于不需要用户指定参数,后剪枝是一种更理想化的剪枝方法。

首先将数据集划分为训练集和测试集。先使用训练集构建出一棵足够复杂的树便于剪枝。然后从上到下找到叶节点,用测试集来判断这些叶节点合并能不能降低测试误差,如果可以的话就合并。

伪代码如下:

基于已有的树切分测试数据:
    如果存在任一子集是一棵树,则在该子集递归剪枝过程
    计算将当前两个叶子节点合并后的误差
    计算不合并的误差
    如果合并会降低误差则合并

 回归树剪枝函数prune():

def isTree(obj):  # 测试输入变量是否是一棵树,返回布尔型的结果,用于判断当前处理的节点是否是叶节点
    return (type(obj).__name__ == "dict")
def getMean(tree):  # 递归函数,从上到下遍历树直到叶节点为止。如果找到两个叶节点则计算它们的平均值。该函数对树进行塌陷处理
    if isTree(tree["right"]):
        tree["right"] = getMean(tree["right"])
    if isTree(tree["left"]):
        tree["left"] = getMean(tree["left"])
    return (tree["left"] + tree["right"]) / 2.0

def prune(tree, testData):  #参数:待剪枝的树与剪枝所需的测试数据
    if shape(testData)[0] == 0:     #没有测试数据则对树进行塌陷处理
        return getMean(tree)
    if (isTree(tree['right']) or isTree(tree['left'])):  #
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'], rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + sum(power(rSet[:, -1] - tree['right'], 2))
        treeMean = (tree['left'] + tree['right']) / 2.0
        errorMerge = sum(power(testData[:, -1] - treeMean, 2))
        if errorMerge < errorNoMerge:
            print("融合")
            return treeMean
        else:
            return tree
    else:
        return tree

 isTree():测试输入变量是否是一棵树,返回布尔值的结果。用于判断当前处理的节点是不是叶子节点。

getMean():递归函数,从上到下遍历树直到叶节点。如果找到两个叶节点就返回其平均值。该函数对树进行塌陷处理。

prune():参数为待剪枝的树和剪枝所需的测试数据集。

测试:

if __name__=='__main__':
    myMat2=loadDataSet('ex2.txt')
    myMat2=mat(myMat2)
    myTree = createTree(myMat2, ops=(0, 1))
    myDat2Test = loadDataSet("ex2test.txt")
    myMat2Test = mat(myDat2Test)
    result=prune(myTree, myMat2Test)
    print(result)

 输出:

融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
融合
{'left': {'left': {'left': {'left': 92.5239915, 'spInd': 0, 'spVal': 0.965969, 'right': {'left': {'left': {'left': 112.386764, 'spInd': 0, 'spVal': 0.960398, 'right': 123.559747}, 'spInd': 0, 'spVal': 0.958512, 'right': 135.837013}, 'spInd': 0, 'spVal': 0.956951, 'right': 111.2013225}}, 'spInd': 0, 'spVal': 0.952833, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 96.41885225, 'spInd': 0, 'spVal': 0.948822, 'right': 69.318649}, 'spInd': 0, 'spVal': 0.944221, 'right': {'left': {'left': 110.03503850000001, 'spInd': 0, 'spVal': 0.936524, 'right': {'left': 65.548418, 'spInd': 0, 'spVal': 0.934853, 'right': {'left': 115.753994, 'spInd': 0, 'spVal': 0.925782, 'right': {'left': {'left': 94.3961145, 'spInd': 0, 'spVal': 0.912161, 'right': 85.005351}, 'spInd': 0, 'spVal': 0.910975, 'right': {'left': {'left': 106.814667, 'spInd': 0, 'spVal': 0.908629, 'right': 118.513475}, 'spInd': 0, 'spVal': 0.901444, 'right': {'left': 87.300625, 'spInd': 0, 'spVal': 0.901421, 'right': {'left': {'left': 100.133819, 'spInd': 0, 'spVal': 0.900699, 'right': 108.094934}, 'spInd': 0, 'spVal': 0.892999, 'right': {'left': 82.436686, 'spInd': 0, 'spVal': 0.888426, 'right': {'left': 98.54454949999999, 'spInd': 0, 'spVal': 0.872199, 'right': 106.16859550000001}}}}}}}}}, 'spInd': 0, 'spVal': 0.85497, 'right': {'left': {'left': 89.20993, 'spInd': 0, 'spVal': 0.847219, 'right': 76.240984}, 'spInd': 0, 'spVal': 0.84294, 'right': 95.893131}}}, 'spInd': 0, 'spVal': 0.841625, 'right': 60.552308}, 'spInd': 0, 'spVal': 0.841547, 'right': 124.87935300000001}, 'spInd': 0, 'spVal': 0.833026, 'right': {'left': 76.723835, 'spInd': 0, 'spVal': 0.823848, 'right': {'left': 59.342323, 'spInd': 0, 'spVal': 0.819722, 'right': 70.054508}}}, 'spInd': 0, 'spVal': 0.815215, 'right': {'left': 118.319942, 'spInd': 0, 'spVal': 0.811602, 'right': {'left': 99.841379, 'spInd': 0, 'spVal': 0.811363, 'right': 112.981216}}}, 'spInd': 0, 'spVal': 0.806158, 'right': 73.49439925}, 'spInd': 0, 'spVal': 0.790312, 'right': {'left': 114.4008695, 'spInd': 0, 'spVal': 0.786865, 'right': 102.26514075}}, 'spInd': 0, 'spVal': 0.769043, 'right': 64.041941}, 'spInd': 0, 'spVal': 0.763328, 'right': 115.199195}, 'spInd': 0, 'spVal': 0.759504, 'right': 78.08564325}}, 'spInd': 0, 'spVal': 0.729397, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 110.90283, 'spInd': 0, 'spVal': 0.716211, 'right': {'left': 103.345308, 'spInd': 0, 'spVal': 0.710234, 'right': 108.553919}}, 'spInd': 0, 'spVal': 0.70889, 'right': 135.416767}, 'spInd': 0, 'spVal': 0.706961, 'right': {'left': {'left': {'left': {'left': 106.180427, 'spInd': 0, 'spVal': 0.70639, 'right': 105.062147}, 'spInd': 0, 'spVal': 0.699873, 'right': 115.586605}, 'spInd': 0, 'spVal': 0.69892, 'right': 92.470636}, 'spInd': 0, 'spVal': 0.698472, 'right': {'left': 120.521925, 'spInd': 0, 'spVal': 0.689099, 'right': {'left': 101.91115275, 'spInd': 0, 'spVal': 0.666452, 'right': 112.78136649999999}}}}, 'spInd': 0, 'spVal': 0.665329, 'right': {'left': 121.980607, 'spInd': 0, 'spVal': 0.661073, 'right': {'left': 115.687524, 'spInd': 0, 'spVal': 0.652462, 'right': 112.715799}}}, 'spInd': 0, 'spVal': 0.642707, 'right': 82.500766}, 'spInd': 0, 'spVal': 0.642373, 'right': 140.613941}, 'spInd': 0, 'spVal': 0.640515, 'right': {'left': {'left': {'left': {'left': 82.713621, 'spInd': 0, 'spVal': 0.637999, 'right': {'left': 91.656617, 'spInd': 0, 'spVal': 0.632691, 'right': 93.645293}}, 'spInd': 0, 'spVal': 0.628061, 'right': {'left': 117.628346, 'spInd': 0, 'spVal': 0.624827, 'right': 105.970743}}, 'spInd': 0, 'spVal': 0.623909, 'right': 82.04976400000001}, 'spInd': 0, 'spVal': 0.613004, 'right': {'left': 168.180746, 'spInd': 0, 'spVal': 0.606417, 'right': {'left': {'left': {'left': {'left': {'left': {'left': 93.521396, 'spInd': 0, 'spVal': 0.599142, 'right': {'left': 130.378529, 'spInd': 0, 'spVal': 0.589806, 'right': {'left': 111.9849935, 'spInd': 0, 'spVal': 0.582311, 'right': {'left': 82.589328, 'spInd': 0, 'spVal': 0.571214, 'right': {'left': 114.872056, 'spInd': 0, 'spVal': 0.569327, 'right': 108.435392}}}}}, 'spInd': 0, 'spVal': 0.560301, 'right': 82.903945}, 'spInd': 0, 'spVal': 0.553797, 'right': 129.0624485}, 'spInd': 0, 'spVal': 0.548539, 'right': {'left': 83.114502, 'spInd': 0, 'spVal': 0.546601, 'right': {'left': 97.3405265, 'spInd': 0, 'spVal': 0.537834, 'right': 90.995536}}}, 'spInd': 0, 'spVal': 0.533511, 'right': {'left': {'left': 129.766743, 'spInd': 0, 'spVal': 0.531944, 'right': 124.795495}, 'spInd': 0, 'spVal': 0.51915, 'right': 116.176162}}, 'spInd': 0, 'spVal': 0.513332, 'right': {'left': 101.075609, 'spInd': 0, 'spVal': 0.508548, 'right': {'left': 93.292829, 'spInd': 0, 'spVal': 0.508542, 'right': 96.403373}}}}}}}, 'spInd': 0, 'spVal': 0.499171, 'right': {'left': {'left': {'left': {'left': {'left': 8.53677, 'spInd': 0, 'spVal': 0.487381, 'right': 27.729263}, 'spInd': 0, 'spVal': 0.483803, 'right': 5.224234}, 'spInd': 0, 'spVal': 0.467383, 'right': {'left': -9.712925, 'spInd': 0, 'spVal': 0.46568, 'right': -23.777531}}, 'spInd': 0, 'spVal': 0.465561, 'right': {'left': 30.051931, 'spInd': 0, 'spVal': 0.463241, 'right': 17.171057}}, 'spInd': 0, 'spVal': 0.457563, 'right': {'left': -34.044555, 'spInd': 0, 'spVal': 0.455761, 'right': {'left': {'left': {'left': {'left': {'left': -4.1911745, 'spInd': 0, 'spVal': 0.437652, 'right': {'left': {'left': {'left': {'left': 19.745224, 'spInd': 0, 'spVal': 0.428582, 'right': 15.224266}, 'spInd': 0, 'spVal': 0.426711, 'right': -21.594268}, 'spInd': 0, 'spVal': 0.418943, 'right': 44.161493}, 'spInd': 0, 'spVal': 0.412516, 'right': {'left': -26.419289, 'spInd': 0, 'spVal': 0.403228, 'right': 0.6359300000000001}}}, 'spInd': 0, 'spVal': 0.388789, 'right': 23.197474}, 'spInd': 0, 'spVal': 0.382037, 'right': {'left': {'left': {'left': -29.007783, 'spInd': 0, 'spVal': 0.378965, 'right': {'left': {'left': 13.583555, 'spInd': 0, 'spVal': 0.377383, 'right': 5.241196}, 'spInd': 0, 'spVal': 0.373501, 'right': -8.228297}}, 'spInd': 0, 'spVal': 0.370042, 'right': {'left': -32.124495, 'spInd': 0, 'spVal': 0.35679, 'right': {'left': -9.9938275, 'spInd': 0, 'spVal': 0.350725, 'right': -26.851234812500003}}}, 'spInd': 0, 'spVal': 0.335182, 'right': {'left': 22.286959625, 'spInd': 0, 'spVal': 0.324274, 'right': {'left': {'left': -20.3973335, 'spInd': 0, 'spVal': 0.310956, 'right': -49.939516}, 'spInd': 0, 'spVal': 0.309133, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 8.814725, 'spInd': 0, 'spVal': 0.300318, 'right': {'left': -18.051318, 'spInd': 0, 'spVal': 0.297107, 'right': {'left': -1.798377, 'spInd': 0, 'spVal': 0.295993, 'right': {'left': -14.988279, 'spInd': 0, 'spVal': 0.290749, 'right': -14.391613}}}}, 'spInd': 0, 'spVal': 0.284794, 'right': {'left': 35.623746, 'spInd': 0, 'spVal': 0.273863, 'right': {'left': -9.457556, 'spInd': 0, 'spVal': 0.264926, 'right': {'left': 5.280579, 'spInd': 0, 'spVal': 0.264639, 'right': 2.557923}}}}, 'spInd': 0, 'spVal': 0.25807, 'right': {'left': {'left': -9.601409499999999, 'spInd': 0, 'spVal': 0.228751, 'right': -30.812912}, 'spInd': 0, 'spVal': 0.228628, 'right': -2.266273}}, 'spInd': 0, 'spVal': 0.228473, 'right': 6.099239}, 'spInd': 0, 'spVal': 0.211633, 'right': {'left': -16.42737025, 'spInd': 0, 'spVal': 0.202161, 'right': -2.6781805}}, 'spInd': 0, 'spVal': 0.193282, 'right': 9.5773855}, 'spInd': 0, 'spVal': 0.166765, 'right': {'left': {'left': {'left': -14.740059, 'spInd': 0, 'spVal': 0.166431, 'right': -6.512506}, 'spInd': 0, 'spVal': 0.164134, 'right': -27.405211}, 'spInd': 0, 'spVal': 0.156273, 'right': 0.225886}}, 'spInd': 0, 'spVal': 0.156067, 'right': {'left': 7.557349, 'spInd': 0, 'spVal': 0.13988, 'right': 7.336784}}, 'spInd': 0, 'spVal': 0.138619, 'right': -29.087463}, 'spInd': 0, 'spVal': 0.131833, 'right': 22.478291}}}}}, 'spInd': 0, 'spVal': 0.130626, 'right': -39.524461}, 'spInd': 0, 'spVal': 0.126833, 'right': {'left': 22.891675, 'spInd': 0, 'spVal': 0.124723, 'right': {'left': {'left': 6.196516, 'spInd': 0, 'spVal': 0.108801, 'right': {'left': -16.106164, 'spInd': 0, 'spVal': 0.10796, 'right': {'left': -1.293195, 'spInd': 0, 'spVal': 0.085873, 'right': -10.137104}}}, 'spInd': 0, 'spVal': 0.085111, 'right': {'left': 37.820659, 'spInd': 0, 'spVal': 0.084661, 'right': {'left': -24.132226, 'spInd': 0, 'spVal': 0.080061, 'right': {'left': 15.824970500000001, 'spInd': 0, 'spVal': 0.068373, 'right': {'left': -15.160836, 'spInd': 0, 'spVal': 0.061219, 'right': {'left': {'left': {'left': 6.695567, 'spInd': 0, 'spVal': 0.055862, 'right': -3.131497}, 'spInd': 0, 'spVal': 0.053764, 'right': -13.731698}, 'spInd': 0, 'spVal': 0.044737, 'right': 4.091626}}}}}}}}}}}
View Code

 虽然合并了很多叶节点,但剪枝后的树没有像预期的那样剪枝成两部分。说明后剪枝可能不如预剪枝有效。可以同时使用两种剪枝方式。


模型树:把叶子节点设定为分段线性函数。利用数生成算法对数据切分,且每份切分数据容易被线性模型表示。该算法的关键在于误差的计算。

对于给定的数据集,应该先用线性的模型对它拟合,然后计算真是的目标值与模型预测值之间的差值,再将这些差值的平方求和就得到了所需要的误差。

模型树的叶节点生成函数:

def linearSolve(dataSet):
    m, n = shape(dataSet)
    X = mat(ones((m, n)))  #第一列仍为1
    Y = mat(ones((m, 1)))
    X[:, 1:n] = dataSet[:, 0:n - 1]
    # print('X:',X)
    Y = dataSet[:, -1]  # 将X,Y中的数据格式化
    # print('Y:',Y)
    xTx = X.T * X
    if linalg.det(xTx) == 0.0:
        raise NameError("此矩阵不可逆。")
        # ws = linalg.pinv(xTx) * (X.T * Y)
    ws = xTx.I * (X.T * Y)
    return ws, X, Y

def modelLeaf(dataSet):  # 当数据不再需要切分的时候它负责生成叶节点模型
    ws, X, Y = linearSolve(dataSet)
    return ws
def modelErr(dataSet):
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat, 2))

 数据集散点图如下:

树回归 CART算法

 测试:

myMat=mat(loadDataSet('exp2.txt'))
    plotPoint(myMat)
    myTree=createTree(myMat,modelLeaf,modelErr,(1,10))
    print(myTree)

 输出结果:

{'spInd': 0, 'spVal': 0.285477, 'right': matrix([[3.46877936],
        [1.18521743]]), 'left': matrix([[1.69855694e-03],
        [1.19647739e+01]])}

 将数据集从x=0.285477分开,分别用两段线性模型来拟合。


 树回归与标准回归的比较:相关系数

用树回归进行预测的代码:包括回归树和模型树两种树

def regTreeEval(model, inDat):  #回归树效果评估
    return float(model)

def modelTreeEval(model, inDat):    #模型树效果评估
    n = shape(inDat)[1]
    X = mat(ones((1, n + 1)))
    X[:, 1:n + 1] = inDat
    return float(X * model)

def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree):
        return modelEval(tree, inData)  # 如果输入单个数据或行向量,返回一个浮点值
    else:
        if inData[tree["spInd"]] > tree["spVal"]:
            if isTree(tree["left"]):
                return treeForeCast(tree["left"], inData, modelEval)
            else:
                return modelEval(tree["left"], inData)
        else:
            if isTree(tree["right"]):
                return treeForeCast(tree["right"], inData, modelEval)
            else:
                return modelEval(tree["right"], inData)
def createForeCast(tree, testData, modelEval=regTreeEval):  #测试不同回归树的效果
    m = len(testData)
    yHat = mat(zeros((m, 1)))
    for i in range(m):
        yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval)  # 多次调用treeForeCast函数,将结果以列的形式放到yHat变量中
    return yHat

 因为代码中已经含有标准线性回归函数(linearSolve),所以不必重新写其生成代码。

测试:

if __name__=='__main__':
    trainMat = mat(loadDataSet("bikeSpeedVsIq_train.txt"))
    testMat = mat(loadDataSet("bikeSpeedVsIq_test.txt"))
    myTree = createTree(trainMat, ops=(1, 20))
    yHat = createForeCast(myTree, testMat[:, 0])
    print("回归树的相关系数:", corrcoef(yHat, testMat[:, -1], rowvar=0)[0, 1])

    myTree = createTree(trainMat, modelLeaf, modelErr, (1, 20))
    yHat = createForeCast(myTree, testMat[:, 0], modelTreeEval)
    print("模型树的相关系数:", corrcoef(yHat, testMat[:, -1], rowvar=0)[0, 1])

    ws, X, Y = linearSolve(trainMat)
    print("线性回归系数:", ws)
    for i in range(shape(testMat)[0]):
        yHat[i] = testMat[i, 0] * ws[1, 0] + ws[0, 0]
    print("线性回归模型的相关系数:", corrcoef(yHat, testMat[:, -1], rowvar=0)[0, 1])

 输出:

回归树的相关系数: 0.964085231822215
模型树的相关系数: 0.9760412191380629
线性回归系数: [[37.58916794]
 [ 6.18978355]]
线性回归模型的相关系数: 0.9434684235674766

 相关系数越接近1越好,所以,模型树>回归树>标准线性回归。