【NIPS2023】Rank-DETR for High Quality Object Detection

【NIPS2023】Rank-DETR for High Quality Object Detection

机构:清华大学、北京大学、剑桥大学、微软亚洲研究院

论文地址:https://arxiv.org/abs/2310.08854

代码地址:https://github.com/LeapLabTHU/Rank-DETR

作者简介:黄高,清华大学博士学位,康奈尔大学计算机系博士后,清华大学自动化系助理教授、博士生导师,获阿里巴巴“达摩院青橙奖”、2019年吴文俊人工智能优秀青年奖等。代表作DenseNet获得CVPR2017年最佳论文、Stochastic Depth。研究方向包括动态神经网络、高效深度学习。

本文考虑到DETR模型中query的重要性存在差异,致力于改进高IoU情况下(例如AP@75)的检测性能,首次提出基于排序思想的Rank-DETR,在Transformer中引入排序相关的网络层、排序导向的损失函数和匈牙利匹配损失。在COCO数据集上的性能高于DINO、Align-DETR、GroupDETR等baseline,与Stable-DINO、MS-DETR、Salience-DETR相当,弱于DDQ-DETR、Co-DETR、Relation-DETR等SOTA方法。

文章贡献/创新点

  • 在Transformer Decoder中提出了基于rank机制改进的分类头和query排序层。
  • 在损失函数(网络损失和匈牙利匹配损失)中对分类和回归分支进行对齐,使得高置信度的query也具有高IoU。
  • 实验验证了所提方法的有效性,并将rank机制引入到已有DETR方法中验证了有效性。

排序相关的结构设计

排序相关的结构设计

排序自适应的分类头

常规的DETR方法中,backbone提取多尺度特征,transformer将其映射为6层Decoder输出(两阶段方法还会多1层Encoder输出),每层的输出都是nn个query(原始DETR中n=100n=100、DeformableDETR中n=300n=300、DINO中n=900n=900),针对第ll层的每个query,表示为qilq_i^l,head将其映射为分类结果pil\boldsymbol p_i^l+回归结果,其中分类头是单层全连接:

pil=Sigmoid(ril),til=MLPcls(qil)\boldsymbol p_i^l=\mathrm{Sigmoid}(\boldsymbol r_i^l), \boldsymbol t_i^l=\text{MLP}_\text{cls}(\boldsymbol q_i^l)

本文提出的排序自适应分类头其实就是为每个query增加了对应的embedding,两者加起来再进行分类:

pil=Sigmoid(til+sil),til=MLPcls(qil)\boldsymbol p_i^l=\mathrm{Sigmoid}(\boldsymbol t_i^l+\boldsymbol s_i^l), \boldsymbol t_i^l=\text{MLP}_\text{cls}(\boldsymbol q_i^l)

这里的sil\boldsymbol s_i^l表示第ll层第ii个query的embedding。所有的s\boldsymbol s都作为网络参数自适应学习。

Query排序层

Transformer Decoder输入的query包含两部分,分别是content query和position query。其中content query作为网络可以学习的参数进行初始化,position query来自Transformer Encoder输出的候选框。RankDETR在每个Transformer解码层后增加了一个query rank layer来对这两种query进行排序。

对内容query进行排序

排序依据是每层输出的分类结果P^l1=MLPcls(Qcl1)\hat{\mathcal P}^{l-1}=\mathrm{MLP}_\text{cls}(\mathcal Q_c^{l-1}),排序层会按照该置信度对内容queryQc\mathcal Q_c进行降序排序,排序后的Q^cl\hat{\mathcal Q}_c^lCl\mathcal C^l进行拼接,这里的Cl\mathcal C^l同样是作为网络参数自适应学习。拼接后的结果在channel维度就变成了2倍,因此再经过全连接进行降维,得到下一层的query。整个流程:

QclMLPfuse(Q^cl1Cl),Q^cl1=Sort(Qcl1;P^l1)\overline{\mathcal Q}_c^l-\mathrm{MLP}_\text{fuse}(\hat{\mathcal Q}_c^{l-1}||\mathcal C^l), \hat{\mathcal Q}_c^{l-1}=\mathrm{Sort}(\mathcal Q_c^l-1;\hat{\mathcal P}^{l-1})

对位置query进行排序

位置query的排序方式在各个DETR方法中有所不同,针对H-DETR和Deformable DETR,作者也是按照P^l1\hat{\mathcal P}^{l-1}对每一层位置query进行降序排序:

Qpl=Sort(Qpl1;P^l1)\overline{\mathcal Q}_p^l=\mathrm{Sort}(\overline{\mathcal Q}_p^{l-1};\hat{\mathcal P}^{l-1})

由于DINO-DETR的位置query是上一层的位置query经过bounding box微调的结果,作者没有直接对Q^pl\hat{\mathcal Q}_p^l进行排序,而是先对检测框进行排序,再由排序后的检测框生成下一层的query:

Qpl=PE(Bl1),Bl1=Sort(Bl1;P^l1)\overline{\mathcal Q}_p^l=\mathrm{PE}(\overline{\mathcal B}^{l-1}),\overline{B}^{l-1}=\mathrm{Sort}(\mathcal B^{l-1};\hat{\mathcal P}^{l-1})

其中PE\mathrm{PE}表示正余弦编码和多层全连接。

代码实现

作者主要是基于H-DETR实现的Rank-DETR,因此代码没给出DINO-DETR的排序方式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# 对内容进行排序
output = torch.gather(
output, 1, rank_indices.unsqueeze(-1).repeat(1, 1, output.shape[-1])
)
# 排序后与C进行拼接,然后经过MLP
concat_term = self.pre_racq_trans[layer_idx - 1](
self.rank_aware_content_query[layer_idx - 1].weight[:output.shape[1]].unsqueeze(0).expand(output.shape[0], -1, -1)
)
output = torch.cat((output, concat_term), dim=2)
output = self.post_racq_trans[layer_idx - 1](output)

# 对未知进行排序
query_pos = torch.gather(
query_pos, 1, rank_indices.unsqueeze(-1).repeat(1, 1, query_pos.shape[-1])
)


# 省略中间代码......

# 获得排序依据:训练时有one2one query和one2many query,要分别对排序,推理时只有one2one query
if self.training:
rank_indices_one2one = torch.argsort(rank_basis[:, : self.num_queries_one2one], dim=1, descending=True) # tensor shape: [bs, num_queries_one2one]
rank_indices_one2many = torch.argsort(rank_basis[:, self.num_queries_one2one :], dim=1, descending=True) # tensor shape: [bs, num_queries_one2many]
rank_indices = torch.cat(
(
rank_indices_one2one,
rank_indices_one2many + torch.ones_like(rank_indices_one2many) * self.num_queries_one2one
),
dim=1,
) # tensor shape: [bs, num_queries_one2one+num_queries_one2many]
else:
rank_indices = torch.argsort(rank_basis[:, : self.num_queries_one2one], dim=1, descending=True)

排序相关的损失设计

损失函数设计

一般DETR的损失包括三部分,分类损失、定位损失和GIoU损失:

λ1GIoU(b^,b)+λ21(b^,b)+λ3FL(p^[c])-\lambda_1\mathrm{GIoU}(\hat{\boldsymbol b},\boldsymbol b)+\lambda_2\ell_1(\hat{\boldsymbol b},\boldsymbol b)+\lambda_3\mathrm{FL}(\hat{\boldsymbol p}[c])

作者提出的改进就在于其中的FL\mathrm{FL}分类损失上,作者将分类目标从原始的二分类0-1目标替换为了基于IoU\mathrm{IoU}的分类目标:

FLGIoU(p^[c])=tp^[c]γ[tlog(p^)]+(1t)log(1p^[c])\mathrm{FL}^{\text{GIoU}}(\hat{\boldsymbol p}[c])=-|t-\hat{\boldsymbol p}[c]|^\gamma\cdot[t\cdot\log(\hat{\boldsymbol p})]+(1-t)\cdot\log(1-\hat{\boldsymbol p}[c])

其中t=(GIoU(b^,b)+1)/2t=(\mathrm{GIoU}(\hat{\boldsymbol b}, \boldsymbol b)+1)/2。文章还对比了VFL的损失函数:

VFL(p^[c])=t[tlog(p^[c])+(1t)log(1p^[c])]\mathrm{VFL}(\hat{\boldsymbol p}[c])=-t\cdot[t\cdot\log(\hat{\boldsymbol p}[c])+(1-t)\cdot\log(1-\hat{\boldsymbol p}[c])]

可以看到,两者区别就在于对正样本监督的权值从tt变成了tp^[c]γ|t-\hat{\boldsymbol p}[c]|^\gamma

思考:一般来说DETR正样本监督式很稀缺的,因此在设计损失函数的时候(例如VFL和TOOD损失),正样本通常不进行难例挖掘。这里作者对正样本同样进行了类似GFL的难例挖掘设计,可能是考虑到H-DETR本身通过one2many的匹配设计弥补了正样本监督稀缺的问题,在此基础上进行难例挖掘则可能提升性能。

匹配损失函数设计

常规的匹配损失是对分类、回归和IoU损失进行加权,加权比例通常是2、5、2,文章提出了高阶的匹配损失:

LHungarianhigh-order=p^[c]IoUα\mathcal L_\text{Hungarian}^\text{high-order}=\hat{\boldsymbol p}[c]\cdot\text{IoU}^\alpha

其中p^[c]\hat{\boldsymbol p}[c]是分类头的输出。在代码实现中,前期仍然是使用常规的匹配损失,后期则是换成作者提出的高阶损失。

实验结果

从消融实验中来看,改进损失函数和匹配损失对性能提升的效果最大:

Rank-DETR单个模块消融实验

文章另外对比了自己提出的损失和VFL,发现使用自己提出的排序损失能够达到49.8%的AP,而VFL则只有49.5%AP,说明作者提出的排序损失能够更好地建模boxes之间地距离。

Rank-DETR虽然引入了2个结构上的改进,但都比较轻量化;而损失函数的改进不会对推理性能有影响,因此总体FPS基本没有太大下降:

Rank-DETR推理速度实验

和主流方法的性能对比实验可以看出,Rank-DETR在1×\times下的性能达到了50.2,高于DINO,但相比其他方法(DDQ-DETR、Relation-DETR、Co-DETR),性能还是弱一些,但胜在对推理速度没有太大影响。

Rank-DETR对比实验