使用Go 机器学习库来进行数据分析 2 (决策树)

鸟窝 2017-12-07 19:07

这篇文章, 继续使用golearn库分析鸢尾花的数据集。 这一次,我们会使用决策树和随机森林来分析。

决策树和随机森林

决策树是机器学习中最接近人类思考问题的过程的一种算法,通过若干个节点,对特征进行提问并分类(可以是二分类也可以使多分类),直至最后生成叶节点(也就是只剩下一种属性)。

每个决策树都表述了一种树型结构,它由它的分支来对该类型的对象依靠属性进行分类。每个决策树可以依靠对源数据库的分割进行数据测试。这个过程可以递归式的对树进行修剪。 当不能再进行分割或一个单独的类可以被应用于某一分支时,递归过程就完成了。另外,随机森林分类器将许多决策树结合起来以提升分类的正确率。

golearn支持两种决策树算法。ID3和RandomTree。

  • ID3 : 以信息增益为准则选择信息增益最大的属性。

    ID3 is a decision tree induction algorithm which splits on the Attribute which gives the greatest Information Gain (entropy gradient). It performs well on categorical data. Numeric datasets will need to be discretised before using ID3

  • RandomTree : 与ID3类似,但是选择的属性的时候随机选择。

    Random Trees are structurally identical to those generated by ID3, but the split Attribute is chosen randomly. Golearn's implementation allows you to choose up to k nodes for consideration at each split.

可以参考 ChongmingLiu的介绍:决策树(ID3 & C4.5 & CART)

维基百科中对随机森林的介绍:

在机器学习中,随机森林是一个包含多个决策树的分类器,并且其输出的类别是由个别树输出的类别的众数而定。 Leo Breiman和Adele Cutler发展出推论出随机森林的算法。而"Random Forests"是他们的商标。这个术语是1995年由贝尔实验室的Tin Kam Ho所提出的随机决策森林(random decision forests)而来的。这个方法则是结合Breimans的"Bootstrap aggregating"想法和Ho的"random subspace method" 以建造决策树的集合。

在机器学习中,随机森林由许多的决策树组成,因为这些决策树的形成采用了随机的方法,因此也叫做随机决策树。随机森林中的树之间是没有关联的。当测试数据进入随机森林时,其实就是让每一颗决策树进行分类,最后取所有决策树中分类结果最多的那类为最终的结果。因此随机森林是一个包含多个决策树的分类器,并且其输出的类别是由个别树输出的类别的众数而定。

代码

下面是使用决策树和随机森林预测鸢尾花分类的代码,来自golearn:

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
// Demonstrates decision tree classificationpackage mainimport (	"fmt"	"github.com/sjwhitworth/golearn/base"	"github.com/sjwhitworth/golearn/ensemble"	"github.com/sjwhitworth/golearn/evaluation"	"github.com/sjwhitworth/golearn/filters"	"github.com/sjwhitworth/golearn/trees"	"math/rand")func main() {	var tree base.Classifier	rand.Seed(44111342)	// Load in the iris dataset	iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)	if err != nil {		panic(err)	}	// Discretise the iris dataset with Chi-Merge	filt := filters.NewChiMergeFilter(iris, 0.999)	for _, a := range base.NonClassFloatAttributes(iris) {		filt.AddAttribute(a)	}	filt.Train()	irisf := base.NewLazilyFilteredInstances(iris, filt)	// Create a 60-40 training-test split	trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60)	//	// First up, use ID3	//	tree = trees.NewID3DecisionTree(0.6)	// (Parameter controls train-prune split.)	// Train the ID3 tree	err = tree.Fit(trainData)	if err != nil {		panic(err)	}	// Generate predictions	predictions, err := tree.Predict(testData)	if err != nil {		panic(err)	}	// Evaluate	fmt.Println("ID3 Performance (information gain)")	cf, err := evaluation.GetConfusionMatrix(testData, predictions)	if err != nil {		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))	}	fmt.Println(evaluation.GetSummary(cf))	tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.InformationGainRatioRuleGenerator))	// (Parameter controls train-prune split.)	// Train the ID3 tree	err = tree.Fit(trainData)	if err != nil {		panic(err)	}	// Generate predictions	predictions, err = tree.Predict(testData)	if err != nil {		panic(err)	}	// Evaluate	fmt.Println("ID3 Performance (information gain ratio)")	cf, err = evaluation.GetConfusionMatrix(testData, predictions)	if err != nil {		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))	}	fmt.Println(evaluation.GetSummary(cf))	tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.GiniCoefficientRuleGenerator))	// (Parameter controls train-prune split.)	// Train the ID3 tree	err = tree.Fit(trainData)	if err != nil {		panic(err)	}	// Generate predictions	predictions, err = tree.Predict(testData)	if err != nil {		panic(err)	}	// Evaluate	fmt.Println("ID3 Performance (gini index generator)")	cf, err = evaluation.GetConfusionMatrix(testData, predictions)	if err != nil {		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))	}	fmt.Println(evaluation.GetSummary(cf))	//	// Next up, Random Trees	//	// Consider two randomly-chosen attributes	tree = trees.NewRandomTree(2)	err = tree.Fit(trainData)	if err != nil {		panic(err)	}	predictions, err = tree.Predict(testData)	if err != nil {		panic(err)	}	fmt.Println("RandomTree Performance")	cf, err = evaluation.GetConfusionMatrix(testData, predictions)	if err != nil {		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))	}	fmt.Println(evaluation.GetSummary(cf))	//	// Finally, Random Forests	//	tree = ensemble.NewRandomForest(70, 3)	err = tree.Fit(trainData)	if err != nil {		panic(err)	}	predictions, err = tree.Predict(testData)	if err != nil {		panic(err)	}	fmt.Println("RandomForest Performance")	cf, err = evaluation.GetConfusionMatrix(testData, predictions)	if err != nil {		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))	}	fmt.Println(evaluation.GetSummary(cf))}

首选使用ChiMerge方法进行数据离散,ChiMerge 是监督的、自底向上的(即基于合并的)数据离散化方法。它依赖于卡方分析:具有最小卡方值的相邻区间合并在一起,直到满足确定的停止准则。
基本思想:对于精确的离散化,相对类频率在一个区间内应当完全一致。因此,如果两个相邻的区间具有非常类似的类分布,则这两个区间可以合并;否则,它们应当保持分开。而低卡方值表明它们具有相似的类分布。可以参考"Principles of Data Mining"第二版中的第105页--第115页。

接着调用base.NewLazilyFilteredInstances应用filter得到FixedDataGrid。

之后将数据集分成训练数据和测试数据两部分。

接下来就是训练数据、预测与评估了。

分别使用ID3、ID3 with InformationGainRatioRuleGenerator、ID3 with GiniCoefficientRuleGenerator、RandomTree、RandomForest算法进行处理。

评估结果

以下是各种算法的评估结果,可以和 kNN进行比较,看起来比不过kNN的预测。

123456789101112131415161718192021222324252627282930313233343536373839
ID3 Performance (information gain)Reference Class	True Positives	False Positives	True Negatives	Precision	Recall	F1 Score---------------	--------------	---------------	--------------	---------	------	--------Iris-virginica	32		5		46		0.8649		0.9697	0.9143Iris-versicolor	4		1		61		0.8000		0.1818	0.2963Iris-setosa	29		13		42		0.6905		1.0000	0.8169Overall accuracy: 0.7738ID3 Performance (information gain ratio)Reference Class	True Positives	False Positives	True Negatives	Precision	Recall	F1 Score---------------	--------------	---------------	--------------	---------	------	--------Iris-virginica	29		3		48		0.9062		0.8788	0.8923Iris-versicolor	5		3		59		0.6250		0.2273	0.3333Iris-setosa	29		15		40		0.6591		1.0000	0.7945Overall accuracy: 0.7500ID3 Performance (gini index generator)Reference Class	True Positives	False Positives	True Negatives	Precision	Recall	F1 Score---------------	--------------	---------------	--------------	---------	------	--------Iris-virginica	26		5		46		0.8387		0.7879	0.8125Iris-versicolor	17		36		26		0.3208		0.7727	0.4533Iris-setosa	0		0		55		NaN		0.0000	NaNOverall accuracy: 0.5119RandomTree PerformanceReference Class	True Positives	False Positives	True Negatives	Precision	Recall	F1 Score---------------	--------------	---------------	--------------	---------	------	--------Iris-virginica	30		3		48		0.9091		0.9091	0.9091Iris-versicolor	9		3		59		0.7500		0.4091	0.5294Iris-setosa	29		10		45		0.7436		1.0000	0.8529Overall accuracy: 0.8095RandomForest PerformanceReference Class	True Positives	False Positives	True Negatives	Precision	Recall	F1 Score---------------	--------------	---------------	--------------	---------	------	--------Iris-virginica	31		8		43		0.7949		0.9394	0.8611Iris-versicolor	0		0		62		NaN		0.0000	NaNIris-setosa	29		16		39		0.6444		1.0000	0.7838Overall accuracy: 0.7143

[返回] [原文链接]