首页 » 机器学习实战 » 机器学习实战全文在线阅读

《机器学习实战》9.4 树剪枝

关灯直达底部

一棵树如果节点过多,表明该模型可能对数据进行了“过拟合”。那么,如何判断是否发生了过拟合?前面章节中使用了测试集上某种交叉验证技术来发现过拟合,决策树亦是如此。本节将对此进行讨论,并分析如何避免过拟合。

通过降低决策树的复杂度来避免过拟合的过程称为剪枝(pruning)。其实本章前面已经进行过剪枝处理。在函数chooseBestSplit中的提前终止条件,实际上是在进行一种所谓的预剪枝(prepruning)操作。另一种形式的剪枝需要使用测试集和训练集,称作后剪枝(postpruning)。本节将分析后剪枝的有效性,但首先来看一下预剪枝的不足之处。

9.4.1 预剪枝

上节两个简单实验的结果还是令人满意的,但背后存在一些问题。树构建算法其实对输入的参数tolStolN非常敏感,如果使用其他值将不太容易达到这么好的效果。为了说明这一点,在Python提示符下输入如下命令:

>>> regTrees.createTree(myMat,ops=(0,1))  

与上节中只包含两个节点的树相比,这里构建的树过于臃肿,它甚至为数据集中每个样本都分配了一个叶节点。

图9-3中的散点图,看上去与图9-1非常相似。但如果仔细地观察y轴就会发现,前者的数量级是后者的100倍。这将不是问题,对吧?现在用该数据来构建一棵新的树(数据存放在ex2.txt中),在Python提示符下输入以下命令:

>>> myDat2=regTrees.loadDataSet(/'ex2.txt/')>>> myMat2=mat(myDat2)>>> regTrees.createTree(myMat2){/'spInd/': 0, /'spVal/': matrix([[ 0.499171]]), /'right/': {/'spInd/': 0, /'spVal/': matrix([[ 0.457563]]), /'right/': -3.6244789069767438, /'left/': 7.9699461249999999}, /'l..0, /'spVal/': matrix([[ 0.958512]]), /'right/': 112.42895575000001,/'left/': 105.2482350000001}}}}  

图9-3 将图9-1的数据的y轴放大100倍后的新数据集

不知你注意到没有,从图9-1数据集构建出来的树只有两个叶节点,而这里构建的新树则有很多叶节点。产生这个现象的原因在于,停止条件tolS对误差的数量级十分敏感。如果在选项中花费时间并对上述误差容忍度取平方值,或许也能得到仅有两个叶节点组成的树:

>>> regTrees.createTree(myMat2,ops=(10000,4)){/'spInd/': 0, /'spVal/': matrix([[ 0.499171]]), /'right/': -2.6377193297872341, /'left/': 101.35815937735855}  

然而,通过不断修改停止条件来得到合理结果并不是很好的办法。事实上,我们常常甚至不确定到底需要寻找什么样的结果。这正是机器学习所关注的内容,计算机应该可以给出总体的概貌。

下节将讨论后剪枝,即利用测试集来对树进行剪枝。由于不需要用户指定参数,后剪枝是一个更理想化的剪枝方法。

9.4.2 后剪枝

使用后剪枝方法需要将数据集分成测试集和训练集。首先指定参数,使得构建出的树足够大、足够复杂,便于剪枝。接下来从上而下找到叶节点,用测试集来判断将这些叶节点合并是否能降低测试误差。如果是的话就合并。

函数prune的伪代码如下:

基于已有的树切分测试数据:    如果存在任一子集是一棵树,则在该子集递归剪枝过程    计算将当前两个叶节点合并后的误差    计算不合并的误差    如果合并会降低误差的话,就将叶节点合并  

为了解实际效果,打开regTrees.py并输入程序清单9-3的代码。

程序清单9-3 回归树剪枝函数

def isTree(obj):    return (type(obj).__name__==/'dict/')def getMean(tree):    if isTree(tree[/'right/']): tree[/'right/'] = getMean(tree[/'right/'])    if isTree(tree[/'left/']): tree[/'left/'] = getMean(tree[/'left/'])    return (tree[/'left/']+tree[/'right/'])/2.0def prune(tree, testData):   #❶  没有测试数据则对树进行塌陷处理    if shape(testData)[0] == 0: return getMean(tree)    if (isTree(tree[/'right/']) or isTree(tree[/'left/'])):        lSet, rSet = binSplitDataSet(testData, tree[/'spInd/'],tree[/'spVal/'])    if isTree(tree[/'left/']): tree[/'left/'] = prune(tree[/'left/'], lSet)    if isTree(tree[/'right/']): tree[/'right/'] = prune(tree[/'right/'], rSet)    if not isTree(tree[/'left/']) and not isTree(tree[/'right/']):        lSet, rSet = binSplitDataSet(testData, tree[/'spInd/'],tree[/'spVal/'])        errorNoMerge = sum(power(lSet[:,-1] - tree[/'left/'],2)) +sum(power(rSet[:,-1] - tree[/'right/'],2))        treeMean = (tree[/'left/']+tree[/'right/'])/2.0        errorMerge = sum(power(testData[:,-1] - treeMean,2))        if errorMerge < errorNoMerge:            print /"merging/"            return treeMean        else: return tree    else: return tree    

程序清单9-3中包含三个函数:isTreegetMeanprune。其中isTree用于测试输入变量是否是一棵树,返回布尔类型的结果。换句话说,该函数用于判断当前处理的节点是否是叶节点。

函数getMean是一个递归函数,它从上往下遍历树直到叶节点为止。如果找到两个叶节点则计算它们的平均值。该函数对树进行塌陷处理(即返回树平均值),在prune函数中调用该函数时应明确这一点。

程序清单9-3的主函数是prune,它有两个参数:待剪枝的树与剪枝所需的测试数据testDataprune函数首先需要确认测试集是否为空❶。一旦非空,则反复递归调用函数prune对测试数据进行切分。因为树是由其他数据集(训练集)生成的,所以测试集上会有一些样本与原数据集样本的取值范围不同。一旦出现这种情况应当怎么办?数据发生过拟合应该进行剪枝吗?或者模型正确不需要任何剪枝?这里假设发生了过拟合,从而对树进行剪枝。

接下来要检查某个分支到底是子树还是节点。如果是子树,就调用函数prune来对该子树进行剪枝。在对左右两个分支完成剪枝之后,还需要检查它们是否仍然还是子树。如果两个分支已经不再是子树,那么就可以进行合并。具体做法是对合并前后的误差进行比较。如果合并后的误差比不合并的误差小就进行合并操作,反之则不合并直接返回。

接下来看看实际效果,将程序清单9-3的代码添加到regTrees.py文件并保存,在Python提示符下输入下面的命令:

>>> reload(regTrees)<module /'regTrees/' from /'regTrees.pyc/'>   

为了创建所有可能中最大的树,输入如下命令:

>>> myTree=regTrees.createTree(myMat2, ops=(0,1))  

输入以下命令导入测试数据:

>>> myDatTest=regTrees.loadDataSet(/'ex2test.txt/')>>> myMat2Test=mat(myDatTest)   

输入以下命令,执行剪枝过程:

>>> regTrees.prune(myTree, myMat2Test)mergingmergingmerging..merging{/'spInd/': 0, /'spVal/': matrix([[ 0.499171]]), /'right/': {/'spInd/': 0, /'spVal/':..01, /'left/': {/'spInd/': 0, /'spVal/': matrix([[ 0.960398]]), /'right/': 123.559747,    /'left/': 112.386764}}}, /'left/': 92.523991499999994}}}}   

可以看到,大量的节点已经被剪枝掉了,但没有像预期的那样剪枝成两部分,这说明后剪枝可能不如预剪枝有效。一般地,为了寻求最佳模型可以同时使用两种剪枝技术。

下节将重用部分已有的树构建代码来创建一种新的树。该树仍采用二元切分,但叶节点不再是简单的数值,取而代之的是一些线性模型。