香农熵-计算最好的数据集划分方式

对于以下数据

data = [
# 不浮出水面是否可以生存, 有脚蹼, 属于鱼类
        1, 1, 'yes',  # 究竟什么鱼满足这个条件我很好奇
        1, 1, 'yes',  # 
        1, 0, 'no',  # 海胆
        0, 1, 'no',  # 鸭子
        0, 1, 'no' ] # 企鹅

计算出 “不浮出水面是否可以生存”与“有脚蹼”这两个特征值,哪个与是否属于鱼类更相关。
以下算法的原理是:拿出指定的特征值,计算剩下的数据的熵,熵越大,也就是数据越混乱,说明被拿出的数据越重要
代码如下

# -*- encoding:utf-8 -*-
import math
import numpy as np
__author__ = 'Butters'


def get_gain(p):
    """
    信息增益值
    """
    return -math.log(p, 2)


def get_ent(*p):
    """
    熵
    """
    return sum([i * get_gain(i) for i in p])


def test():
    data = [
        # 不浮出水面是否可以生存, 有脚蹼, 属于鱼类
        1, 1, 'yes',  # 究竟什么鱼满足这个条件我很好奇
        1, 1, 'yes',  # 
        1, 0, 'no',  # 海胆
        0, 1, 'no',  # 鸭子
        0, 1, 'no' ] # 企鹅
    dataset = np.reshape(data, (5, 3))
    chooseBestFeatureToSplit(dataset)


def calcShannonEnt(dataset):
    print '=====>data set is'
    print dataset
    numEntries = len(dataset)
    labelCounts = {}
    for featVec in dataset:
        feat = featVec[-1]
        if feat not in labelCounts:
            labelCounts[feat] = 0
        labelCounts[feat] += 1
    shannonEnt = get_ent(*[float(labelCounts[key]) / float(numEntries) for key in labelCounts])
    print 'shannon ent is ', shannonEnt
    return shannonEnt


def splitDataset(dataset, axis, value):
    """
    获取dataset里的第axis轴值等于 value的
    :param dataset:
    :param axis:第axis列特征值
    :param value:
    :return:
    """
    m = np.array([row for row in dataset if row[axis] == value])
    return np.delete(m, axis, axis=1)


def chooseBestFeatureToSplit(dataset):
    numFeatures = len(dataset[0]) - 1  # 2
    baseEntries = calcShannonEnt(dataset)
    print 'base entries is', baseEntries
    bestInfoGain = 0.0
    bestFeature = -1
    for i in xrange(numFeatures):
        featList = [example[i] for example in dataset]
        featSet = set(featList)  # 获取取值范围
        print 'feat set is', featSet
        tempEntry = 0.0
        for value in featSet:
            subDataset = splitDataset(dataset, i, value)
            prob = float(len(subDataset)) / float(len(dataset))
            tempEntry += (prob * calcShannonEnt(subDataset))
        infoGain = baseEntries - tempEntry
        print '-----------after calculate ----------'
        print tempEntry
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    print 'best feature is', bestFeature
    print 'best info gain is', bestInfoGain


if __name__ == '__main__':
    test()

运行结果:

best feature is 0
best info gain is 0.419973094022

所以第0个特征值是我们要的。
我们可以通过简单的逻辑来验证一下,5个例子中:

  • 满足 “不浮出水面也能生存”的是鱼, “不浮出水面不能生存”的不是鱼条件的有4个。
  • 满足 "有脚蹼"的是鱼, “没脚蹼“的不是鱼条件的有3个。
    所以第0个特征值确实是要重要一些.

以上部分代码来源于《机器学习实战》,略有简化与修改

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 越来越多的人感慨花钱如流水,钱越来越不经用。可比花钱更恐怖的事是时间在悄无声息中已经过了100天了,感觉过年的画面...
    芬芬vstar阅读 1,781评论 2 4
  • 其实,我不坚强 其实,我不懂事 其实,我不成熟 其实,我不独立 其实,我不优秀 其实,我害怕 其实,我怯懦 其实,...
    上YE阅读 1,417评论 0 1
  • 文/六天笑笑 亲爱的写作老友: 你好!见字如面! 请允许我选择这种方式来表达我对你的感激之情,谢谢你对我的不离不弃...
    六天笑笑阅读 3,021评论 4 10
  • 小时侯的我, 总爱玩沙子, 抓一把在手中, 然后看着沙子, 慢慢从缝隙里逃走... 长大了的我, 虽然不再玩, 但...
    shevian阅读 1,144评论 0 0