这篇我们分析一下deformable DETR的核心部分 multi-scale deformable attention。
首先看一下其数学形式:

Eq.2
其中M表示attention里面head的个数,L表示多个level的个数,K表示每个level上每个query采样的点。
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
"""
Multi-Scale Deformable Attention Module
:param d_model hidden dimension
:param n_levels number of feature levels
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
_d_per_head = d_model // n_heads
# 每个head的维度,对总的维度进行了均分
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head): # 为了更好的cuda操作
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
self.im2col_step = 64 # cuda加速
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
# 每个head为每个level产生n_point个点的偏置, 对应公式里的Delta
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
# 每个位置点的权重,由网络直接生成, 对应公式里的A_{mlqk}
self.value_proj = nn.Linear(d_model, d_model)
# 数据进行变换, 对应W_m'
self.output_proj = nn.Linear(d_model, d_model)
# 总体和进行再变换, 对应W_m
self._reset_parameters()
def _reset_parameters(self):
# 这里初始化不同的权重,采样不同的偏置点时有些特殊,不同的level不同的point初始偏置bias不同
constant_(self.sampling_offsets.weight.data, 0.)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
# 相当于每个level每个point偏置对应的head进行编码
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
# 对不同的偏置进行编码, 不同点的编码不同但不同level是相同的
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
# 每个query在不同的level的参考位置,即公式2的q
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
# 把不同的level特征flatten一起,所有key的个数,即所有level的像素点个数之和
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
# 每个level的尺寸
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
# 每个level的开始索引, 相当于不同的level进行序列排序后的索引
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
# bool,每个位置是否mask
:return output (N, Length_{query}, C)
"""
N, Len_q, _ = query.shape # batch size, query的个数
N, Len_in, _ = input_flatten.shape # Len_in是所有key的个数
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
# 每个query产生对应不同head不同level的偏置
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
# 每个偏置向量的权重
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
# 对属于同一个query的来自与不同level的offset后向量权重在每个head分别归一化
# N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
# sampling_offsets取值非0,1之间,因此这里相当于归一化后,计算$x_q+\Delta$
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
# 偏置是相对于目标框的归一化
else:
raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
output = MSDeformAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output)
return output
所以MSDeformAttn模块定义时需指定:
- d_model 隐变量的维度,也是变换后特征的维度
- n_level 多尺度融合的level数,取决于backbone的输入,比如resnet输出4个stage的输出
- n_heads 多头attention中的head的个数
- n_points 每个query在每个level中每个head中采样的点的个数,也就是说每个query其实采样的点数为 n_leveln_headsn_points, 默认128个点。
MSDeformAttn forward的输入:
- query query向量 batch_size x query个数 x 表征维度
- reference_points batch_size x query个数 x level个数 x 2 表示每个query在每个level中的参考位置,也就是公式中的
, 归一化的话其实每个level上的reference_points相同
- input_flatten batch_size x key的个数 x 特征维度, key包括所有level中的像素位置对应的特征向量。对应公式中的x
- input_level_shapes level个数 x 2, 表示每个level的feature map的尺寸,是 H x W
- input_level_start_index 每个level的key在总体的input_flatten中的初始位置, 这个量只在采样时使用
- input_padding_mask bool, 表示每个input_flatten位置的掩码
MSDeformAttn forward的输出:
- output attention融合之后的每个query的特征向量,长度和输入相同
MSDeformAttnFunction函数是调用的cuda编写的采样且求和的过程。我们可以通过其pytorch版本看起实现过程:
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape # batchsize, key个数, head个数, 维度
_, Lq_, M_, L_, P_, _ = sampling_locations.shape #Lq_: query个数, L_:level数, P_: 采样点个数
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) # 区分每个level的key
sampling_grids = 2 * sampling_locations - 1 # 因为需要使用grid_sample因此需要将采样点映射到-1,1之间
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
mode='bilinear', padding_mode='zeros', align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
return output.transpose(1, 2).contiguous()
可以发现其本质是利用F.grid_sample函数进行采样,该函数使用时需要将采样点归一化到之间。
输入value对应着keys, value_spatial_shapes用于对value进行拆分,拆分成不同的level,在不同的level中进行采样,每个level采样n_head*n_point个向量。这里想到之前《纯pytorch版本的deformable cnn的实现》进行采样的过程,其实循环部分也可以借鉴这里,直接将采样点并在一起进行采样。attention_weights分别加权每一个D维的向量,总共相当于每个query的L_*P_个特征进行加权求和。最终的输出是N\times Lq_\times d_model, 其中d_model是总的特征的维度,Lq_是query的个数。
之后在使用self-attention时,像在检测或者分割任务中,牵涉到局部相关性时可以直接使用该模块进行取代。
