attention机制的热度图

1. 举个🌰

def plot_attention(data, X_label=None, Y_label=None):
  '''
    Plot the attention model heatmap
    Args:
      data: attn_matrix with shape [ty, tx], cutted before 'PAD'
      X_label: list of size tx, encoder tags
      Y_label: list of size ty, decoder tags
  '''
  fig, ax = plt.subplots(figsize=(20, 8)) # set figure size
  heatmap = ax.pcolor(data, cmap=plt.cm.Blues, alpha=0.9)
  
  # Set axis labels
  if X_label != None and Y_label != None:
    X_label = [x_label.decode('utf-8') for x_label in X_label]
    Y_label = [y_label.decode('utf-8') for y_label in Y_label]
    
    xticks = range(0,len(X_label))
    ax.set_xticks(xticks, minor=False) # major ticks
    ax.set_xticklabels(X_label, minor = False, rotation=45)   # labels should be 'unicode'
    
    yticks = range(0,len(Y_label))
    ax.set_yticks(yticks, minor=False)
    ax.set_yticklabels(Y_label, minor = False)   # labels should be 'unicode'
    
    ax.grid(True)

2. 参数

X_label: 是encoder的句子一个一个word组成的list;
Y_label: 是decoder的句子一个一个word组成的list

3. 结果展示

热度图.png
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • Spring Cloud为开发人员提供了快速构建分布式系统中一些常见模式的工具(例如配置管理,服务发现,断路器,智...
    卡卡罗2017阅读 135,080评论 19 139
  • 背景 一年多以前我在知乎上答了有关LeetCode的问题, 分享了一些自己做题目的经验。 张土汪:刷leetcod...
    土汪阅读 12,788评论 0 33
  • 要是关注深度学习在自然语言处理方面的研究进展,我相信你一定听说过Attention Model(后文有时会简称AM...
    MiracleJQ阅读 2,792评论 1 6
  • 原文地址 要是关注深度学习在自然语言处理方面的研究进展,我相信你一定听说过Attention Model(后文有时...
    Henrywood阅读 1,737评论 0 5
  • 对比两篇论文 : 其中一篇是A Neural Attention Model for Abstractive Se...
    MiracleJQ阅读 3,570评论 0 1