代码阅读-deformable DETR (二)

这篇我们分析一下deformable DETR的核心部分 multi-scale deformable attention。
首先看一下其数学形式:


Eq.2

其中M表示attention里面head的个数,L表示多个level的个数,K表示每个level上每个query采样的点。

class MSDeformAttn(nn.Module):
    def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
        """
        Multi-Scale Deformable Attention Module
        :param d_model      hidden dimension  
        :param n_levels     number of feature levels
        :param n_heads      number of attention heads
        :param n_points     number of sampling points per attention head per feature level
        """
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
        _d_per_head = d_model // n_heads 
        # 每个head的维度,对总的维度进行了均分
        # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
        if not _is_power_of_2(_d_per_head):  # 为了更好的cuda操作
            warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
                          "which is more efficient in our CUDA implementation.")

        self.im2col_step = 64  # cuda加速

        self.d_model = d_model
        self.n_levels = n_levels
        self.n_heads = n_heads
        self.n_points = n_points

        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)  
        # 每个head为每个level产生n_point个点的偏置, 对应公式里的Delta
        self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)       
        # 每个位置点的权重,由网络直接生成, 对应公式里的A_{mlqk}
        self.value_proj = nn.Linear(d_model, d_model)                                 
        # 数据进行变换, 对应W_m'
        self.output_proj = nn.Linear(d_model, d_model)                                 
        # 总体和进行再变换, 对应W_m

        self._reset_parameters()

    def _reset_parameters(self):  
        # 这里初始化不同的权重,采样不同的偏置点时有些特殊,不同的level不同的point初始偏置bias不同
        constant_(self.sampling_offsets.weight.data, 0.)
        thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
        grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
        # 相当于每个level每个point偏置对应的head进行编码
        for i in range(self.n_points):
            grid_init[:, :, i, :] *= i + 1  
            # 对不同的偏置进行编码, 不同点的编码不同但不同level是相同的
        with torch.no_grad():
            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
        constant_(self.attention_weights.weight.data, 0.)
        constant_(self.attention_weights.bias.data, 0.)
        xavier_uniform_(self.value_proj.weight.data)
        constant_(self.value_proj.bias.data, 0.)
        xavier_uniform_(self.output_proj.weight.data)
        constant_(self.output_proj.bias.data, 0.)

    def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
        """
        :param query                       (N, Length_{query}, C)
        :param reference_points            (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
                                        or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes  
        # 每个query在不同的level的参考位置,即公式2的q
        :param input_flatten               (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
        # 把不同的level特征flatten一起,所有key的个数,即所有level的像素点个数之和
        :param input_spatial_shapes        (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
        # 每个level的尺寸
        :param input_level_start_index     (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]  
        # 每个level的开始索引, 相当于不同的level进行序列排序后的索引
        :param input_padding_mask          (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 
        # bool,每个位置是否mask

        :return output                     (N, Length_{query}, C)
        """
        N, Len_q, _ = query.shape       # batch size, query的个数
        N, Len_in, _ = input_flatten.shape      # Len_in是所有key的个数
        assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in

        value = self.value_proj(input_flatten)
        if input_padding_mask is not None:
            value = value.masked_fill(input_padding_mask[..., None], float(0))
        value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
        sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)  
        # 每个query产生对应不同head不同level的偏置
        attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)   
         # 每个偏置向量的权重
        attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 
         # 对属于同一个query的来自与不同level的offset后向量权重在每个head分别归一化
        # N, Len_q, n_heads, n_levels, n_points, 2
        if reference_points.shape[-1] == 2:
            offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
            sampling_locations = reference_points[:, :, None, :, None, :] \
                                 + sampling_offsets / offset_normalizer[None, None, None, :, None, :]  
         # sampling_offsets取值非0,1之间,因此这里相当于归一化后,计算$x_q+\Delta$
        elif reference_points.shape[-1] == 4:
            sampling_locations = reference_points[:, :, None, :, None, :2] \
                                 + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5  
        # 偏置是相对于目标框的归一化
        else:
            raise ValueError(
                'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
        output = MSDeformAttnFunction.apply(
            value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
        output = self.output_proj(output)
        return output

所以MSDeformAttn模块定义时需指定:

  1. d_model 隐变量的维度,也是变换后特征的维度
  2. n_level 多尺度融合的level数,取决于backbone的输入,比如resnet输出4个stage的输出
  3. n_heads 多头attention中的head的个数
  4. n_points 每个query在每个level中每个head中采样的点的个数,也就是说每个query其实采样的点数为 n_leveln_headsn_points, 默认128个点。

MSDeformAttn forward的输入:

  1. query query向量 batch_size x query个数 x 表征维度
  2. reference_points batch_size x query个数 x level个数 x 2 表示每个query在每个level中的参考位置,也就是公式中的\phi(p_q), 归一化的话其实每个level上的reference_points相同
  3. input_flatten batch_size x key的个数 x 特征维度, key包括所有level中的像素位置对应的特征向量。对应公式中的x
  4. input_level_shapes level个数 x 2, 表示每个level的feature map的尺寸,是 H x W
  5. input_level_start_index 每个level的key在总体的input_flatten中的初始位置, 这个量只在采样时使用
  6. input_padding_mask bool, 表示每个input_flatten位置的掩码

MSDeformAttn forward的输出:

  1. output attention融合之后的每个query的特征向量,长度和输入相同

MSDeformAttnFunction函数是调用的cuda编写的采样且求和的过程。我们可以通过其pytorch版本看起实现过程:

def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
    # for debug and test only,
    # need to use cuda version instead
    N_, S_, M_, D_ = value.shape        # batchsize, key个数, head个数, 维度
    _, Lq_, M_, L_, P_, _ = sampling_locations.shape    #Lq_: query个数, L_:level数, P_: 采样点个数 
    value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) # 区分每个level的key
    sampling_grids = 2 * sampling_locations - 1  # 因为需要使用grid_sample因此需要将采样点映射到-1,1之间
    sampling_value_list = []
    for lid_, (H_, W_) in enumerate(value_spatial_shapes):
        # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
        value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
        # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
        sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
        # N_*M_, D_, Lq_, P_
        sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
                                          mode='bilinear', padding_mode='zeros', align_corners=False)
        sampling_value_list.append(sampling_value_l_)
    # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
    attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
    output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
    return output.transpose(1, 2).contiguous()

可以发现其本质是利用F.grid_sample函数进行采样,该函数使用时需要将采样点归一化到[-1,1]之间。
输入value对应着keys, value_spatial_shapes用于对value进行拆分,拆分成不同的level,在不同的level中进行采样,每个level采样n_head*n_point个向量。这里想到之前《纯pytorch版本的deformable cnn的实现》进行采样的过程,其实循环部分也可以借鉴这里,直接将采样点并在一起进行采样。attention_weights分别加权每一个D维的向量,总共相当于每个query的L_*P_个特征进行加权求和。最终的输出是N\times Lq_\times d_model, 其中d_model是总的特征的维度,Lq_是query的个数。

之后在使用self-attention时,像在检测或者分割任务中,牵涉到局部相关性时可以直接使用该模块进行取代。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容