matlab实现cart(回归分类树)

作为机器学习的小白和matlab的小白自己参照 python的 《机器学习实战》 写了一下分类回归树,这里记录一下。

关于决策树的基础概念就不过多介绍了,至于是分类还是回归。。我说不清楚。。我用的数据集是这个http://archive.ics.uci.edu/ml/datasets/Abalone 就是通过一些属性来预测鲍鱼有多少头,下面看一下

Length / continuous / mm / Longest shell measurement 
Diameter / continuous / mm / perpendicular to length 
Height / continuous / mm / with meat in shell 
Whole weight / continuous / grams / whole abalone 
Shucked weight / continuous / grams / weight of meat 
Viscera weight / continuous / grams / gut weight (after bleeding) 
Shell weight / continuous / grams / after being dried 
Rings / integer / -- / +1.5 gives the age in years

这些属性除了最后的Rings是整数,可以看做是离散的,其他都是浮点数,是连续的。所以还是用cart中二分的思想,就是小于等于分一边,大于分一边。但是没有用gini指数,因为熵还是好一点。

代码在github:https://github.com/jokermask/matlab_cart

参照《机器学习实战》代码有5个部分:getEnt(获取信息熵),splitDataset(通过属性和阈值分割数据集),chooseBestFeatureToSplit(寻找最佳分割点和阈值),createTree(建树),predict(预测)。

我按流程梳理一下,首先函数脚本来将数据集划分成,训练集和测试集,然后用训练集建树,用测试集测试,(更改后变成bootstrap sampleing)

dataset = importdata('abalone.data.txt') ;
origin_data = dataset.data ;
labels = {'Length';'Diam';'Height';    'Whole';'Shucked';'Viscera';'Shell';'Rings'} ;
test_runtimes = 50 ;
ae = 0 ;
rr = 0 ;
for i=1:test_runtimes
    data = sampleWithReplace(origin_data) ;%bootstrap sampling
    len = floor(length(data)/4*3) ;
    train_data = data(1:len,:) ;
    test_data = data(len:end,:) ;
    test_y_truth = test_data(:,end) ;
%     tree = createTree(train_data,labels,0) ;
%     predict_y = predict(tree,test_data,labels) ;
%     com_matrix = [predict_y,test_y_truth] ;
%     count = sum(predict_y==test_y_truth) ;
%     disp(com_matrix) ;
%     disp(mae) ;
%     disp(rr) ;

%plot single runtime
%     x = 1:1:size(test_y_truth,1) ;
%     plot(x,predict_y,'-b',x,test_y_truth,'-r') ;

    ae = ae+sum(abs(predict_y-test_y_truth))/size(test_y_truth,1) ;
    rr = rr+count/size(test_y_truth,1) ;
    
    %trian with office tools fitctree
    
    std_tree = fitctree(train_data(:,1:7),train_data(:,end)) ;
    % view(std_tree) ;
    std_y = predict(std_tree,test_data(:,1:7)) ;
    % disp([std_y,y]) ;
    ae = ae+sum(abs(std_y-test_y_truth))/size(test_y_truth,1) ;
    rr = rr+sum(std_y==test_y_truth)/size(test_y_truth,1) ;
end
mae = mae / test_runtimes ;
mrr = rr / test_runtimes ;
disp('mae') ;
disp(mae) ;
disp('mrr') ;
disp(mrr) ;

createTree函数:由于matlab没有指针,所以只能写成嵌套结构,就像tree{tree{tree}}这样。我们是递归实现的,但怎么样才会停止建树?条件是当前节点所有标签的类别一样,比如rings都为10,那说明这一个子集已经纯了,或者是这颗树的高度已经超出了我们设的阈值,就停止,第二种情况很可能当前节点下的数据集不纯,我们就找一个出现频率最高的类别代表该节点

function [ tree ] = createTree( dataset,labels,heightcount )
    len = size(dataset,1) ;
    templabel = dataset(1,end) ;
    tree = templabel ;
    max_depth = 5 ;%最大树高
    flag = 1 ; %判断是否数据集中所有标签都一致了(纯的),是则返回
    for i=1:len
        if templabel~=dataset(i,end) ;
            flag = 0 ;
        end
    end
    if flag==1
        return ;
    end
    if heightcount>max_depth
        labelVec = dataset(:,end) ;
        disp(labelVec) ;
        element = 1:max(labelVec) ;
        counts = histc(labelVec,element) ;
        [~,max_idx] = max(counts) ;
        tree = element(max_idx) ;
        return ;
    end
    [bestFeat,bestT] = chooseBestFeatureToSplit(dataset) ;
    bestFeatLabel = labels{bestFeat} ;
    tree = struct ;%struct储存树结构
    tree.bestFeatLabel = bestFeatLabel ;
    tree.bestT = bestT ;
    tree.greaterthan = createTree(splitDataset(dataset,bestFeat,bestT,1),labels,heightcount+1) ;%大于阈值部分的子树
    tree.lessthan = createTree(splitDataset(dataset,bestFeat,bestT,2),labels,heightcount+1) ;%小于阈值部分的子树
end

chooseBestFeatureToSplit函数:在createTree时,每次递归都要找那个当前最佳的特征和阈值,也就是调用chooseBestFeatureToSplit函数,所以两层循环,第一层遍历每个属性,第二层本应该遍历每个属性下的值,但是那样计算量太大了,所以我就将值排序之后分成10端取中位数遍历,在里面找阈值,如果当前节点的数据子集已经不足10个里,那就把所有属性都遍历一哈

function [ bestFeat,bestT ] = chooseBestFeatureToSplit( dataset )
    [~,numFeats] = size(dataset) ;
    numFeats = numFeats-1 ;%除去标签那一列
    baseEnt = getEnt(dataset) ;
    baseInfoGain = 0 ;
    bestFeat = -1 ;
    for i=1:numFeats
        featVec = dataset(:,i) ;
        %由于值是连续的,所以对于特征向量组排序分成n段取中位数
        sortedFeatVec = sort(featVec,'ascend') ;
        lengthofT = floor(sqrt(length(sortedFeatVec))) ; %取向量长度开根号来确定阈值的个数
        if lengthofT<10
            lengthofT = length(sortedFeatVec) ;
            selectedFeat = sortedFeatVec ;
        else
            step = floor(length(sortedFeatVec)/lengthofT) ;
            selectedFeat = zeros(lengthofT,1) ;
            for j=1:lengthofT
                head = (j-1)*step+1 ;
                tail = j*step ;
                subSortedFeatVec = sortedFeatVec(head:tail) ;
                selectedFeat(j) = median(subSortedFeatVec) ;
            end
        end
        for k=1:lengthofT
            newEnt = 0 ;
            for l=1:2
                subDataset = splitDataset(dataset,i,selectedFeat(k),l) ;
                prob = size(subDataset,1)/size(dataset,1) ;
                newEnt = newEnt + prob*getEnt(subDataset) ;
            end
            infoGain = baseEnt - newEnt ;
%             disp('infoGain') ;
%             disp(infoGain) ;
            if(infoGain>baseInfoGain)
                baseInfoGain = infoGain ;
                bestFeat= i ;
                bestT = selectedFeat(k) ;
            end
        end
    end
end

计算信息增益(infoGain)的时候需要用到getEnt(获取信息熵),splitDataset(通过属性和阈值分割数据集)函数

splitDataset:

function [ retDataset ] = splitDataset(dataset,axis,value,arg )
%axis 代表键值的位置 value表示阈值 返回划分后的dataset,arg表示取大于的部分(1)还是小于等于的部分
    if arg==1
        retDataset = dataset(dataset(:,axis)>value,:) ;
    else
        retDataset = dataset(dataset(:,axis)<=value,:) ;
    end
end
View Code

getEnt:

function [ ent ] = getEnt( data )
%index present the label
[datalen,~] = size(data) ;
maxLabel = max(data(:,end)) ;
labelCountsMap = zeros(maxLabel,1) ;%rings are all numbers
    for i=1:datalen
        label =  data(i,end) ;
        if labelCountsMap(label)~=0
            labelCountsMap(label) = labelCountsMap(label) + 1 ;
        else
            labelCountsMap(label) = 1 ; 
        end
    end
    ent = 0 ;
%     disp('labelMap') ;
%     disp(labelCountsMap) ;
    for i=1:maxLabel
        if labelCountsMap(i)~=0
            prob = labelCountsMap(i)/datalen ;
            ent = ent - prob*log2(prob) ;
        end
    end
end
View Code

最后预测函数:

function [ classVec ] = predict( tree , dataset , labels)
%tree应由createTree函数生成
    len = size(dataset,1) ;
    classVec = zeros(len,1) ;
    for i=1:len
        dataVec = dataset(i,1:end-1) ;
        tempnode = tree ;
        while(isstruct(tempnode))
            [~,tempFeatIdx] = ismember(tempnode.bestFeatLabel,labels) ;
            if(dataVec(tempFeatIdx)>tempnode.bestT)
                tempnode = tempnode.greaterthan ;
            else
                tempnode = tempnode.lessthan ;
            end
        end
        classVec(i) = tempnode ;
    end
end
View Code

更新了一下代码,加入了boostrap采样,就是有放回的采样,我是这样采用的,有多少个样本就进行多少次有放回采样,然后这个过程进行50次求均值。用了之后,官方的库正确率道理44%,而我的还在30%。。差距一下突显,还需继续学习。。

补充一下那个sampleWithReplace函数

function [ sample_data ] = sampleWithReplace( dataset )
    len = size(dataset,1) ;
    randidx = randsample(len,len,true) ;
    sample_data = dataset(randidx,:) ;
end

相关推荐