机器学习入门第二课:决策树的可视化

原视频地址
分类器有很多种,比如神经网络、或支持向量机,决策树则是其中之一。决策树一大特点便是简单易读,便于理解。实际上,决策树是为数不多的可解释的分类器。你可以彻底理解为什么这个分类器做出了这样的选择。

鸢尾花数据

这节课使用一组真实的数据集-鸢尾花数据。“鸢尾花”是一个典型的机器学习问题。在这个问题中,我们使用不同的测量标准,例如花瓣的宽度、长度来辨别是哪种鸢尾花。在数据集里有三种不同的鸢尾花,山鸢尾、变色鸢尾、青龙鸢尾。

数据集中每种花有50个样本,所以共有150个样本。每个样本有四个特征值来描述,分别是萼片和花瓣的长度和宽度。每行数据的前面四列是特征值,最后一列是本行鸢尾花数据的种类,也就是标签。

鸢尾花数据集

本节课的目标就是通过决策器训练这些数据,然后可视化分类器的决策过程。

导入数据

scikit-learn 提供了一系列的样本数据集,我们可以很方便的导入到项目中。

from sklearn.datasets import load_iris
iris = load_iris()

拆分数据

我们需要 从样本数据中抽取部分数据作为验证的测试数据,剩余数据作为训练数据。

test_index = [0, 50, 100]

# traing data
train_target = np.delete(iris.target, test_index)
train_data = np.delete(iris.data, test_index, axis=0)

这里我们用到了 Numpy。Numpy - 是Python语言的一个扩充程序库,支持高级大量的维度数组矩阵运算,此外也针对数组运算提供大量的数学函数
我们先导入 Numpy,这个库包含在 anaconda3-4.4.0 中。

import numpy as np

训练数据

这部分和第一课的代码是一致的。

# testing data
test_target = iris.target[test_index]
test_data = iris.data[test_index]

clf = tree.DecisionTreeClassifier()
clf.fit(train_data, train_target)

print(test_target)

print(clf.predict(test_data))

可视化决策树

本课最关键的部分就是如何将决策树做出判断的过程可视化,需要用到 graphvizpydotplus
可视化代码如下:

# viz code

dot_data = StringIO()
tree.export_graphviz(clf,
                        out_file=dot_data,
                        feature_names=iris.feature_names,
                        class_names=iris.target_names,
                        filled=True, rounded=True,
                        impurity=False)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf('iris.pdf')

我们最终将生成 iris.pdf 的决策图,如下:

决策树

B 站视频网址

我顺便把 youtube 的视频嵌上字幕后上传到了 B 站,网址在 这里

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容