当数据拥有众多特征并且特征之间关系十分复杂时,构建全局模型的想法就显得太难了。实际生活中很多问题都是非线性的,不可能使用全局性模型来拟合任何数据。
一种可行的方法就是将数据集切分成很多易建模的数据,然后再利用线性回归技术来建模。本章介绍一种新的叫做CART的树构建算法,该算法可用于分类也可用于回归。
复杂数据的局部性建模
树构建算法:
- ID3 每次选取当前最佳的特征来分割数据,并按照该特征的所有可能取值来切分。如果一个特征有4个取值,那么数据将被切分成4份。一旦按某特征切分后,该特征在之后的算法执行中将不会再起作用。
- 二分法 每次把数据切成两份。如果特征值大于给定值就走左子树,否则就走右子树。
CART使用二元切分来处理连续型变量。对CART稍作修改就可以处理回归问题。
在树的构建过程中,需要解决多种类型数据的存储问题。这里使用字典来存储树的数据结构,字典包括1、待切分的特征 2、待切分的特征值。 3、右子树。当不再需要切分的时候也可以是单个值。4、左子树。 与右子树类似。
'''
创建树节点
'''
class treeNode():
def __init__(self,feat,val,right,left):
self.featureToSplitOn = feat
self.valueOfSpit = val
self.rightBranch = right
self.leftBranch = left
'''
加载数据
'''
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float,curLine))
dataMat.append(fltLine)
return dataMat
'''
构建树
dataSet 数据集
leafType 建立叶子节点的函数
errType 代表误差计算函数
ops 代表树构建所需其他参数的元组
'''
def createTree(dataSet,leafType = regLeaf,errType = regErr,ops=(1,4)):
feat,val = chooseBestSplit(dataSet,leafType,errType,ops)
if feat == None:return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet,rSet = bindSplitDataSet(dataSet,feat,val)
retTree['left'] = createTree(lSet,leafType,errType,ops)
retTree['right'] = createTree(rSet,leafType,errType,ops)
return retTree
'''
负责生成叶子节点 当chooseBestSplit确定不再对数据进行切分时,调用子方法得到叶子节点
该函数就是目标变量的均值
'''
def regLeaf(dataSet):
return mean(dataSet[:,-1])
'''
在给定数据上计算目标变量的平均误差
'''
def regErr(dataSet):
return var(dataSet[:,-1]) * shape(dataSet)[0]
'''
三个参数 分别为数据集合,待切分的特征和该特诊的某个值
'''
def bindSplitDataSet(dataSet,feature,value):
mats0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
mats1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
mat0 = [];mat1 = []
if len(mats0) > 1:mat0 = mats0
if len(mats1) > 1:mat1 = mats1
return mat0,mat1
'''
找到最佳二元切分
dataSet 数据集
leafType 建立叶子节点的函数
errType 代表误差计算函数
ops 代表树构建所需其他参数的元组
'''
def chooseBestSplit(dataSet,leafType=regLeaf,errType = regErr,ops=(1,4)):
#ops的两个值 用于控制函数的退出机制
#tolS容许的误差下降值
#tolN切分的最好样本
tolS = ops[0];tolN = ops[1]
#如果所有值相等则退出 如果该数目为1 就不需要再切分了
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None,leafType(dataSet)
m,n = shape(dataSet)
S = errType(dataSet)
bestS = inf; bestIndx = 0; bestValue = 0
# 在所有可能的特征和可能取值上遍历
#最佳的切分就是切分后能达到最低误差的切分
for featIndex in range(n-1):
#书中源代码有错误。
for splitVal in set(dataSet[:,featIndex].T.A.tolist()[0]):
mat0,mat1 = bindSplitDataSet(dataSet,featIndex,splitVal)
if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndx = featIndex
bestValue = splitVal
bestS = newS
#如果误差减少不大则退出
if(S - bestS) < tolS:
return None,leafType(dataSet)
mat0,mat1 = bindSplitDataSet(dataSet,bestIndx,bestValue)
#如果切分的数据小则退出
if(shape(mat0)[0] < tolN) or (shape(mat1)[0]<tolN):
return None,leafType(dataSet)
return bestIndx,bestValue
if __name__ == '__main__':
myDat = loadDataSet('eex00.txt')
myMat = mat(myDat)
retTree = createTree(myMat)
print(retTree)
>>> {'left': 1.0180967672413792, 'spVal': 0.48813, 'spInd': 0, 'right': -0.044650285714285719}
if __name__ == '__main__':
myDat = loadDataSet('regTreesex0.txt')
myMat = mat(myDat)
retTree = createTree(myMat)
print(retTree)
>>> {'right': {'right': -0.023838155555555553, 'left': 1.0289583666666666, 'spInd': 1, 'spVal': 0.197834},
'left': {'right': 1.980035071428571, 'left': {'right': 2.9836209534883724, 'left': 3.9871631999999999, 'spInd': 1, 'spVal': 0.797583}, 'spInd': 1, 'spVal': 0.582002}, 'spInd': 1, 'spVal': 0.39435}