CVPR2022. Compound Domain Generalization via Meta-Knowledge Encoding
论文地址
Arxiv
1 2 3 4 5 6 7 8 @InProceedings{Chen_2022_CVPR, author = {Chen, Chaoqi and Li, Jiongcheng and Han, Xiaoguang and Liu, Xiaoqing and Yu, Yizhou}, title = {Compound Domain Generalization via Meta-Knowledge Encoding}, booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, month = {June}, year = {2022}, pages = {7119-7129} }
主要问题
领域泛化致力于通过来自源域的知识来提升对于未见过的新领域的泛化能力。
主流的领域泛化方法假设源域样本的属于哪个领域是预先知道的,但实际应用中往往难以实现。
现有的方法主要关注于多个领域之间的语义不变性(一对一),但很少关注整体的语义结构(多对多),这种元知识对于学习可以泛化的样本表示是很重要的。
文章提出CO mpound domain generalization via M eta-knowledge EN coding (COMEN)方法来自动发现和建模潜在的领域结构和关系。
文章贡献/创新点
介绍一种实际应用中的复合领域泛化环境,其中源样本所述哪个领域是不知道的。文章提出一种统一的学习框架COMEN,联合发现和建模潜在的领域结构。
提出SDNorm发现和重归一化多模态数据的潜在分布。
提出两个互补模块ProtoGR和ProtoCCL来精确探索通用特征空间中特征原型的关系和交互。
实验表明COMEN无需领域监督,即可超过大部分最新方法。
所提方法
总体结构包括三部分:SDNorm、ProtoGP和ProtoCCL。SDNorm使用无监督模式来解耦多模态数据分布,统计估计每个样本所属的领域。ProtoGR和ProtoCCL基于样本原型表示来建模不同类别的领域内和领域间语义关系,将语义结构编码进特征空间。
潜在领域发现
为提高领域泛化能力,一些领域泛化方法简单地混合所有源样本而没有探索领域内关系。另外,探索潜在的领域分布仍然是领域泛化研究待解决的问题,表示学习方法可以很自然地将样本划分进不同的语义类别,但领域泛化方法需要领域信息来实现他们的框架。此外,使用标准的聚类方法来嵌入表示没有考虑每个样本之间的领域关系,因此效果有限。文章提出Style-induced Domain-specific Normalization(SD-Norm)模块来探索多模态数据分布模式。
给定CNN提取的图片特征图f ∈ R C × H × W \boldsymbol f\in\mathbb R^{C\times H\times W} f ∈ R C × H × W 。先计算按通道的平均和标准差:
μ ( f ) = 1 H W ∑ h = 1 H ∑ w = 1 W f h , w σ ( f ) = 1 H W ∑ h = 1 H ∑ w = 1 W ( f h , w − μ ( f ) ) 2 + ϵ \begin{aligned}
\mu(\boldsymbol f)=&\frac1{HW}\sum_{h=1}^H\sum_{w=1}^Wf_{h,w}\\
\sigma(\boldsymbol f)=&\sqrt{\frac1{HW}\sum_{h=1}^H\sum_{w=1}^W(f_{h,w}-\mu(\boldsymbol f))^2+\epsilon}
\end{aligned}
μ ( f ) = σ ( f ) = H W 1 h = 1 ∑ H w = 1 ∑ W f h , w H W 1 h = 1 ∑ H w = 1 ∑ W ( f h , w − μ ( f ) ) 2 + ϵ
将二者拼接,得到样本的样式表示s t y ( f ) ∈ R 2 C \mathrm{sty}(\boldsymbol f)\in\mathbb R^{2C} sty ( f ) ∈ R 2 C ,然后送入领域标签预测器F d F_d F d 中,输出预测的领域分配结果。优化采用熵最小化:
L d = − 1 N ∑ i = 1 N F d ( s t y ( f ) i ) log F d ( s t y ( f ) i ) \mathcal L_d=-\frac1N\sum_{i=1}^NF_d(\mathrm{sty}(\boldsymbol f)_i)\log F_d(\mathrm{sty}(\boldsymbol f)_i)
L d = − N 1 i = 1 ∑ N F d ( sty ( f ) i ) log F d ( sty ( f ) i )
这里文章写的不清楚,文章中F d F_d F d 为预训练获得,而s t y \mathrm{sty} sty 又没有可学习的参数,那这里的L d \mathcal L_d L d 用来训练什么?
文章没有直接按照概率分配样本所属的领域,而是将属于每个领域的概率大小作为软标签,从而降低错误分配对后续特征重标准化的影响。
然后将计算得到的领域分配结果送入Batch Normlization中进行重归一化:
z ^ m = z − μ m σ m 2 + ϵ \hat{\boldsymbol z}_m=\frac{\boldsymbol z-\mu_m}{\sqrt{\sigma_m^2+\epsilon}}
z ^ m = σ m 2 + ϵ z − μ m
其中μ m = ∑ i = 1 ∣ B ∣ p i , m z i \mu_m=\sum_{i=1}^{|B|}p_{i,m}\boldsymbol z^i μ m = ∑ i = 1 ∣ B ∣ p i , m z i ,σ m 2 = ∑ i = 1 ∣ B ∣ p i , m ( z i − μ m ) 2 \sigma_m^2=\sum_{i=1}^{|B|}p_{i,m}(\boldsymbol z^i-\mu_m)^2 σ m 2 = ∑ i = 1 ∣ B ∣ p i , m ( z i − μ m ) 2 。z i \boldsymbol z^i z i 为样本i i i 的特征图,p i , m p_{i,m} p i , m 为F d F_d F d 预测的样本i i i 属于领域m m m 的概率。
可以看出此处和原始Batch Normalization不一致之处在于,这里的均值和方差根据F d F_d F d 预测的概率进行了加权平均。
文章中Batch Normalization的输入为单个样本特征图,内部会对所属的每个领域都做一次归一化,得到每个领域的z m \boldsymbol z_m z m ,因此输出数目会扩增到领域数。而经典的Batch Normalization输入一个样本时,输出也是一个样本。
最后执行类似Batch Normalization的重映射:
S D N o r m m ( z ; λ m , β m ) = λ m ⋅ z ^ m + β m \mathrm{SDNorm}_m(\boldsymbol z;\lambda_m,\beta_m)=\lambda_m\cdot \hat{\boldsymbol z}_m+\beta_m
SDNorm m ( z ; λ m , β m ) = λ m ⋅ z ^ m + β m
其中λ m \lambda_m λ m 和β m \beta_m β m 是领域m m m 可学习的参数。文章对于每个领域和每个通道都进行独立的归一化,因此有M × C M\times C M × C 个λ \lambda λ 和β \beta β 。
原型关系建模
主流领域泛化方法通过对齐策略,例如时刻匹配、对抗学习,来强迫语义在共享嵌入空间的一致性,但很少考虑不同种类之间的语义关系。因此只寻找一对一的种类对齐不能确保学习到可泛化到新场景的表示。原型表示在小样本学习、领域自适应和无监督学习中显示出了强大的效果,为了解决原型表示的原型结构通过无监督模式泛化到新领域的问题,文章提出了ProtoGR和ProtoCCL。
首先计算出来全局原型:
c m k = 1 ∣ D m k ∣ = ∑ ( x i D m , y i D m ) ∈ D m k f ( x i D m ) c_m^k=\frac1{|\mathcal D_m^k|}=\sum_{(x_i^{\mathcal D_m},y_i^{\mathcal D_m})\in\mathcal D_m^k}f(x_i^{\mathcal D_m})
c m k = ∣ D m k ∣ 1 = ( x i D m , y i D m ) ∈ D m k ∑ f ( x i D m )
其中D m k \mathcal D_m^k D m k 表示属于类别k k k 和领域m m m 的样本集合。
前文说软标签可以降低对于重归一化的影响,但文章此处划分样本归属的类别和领域仍使用的硬标签,有什么意义吗?
每次迭代中采样小批量的样本,然后计算局部原型,并采用滑动平均的方式来更新全局原型:
c m ( I ) k ← ρ c m ( I − 1 ) k + ( 1 − ρ ) c ^ m ( I ) k c_{m(I)}^k\leftarrow \rho c_{m(I-1)}^k+(1-\rho)\hat{c}_{m(I)}^k
c m ( I ) k ← ρ c m ( I − 1 ) k + ( 1 − ρ ) c ^ m ( I ) k
其中指数衰减率ρ = 0.7 \rho=0.7 ρ = 0.7 ,I I I 为迭代次数。
原型图推理(ProtoGR)
全局类别原型的数目为M × K M\times K M × K ,维度为d d d ,将其作为节点来建立图模型,节点特征定义为X = { x 1 , ⋯ , x M K } ∈ R M K × d \mathbf X=\{\boldsymbol x_1,\cdots,\boldsymbol x_{MK}\}\in\mathbb R^{MK\times d} X = { x 1 , ⋯ , x M K } ∈ R M K × d 。邻接矩阵定义为A ∈ R M K × M K \mathbf A\in\mathbb R^{MK\times MK} A ∈ R M K × M K ,其每个元素代表节点的互相关性:
A i j = 1 ( x i ⊤ x j ∥ x i ∥ 2 ⋅ ∥ x j ∥ 2 > δ ) ⋅ x i ⊤ x j ∥ x i ∥ 2 ⋅ ∥ x j ∥ 2 A_{ij}=\mathbb 1\left(\frac{\boldsymbol x_i^\top\boldsymbol x_j}{\|\boldsymbol x_i\|_2\cdot\|\boldsymbol x_j\|_2}>\delta\right)\cdot\frac{\boldsymbol x_i^\top\boldsymbol x_j}{\|\boldsymbol x_i\|_2\cdot\|\boldsymbol x_j\|_2}
A ij = 1 ( ∥ x i ∥ 2 ⋅ ∥ x j ∥ 2 x i ⊤ x j > δ ) ⋅ ∥ x i ∥ 2 ⋅ ∥ x j ∥ 2 x i ⊤ x j
其中x i \boldsymbol x_i x i 和x j \boldsymbol x_j x j 代表第i i i 和第j j j 个原型特征,δ = 0.5 \delta=0.5 δ = 0.5 控制稀疏度。为建模长期依赖,引入图注意力:
x i l + 1 = σ ( ∑ j ∈ N i α i j ( i ) ⋅ W x j ( l ) ) \boldsymbol x_i^{l+1}=\sigma\left(\sum_{j\in\mathcal N_i}\alpha_{ij}^{(i)}\cdot\mathbf{W}\boldsymbol x_j^{(l)}\right)
x i l + 1 = σ j ∈ N i ∑ α ij ( i ) ⋅ W x j ( l )
其中x j ( l ) \boldsymbol x_j^{(l)} x j ( l ) 为第l l l 层的隐藏层特征。图中两节点i i i 和j j j 的边的权重为:
α i j ( l ) = A i j exp ( L R e L U ( a ( l ) ⊤ [ W x i ( l ) ∣ ∣ W x j ( l ) ] ) ) ∑ k ∈ N i A i k exp ( L R e L U ( a ( l ) ⊤ [ W x i ( l ) ∣ ∣ W x k ( l ) ] ) ) \alpha_{ij}^{(l)}=\frac{A_{ij}\exp(\mathrm{LReLU}(\boldsymbol a_{(l)}^\top[\mathbf{W}\boldsymbol x_i^{(l)}||\mathbf W\boldsymbol x_j^{(l)}]))}{\sum_{k\in\mathcal N_i}A_{ik}\exp(\mathrm{LReLU}(\boldsymbol a_{(l)}^\top[\mathbf W\boldsymbol x_i^{(l)}||\mathbf W\boldsymbol x_k^{(l)}]))}
α ij ( l ) = ∑ k ∈ N i A ik exp ( LReLU ( a ( l ) ⊤ [ W x i ( l ) ∣∣ W x k ( l ) ])) A ij exp ( LReLU ( a ( l ) ⊤ [ W x i ( l ) ∣∣ W x j ( l ) ]))
其中a ( l ) \boldsymbol a_{(l)} a ( l ) 是可学习的权重向量,∣ ∣ || ∣∣ 为拼接操作。
文章堆叠了两层图注意力,并使用残差连接来增强判别行:
X ( L ) = X ( L ) + X ( 0 ) \mathbf X^{(L)}=\mathbf X^{(L)}+\mathbf X^{(0)}
X ( L ) = X ( L ) + X ( 0 )
最后增加节点分类网络,利用G \mathcal G G 中的信息预测类别原型类别。
y ^ = s o f t m a x ( F C ( P r o t o G R ( x , G ) ) ) \hat y=\mathrm{softmax}(\mathrm{FC}(\mathrm{ProtoGR}(x,\mathcal G)))
y ^ = softmax ( FC ( ProtoGR ( x , G )))
其中x x x 和y ^ \hat y y ^ 分别表示节点特征和其对应的预测类别标签。
这里文章写的很不清楚,节点特征前文是用x ∈ R 1 × d \boldsymbol x\in\mathbb R^{1\times d} x ∈ R 1 × d 来表示,现在又使用x x x 来表示,这两处是否是同一含义,节点特征是向量还是标量?
上述节点分类损失定义为L ProtoGR \mathcal L_\text{ProtoGR} L ProtoGR
原型类别感知对比学习(Proto CCL)
对比学习方法在自监督表示学习中具有很好的效果,其中InfoNCE是一种代表性的对比学习损失:
L I InfoNCE = − log exp ( v ⋅ v + / τ ) exp ( v ⋅ v + / τ ) + ∑ v − ∈ N I exp ( v ⋅ v − / τ ) \mathcal L_I^\text{InfoNCE}=-\log\frac{\exp(\boldsymbol v\cdot\boldsymbol v^+/\tau)}{\exp(\boldsymbol v\cdot\boldsymbol v^+/\tau)+\sum_{\boldsymbol v^-\in\mathcal N_I}\exp(\boldsymbol v\cdot\boldsymbol v^-/\tau)}
L I InfoNCE = − log exp ( v ⋅ v + / τ ) + ∑ v − ∈ N I exp ( v ⋅ v − / τ ) exp ( v ⋅ v + / τ )
其中v + \boldsymbol v^+ v + 为正嵌入样本,即出当前原型外的其他和I I I 同一类别样本,N I \mathcal N_I N I 表示负样本,即属于其他类别的样本,τ \tau τ 是温度超参数。
ProtoCCL的输入为当前类别k k k 对应的原型c \boldsymbol c c :
L k ProtoCCL = 1 ∣ C k ∣ ∑ c + ∈ C k − log exp ( c ⋅ c + / τ ) exp ( c ⋅ c + / τ ) + ∑ c − ∈ N k exp ( c ⋅ c − / τ ) \mathcal L_k^\text{ProtoCCL}=\frac1{|\mathcal C_k|}\sum_{\boldsymbol c^+\in\mathcal C_k}-\log\frac{\exp(\boldsymbol c\cdot\boldsymbol c^+/\tau)}{\exp(\boldsymbol c\cdot\boldsymbol c^+/\tau)+\sum_{\boldsymbol c^-\in\mathcal N_k}\exp(\boldsymbol c\cdot\boldsymbol c^-/\tau)}
L k ProtoCCL = ∣ C k ∣ 1 c + ∈ C k ∑ − log exp ( c ⋅ c + / τ ) + ∑ c − ∈ N k exp ( c ⋅ c − / τ ) exp ( c ⋅ c + / τ )
其中C k \mathcal C_k C k 和N k \boldsymbol N_k N k 分别为原型c \boldsymbol c c 的正样本和负样本集合。
训练损失
总的训练损失分成两个阶段,第一阶段使用L d \mathcal L_d L d 和L cls \mathcal L_\text{cls} L cls 训练伪领域标签和类别标签,其中类别标签预测的损失函数为:
L cls = − 1 N ∑ i = 1 N ∑ j = 1 K 1 [ y i = j ] log ( F c ∘ G ( x i ) ) \mathcal L_\text{cls}=-\frac1N\sum_{i=1}^N\sum_{j=1}^K\mathbb 1[y_i=j]\log(F_c\circ G(x_i))
L cls = − N 1 i = 1 ∑ N j = 1 ∑ K 1 [ y i = j ] log ( F c ∘ G ( x i ))
其中G G G 表示特征提取器,F c F_c F c 表示类别标签预测器。
这里文章说使用L cls \mathcal L_\text{cls} L cls 来获得伪标签就很迷惑,前文说伪领域标签由预训练的F d F_d F d 输出,和这里的L cls \mathcal L_\text{cls} L cls 有关系吗?L c l s \mathcal L_{cls} L c l s 表达式的计算似乎也没有依赖伪标签。另外也没交代特征提取器G G G 和F c F_c F c 是什么,姑且认为是经典的CNN特征提取器或分类器吧。
第二阶段总损失函数为:
L COMEN = L cls + λ L ProtoRP + γ L ProtoCCL \mathcal L_\text{COMEN}=\mathcal L_\text{cls}+\lambda\mathcal L_\text{ProtoRP}+\gamma\mathcal L_\text{ProtoCCL}
L COMEN = L cls + λ L ProtoRP + γ L ProtoCCL
其中λ \lambda λ 和γ \gamma γ 为权衡参数。
文章没交代L ProtoCCL \mathcal L_\text{ProtoCCL} L ProtoCCL 和L k ProtoCCL \mathcal L_k^\text{ProtoCCL} L k ProtoCCL 的关系。
仿真实验
文章介绍了一些不错的领域泛化基准数据集:PACS、Digits-DG、VLCS和Office-Home。
数据集
样本量
类别
领域
PACS
9991
7
Photo,Art Painting, Cartoon, Sketch
Digits-DG
—
—
MNIST,MNIST-M,SVHN,SYN
VLCS
—
5
PASCAL VOC 2007, LabelMe, Caltech, Sun
Office-Home
15500
65
Artistic, Clipart, Product, Real World
文章采用常用的leave-one-domain-out评价方法,即留出1个作为新领域,剩下认为是源领域。
实验结果:
文章指出,COMEN只适用于分类任务,无法直接扩展到目标检测和语义分割任务。