CVPR2022. Compound Domain Generalization via Meta-Knowledge Encoding

CVPR2022. Compound Domain Generalization via Meta-Knowledge Encoding

论文地址
Arxiv

1
2
3
4
5
6
7
8
@InProceedings{Chen_2022_CVPR,
author = {Chen, Chaoqi and Li, Jiongcheng and Han, Xiaoguang and Liu, Xiaoqing and Yu, Yizhou},
title = {Compound Domain Generalization via Meta-Knowledge Encoding},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022},
pages = {7119-7129}
}

主要问题

领域泛化致力于通过来自源域的知识来提升对于未见过的新领域的泛化能力。

  • 主流的领域泛化方法假设源域样本的属于哪个领域是预先知道的,但实际应用中往往难以实现。
  • 现有的方法主要关注于多个领域之间的语义不变性(一对一),但很少关注整体的语义结构(多对多),这种元知识对于学习可以泛化的样本表示是很重要的。

文章提出COmpound domain generalization via Meta-knowledge ENcoding (COMEN)方法来自动发现和建模潜在的领域结构和关系。

文章贡献/创新点

  • 介绍一种实际应用中的复合领域泛化环境,其中源样本所述哪个领域是不知道的。文章提出一种统一的学习框架COMEN,联合发现和建模潜在的领域结构。
  • 提出SDNorm发现和重归一化多模态数据的潜在分布。
  • 提出两个互补模块ProtoGR和ProtoCCL来精确探索通用特征空间中特征原型的关系和交互。
  • 实验表明COMEN无需领域监督,即可超过大部分最新方法。

所提方法

总体结构包括三部分:SDNorm、ProtoGP和ProtoCCL。SDNorm使用无监督模式来解耦多模态数据分布,统计估计每个样本所属的领域。ProtoGR和ProtoCCL基于样本原型表示来建模不同类别的领域内和领域间语义关系,将语义结构编码进特征空间。

COMEN总体结构

潜在领域发现

为提高领域泛化能力,一些领域泛化方法简单地混合所有源样本而没有探索领域内关系。另外,探索潜在的领域分布仍然是领域泛化研究待解决的问题,表示学习方法可以很自然地将样本划分进不同的语义类别,但领域泛化方法需要领域信息来实现他们的框架。此外,使用标准的聚类方法来嵌入表示没有考虑每个样本之间的领域关系,因此效果有限。文章提出Style-induced Domain-specific Normalization(SD-Norm)模块来探索多模态数据分布模式。

给定CNN提取的图片特征图fRC×H×W\boldsymbol f\in\mathbb R^{C\times H\times W}。先计算按通道的平均和标准差:

μ(f)=1HWh=1Hw=1Wfh,wσ(f)=1HWh=1Hw=1W(fh,wμ(f))2+ϵ\begin{aligned} \mu(\boldsymbol f)=&\frac1{HW}\sum_{h=1}^H\sum_{w=1}^Wf_{h,w}\\ \sigma(\boldsymbol f)=&\sqrt{\frac1{HW}\sum_{h=1}^H\sum_{w=1}^W(f_{h,w}-\mu(\boldsymbol f))^2+\epsilon} \end{aligned}

将二者拼接,得到样本的样式表示sty(f)R2C\mathrm{sty}(\boldsymbol f)\in\mathbb R^{2C},然后送入领域标签预测器FdF_d中,输出预测的领域分配结果。优化采用熵最小化:

Ld=1Ni=1NFd(sty(f)i)logFd(sty(f)i)\mathcal L_d=-\frac1N\sum_{i=1}^NF_d(\mathrm{sty}(\boldsymbol f)_i)\log F_d(\mathrm{sty}(\boldsymbol f)_i)

这里文章写的不清楚,文章中FdF_d为预训练获得,而sty\mathrm{sty}又没有可学习的参数,那这里的Ld\mathcal L_d用来训练什么?

文章没有直接按照概率分配样本所属的领域,而是将属于每个领域的概率大小作为软标签,从而降低错误分配对后续特征重标准化的影响。

然后将计算得到的领域分配结果送入Batch Normlization中进行重归一化:

z^m=zμmσm2+ϵ\hat{\boldsymbol z}_m=\frac{\boldsymbol z-\mu_m}{\sqrt{\sigma_m^2+\epsilon}}

其中μm=i=1Bpi,mzi\mu_m=\sum_{i=1}^{|B|}p_{i,m}\boldsymbol z^iσm2=i=1Bpi,m(ziμm)2\sigma_m^2=\sum_{i=1}^{|B|}p_{i,m}(\boldsymbol z^i-\mu_m)^2zi\boldsymbol z^i为样本ii的特征图,pi,mp_{i,m}FdF_d预测的样本ii属于领域mm的概率。

可以看出此处和原始Batch Normalization不一致之处在于,这里的均值和方差根据FdF_d预测的概率进行了加权平均。
文章中Batch Normalization的输入为单个样本特征图,内部会对所属的每个领域都做一次归一化,得到每个领域的zm\boldsymbol z_m,因此输出数目会扩增到领域数。而经典的Batch Normalization输入一个样本时,输出也是一个样本。

最后执行类似Batch Normalization的重映射:

SDNormm(z;λm,βm)=λmz^m+βm\mathrm{SDNorm}_m(\boldsymbol z;\lambda_m,\beta_m)=\lambda_m\cdot \hat{\boldsymbol z}_m+\beta_m

其中λm\lambda_mβm\beta_m是领域mm可学习的参数。文章对于每个领域和每个通道都进行独立的归一化,因此有M×CM\times Cλ\lambdaβ\beta

原型关系建模

主流领域泛化方法通过对齐策略,例如时刻匹配、对抗学习,来强迫语义在共享嵌入空间的一致性,但很少考虑不同种类之间的语义关系。因此只寻找一对一的种类对齐不能确保学习到可泛化到新场景的表示。原型表示在小样本学习、领域自适应和无监督学习中显示出了强大的效果,为了解决原型表示的原型结构通过无监督模式泛化到新领域的问题,文章提出了ProtoGR和ProtoCCL。

首先计算出来全局原型:

cmk=1Dmk=(xiDm,yiDm)Dmkf(xiDm)c_m^k=\frac1{|\mathcal D_m^k|}=\sum_{(x_i^{\mathcal D_m},y_i^{\mathcal D_m})\in\mathcal D_m^k}f(x_i^{\mathcal D_m})

其中Dmk\mathcal D_m^k表示属于类别kk和领域mm的样本集合。

前文说软标签可以降低对于重归一化的影响,但文章此处划分样本归属的类别和领域仍使用的硬标签,有什么意义吗?

每次迭代中采样小批量的样本,然后计算局部原型,并采用滑动平均的方式来更新全局原型:

cm(I)kρcm(I1)k+(1ρ)c^m(I)kc_{m(I)}^k\leftarrow \rho c_{m(I-1)}^k+(1-\rho)\hat{c}_{m(I)}^k

其中指数衰减率ρ=0.7\rho=0.7II为迭代次数。

原型图推理(ProtoGR)

全局类别原型的数目为M×KM\times K,维度为dd,将其作为节点来建立图模型,节点特征定义为X={x1,,xMK}RMK×d\mathbf X=\{\boldsymbol x_1,\cdots,\boldsymbol x_{MK}\}\in\mathbb R^{MK\times d}。邻接矩阵定义为ARMK×MK\mathbf A\in\mathbb R^{MK\times MK},其每个元素代表节点的互相关性:

Aij=1(xixjxi2xj2>δ)xixjxi2xj2A_{ij}=\mathbb 1\left(\frac{\boldsymbol x_i^\top\boldsymbol x_j}{\|\boldsymbol x_i\|_2\cdot\|\boldsymbol x_j\|_2}>\delta\right)\cdot\frac{\boldsymbol x_i^\top\boldsymbol x_j}{\|\boldsymbol x_i\|_2\cdot\|\boldsymbol x_j\|_2}

其中xi\boldsymbol x_ixj\boldsymbol x_j代表第ii和第jj个原型特征,δ=0.5\delta=0.5控制稀疏度。为建模长期依赖,引入图注意力:

xil+1=σ(jNiαij(i)Wxj(l))\boldsymbol x_i^{l+1}=\sigma\left(\sum_{j\in\mathcal N_i}\alpha_{ij}^{(i)}\cdot\mathbf{W}\boldsymbol x_j^{(l)}\right)

其中xj(l)\boldsymbol x_j^{(l)}为第ll层的隐藏层特征。图中两节点iijj的边的权重为:

αij(l)=Aijexp(LReLU(a(l)[Wxi(l)Wxj(l)]))kNiAikexp(LReLU(a(l)[Wxi(l)Wxk(l)]))\alpha_{ij}^{(l)}=\frac{A_{ij}\exp(\mathrm{LReLU}(\boldsymbol a_{(l)}^\top[\mathbf{W}\boldsymbol x_i^{(l)}||\mathbf W\boldsymbol x_j^{(l)}]))}{\sum_{k\in\mathcal N_i}A_{ik}\exp(\mathrm{LReLU}(\boldsymbol a_{(l)}^\top[\mathbf W\boldsymbol x_i^{(l)}||\mathbf W\boldsymbol x_k^{(l)}]))}

其中a(l)\boldsymbol a_{(l)}是可学习的权重向量,||为拼接操作。

文章堆叠了两层图注意力,并使用残差连接来增强判别行:

X(L)=X(L)+X(0)\mathbf X^{(L)}=\mathbf X^{(L)}+\mathbf X^{(0)}

最后增加节点分类网络,利用G\mathcal G中的信息预测类别原型类别。

y^=softmax(FC(ProtoGR(x,G)))\hat y=\mathrm{softmax}(\mathrm{FC}(\mathrm{ProtoGR}(x,\mathcal G)))

其中xxy^\hat y分别表示节点特征和其对应的预测类别标签。

这里文章写的很不清楚,节点特征前文是用xR1×d\boldsymbol x\in\mathbb R^{1\times d}来表示,现在又使用xx来表示,这两处是否是同一含义,节点特征是向量还是标量?

上述节点分类损失定义为LProtoGR\mathcal L_\text{ProtoGR}

原型类别感知对比学习(Proto CCL)

对比学习方法在自监督表示学习中具有很好的效果,其中InfoNCE是一种代表性的对比学习损失:

LIInfoNCE=logexp(vv+/τ)exp(vv+/τ)+vNIexp(vv/τ)\mathcal L_I^\text{InfoNCE}=-\log\frac{\exp(\boldsymbol v\cdot\boldsymbol v^+/\tau)}{\exp(\boldsymbol v\cdot\boldsymbol v^+/\tau)+\sum_{\boldsymbol v^-\in\mathcal N_I}\exp(\boldsymbol v\cdot\boldsymbol v^-/\tau)}

其中v+\boldsymbol v^+为正嵌入样本,即出当前原型外的其他和II同一类别样本,NI\mathcal N_I表示负样本,即属于其他类别的样本,τ\tau是温度超参数。

ProtoCCL的输入为当前类别kk对应的原型c\boldsymbol c

LkProtoCCL=1Ckc+Cklogexp(cc+/τ)exp(cc+/τ)+cNkexp(cc/τ)\mathcal L_k^\text{ProtoCCL}=\frac1{|\mathcal C_k|}\sum_{\boldsymbol c^+\in\mathcal C_k}-\log\frac{\exp(\boldsymbol c\cdot\boldsymbol c^+/\tau)}{\exp(\boldsymbol c\cdot\boldsymbol c^+/\tau)+\sum_{\boldsymbol c^-\in\mathcal N_k}\exp(\boldsymbol c\cdot\boldsymbol c^-/\tau)}

其中Ck\mathcal C_kNk\boldsymbol N_k分别为原型c\boldsymbol c的正样本和负样本集合。

训练损失

总的训练损失分成两个阶段,第一阶段使用Ld\mathcal L_dLcls\mathcal L_\text{cls}训练伪领域标签和类别标签,其中类别标签预测的损失函数为:

Lcls=1Ni=1Nj=1K1[yi=j]log(FcG(xi))\mathcal L_\text{cls}=-\frac1N\sum_{i=1}^N\sum_{j=1}^K\mathbb 1[y_i=j]\log(F_c\circ G(x_i))

其中GG表示特征提取器,FcF_c表示类别标签预测器。

这里文章说使用Lcls\mathcal L_\text{cls}来获得伪标签就很迷惑,前文说伪领域标签由预训练的FdF_d输出,和这里的Lcls\mathcal L_\text{cls}有关系吗?Lcls\mathcal L_{cls}表达式的计算似乎也没有依赖伪标签。另外也没交代特征提取器GGFcF_c是什么,姑且认为是经典的CNN特征提取器或分类器吧。

第二阶段总损失函数为:

LCOMEN=Lcls+λLProtoRP+γLProtoCCL\mathcal L_\text{COMEN}=\mathcal L_\text{cls}+\lambda\mathcal L_\text{ProtoRP}+\gamma\mathcal L_\text{ProtoCCL}

其中λ\lambdaγ\gamma为权衡参数。

文章没交代LProtoCCL\mathcal L_\text{ProtoCCL}LkProtoCCL\mathcal L_k^\text{ProtoCCL}的关系。

仿真实验

文章介绍了一些不错的领域泛化基准数据集:PACS、Digits-DG、VLCS和Office-Home。

数据集 样本量 类别 领域
PACS 9991 7 Photo,Art Painting, Cartoon, Sketch
Digits-DG MNIST,MNIST-M,SVHN,SYN
VLCS 5 PASCAL VOC 2007, LabelMe, Caltech, Sun
Office-Home 15500 65 Artistic, Clipart, Product, Real World

文章采用常用的leave-one-domain-out评价方法,即留出1个作为新领域,剩下认为是源领域。

实验结果:

PACS中元知识编码领域泛化结果对比

Digits-DG中元知识编码领域泛化结果对比

VLCS中元知识编码领域泛化结果对比

Office-Home中元知识编码领域泛化结果对比

元知识编码消融实验和结果可视化

文章指出,COMEN只适用于分类任务,无法直接扩展到目标检测和语义分割任务。