阅读笔记-CoaT: Co-Scale Conv-Attentional Image Transformers

来源: arXiv:2104.06399v1
代码:https://github.com/mlpc-ucsd/CoaT

title


Introduction

  1. Transformer 和CNN的区别。
    CNN通过多层的卷积网络逐步扩大感受野,实现content和context的信息融合。而transformer中每一层的每一个位置所能感受的区域都是整个featmap,且计算attention时使用的时两个位置特征的内积本身就刻画了特征向量的二阶相关性。另外,对于cnn而言,在推理时其模型参数是固定的,但是对于transformer而言,推理时使用的attention是和不同的content相关的,显然适应性更强。但transformer相对于CNN的区别在于其模型更加复杂,计算量更大。
  2. 本文贡献点:
  • conv-attention,其实主要是指计算相对位置编码时采用的类卷积方式,另外为了
    进一步降低计算量,还简化了attention的方式,即factorized attention。 两个模
    块是相辅相成的。
  • 提出了一种co-scale机制,其实类似于U-net的作用,将transformer的不同stage
    的输出特征进行跨尺度的融合。

方法

原始transformer中attention的计算方式:
Attn(X) = softmax(\frac{QK^T}{\sqrt{C}})V

  1. Factorized Attention Mechanism
    在原始的计算attention的过程中,空间复杂度是O(N^2), 时间复杂度是O(N^2C), 为了降低复杂度,类似于LambdaNet中的做法,将attention的方法改为如下形式:
    FactorAttn(X) = \frac{Q}{\sqrt{C}}(softmax(K)^TV)
    我们来看一下这个改变,首先空间复杂度变为O(NC), 时间复杂度变为O(NC^2), 因为N>>C所以复杂度都降为原来的C/N倍。另一方面在计算原始的attention时可以明确解释attention是当前位置与其他位置的相似度,但在factor attn的计算过程中并不是很好解释,而且丢失了内积过程。虽然FactorAttn不是对attn的直接近似,但是也是一种泛化的注意力机制有query,key和value
  2. Convolutional Relative Position Encoding
    本章中认为对于相同的q,其输出的注意力是相同的,但原始的transformer中相同的q获得的attention也是相同的。所以position encoding的使用并不是FactorAttn导致的,而是transformer本质上带来的序列无序问题。因此transformer都是需要加入pos encoding的。
    在原始的transformer中,pos embeding的计算方式是在输入encoder前就把pos embedding 加到了x上,除此之外pos的处理方式还包括如下形式:
  • DETR的形式, 加在每个attention中
    softmax(\frac{(Q+P)(K+P)^T}{\sqrt{C}})V
  • Swin transformer, 此时每个位置只有一个bias,位置编码是position bias
    softmax(\frac{QK^T}{\sqrt{C}}+E)V
    本文方法的pos emb加入方式和swin transformer很类似
    \frac{Q}{\sqrt{C}} softmax(K)^TV + EV
    不同的是这里的E是局部的,而swin transformer中在每个window内其实是全局的。
    E_{ij} = 1(i,j)q_i\cdot p_j, 1\e i, j\e N
    1(i,j)表示以(i,j)为中心的window 指示函数。可以进一步表示
    EV_{ij} = Q_{ij}\cdot \sum_{|k-i|<M/2, |l-j|<M/2} P_{ij}\cdot V_{ij}
    这个操作无法直接用卷积完成,EV的计算时空复杂度分别为O(N^2C), O(N^2), 为了降低复杂度,作者将EV进一步简化成每个通道作为一个head进行操作,于是
    \hat{EV}_{ij}^l = \sum_{|k-i|<M/2, |l-j|<M/2} Q_{ij}^lP_{ij}^lV_{ij}^l
    可以写成2d depth-wise的卷积形式
    \hat{EV}_{ij} = Q_{ij}\odot \text{ DepthwiseConv2d}(P, V)
    此时时空复杂度分辨降为O(NCM^2), O(NC), M是window的size,上述步骤可以通过下图直观理解
    conovolutional relative position encoding

考虑到ViT处理分类任务时,还有一个class token,该token无法嵌入到2D featuremap中,所以class token此时的relative positive encoding为0。

  1. 卷积位置编码
    在每个attention模块之前,作者对输入也引入了类似的卷积位置编码,和原始transformer中直接加上一组pos emb不同,这里pos emb进行的时相对的位置编码,通过depthwise 卷积获得。即
    x = x + \text{DepthwiseConv2D}(P, x)
    可以发现与卷积相对位置编码还是不同的,没有hadmard积的过程,于是整个Conv-Attentional Mechanism如下图所示,这个图其实绘制的存在问题,otimes的意义不同,左侧俩时矩阵乘法,右侧的时哈达马积。
    conv-attentional module

Co-Scale Conv-Attention Transformers

作者提出的CoaT Serial Block和金字塔时的transformer的Block结构类似,只是内部attention的计算方式不同,真正实现cross-scale attention的方式时在每个stage的输出使用Parallel Block进行融合,方式有一下两种:


CoaT parallel block
  1. direct cross-layer attention 是将不同scale的特征分别作为query, key和value进行融合,这时候通过上下采样同意不同scale。
  2. Attention with feature interpolation则是把不同scale经过自己的attention module之后的特征进行相加然后经过FFN融合送入下一个block。

实验

实验对比了两个模型CoaT和CoaT-Lite,分别对应着是否使用ParallelBlock。

  1. 消融实验


    image.png

    image.png

首先分析了relative position encoding 和conv pos encoding的作用,可以发现使用pos encoding的作用都会提升,但是conv pos encoding的作用明显优于relative pos encoding,这也说明局部信息融合后送入attention会一定程度提升性能。存在一个疑问:没有对比使用其他pos encoding的性能,比如直接使用EV会怎么样?

.
image.png

证实了使用co-scale融合的性能,且特征融合的方式优于直接cross attention的方式。

  1. 对比试验
  • 分类任务上目前超过了所有相似大小的基于transformer模型性能,包括swin transformer
  • 检测任务上没有和swin transformer直接对比

结论

本文的主要目标其实还是降低模型计算和空间复杂度,主要是factorized attention和conv pos encoding值得借鉴。细节还等着代码放出。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容