CVPR2022. Pin the Memory Learning to Generalize Semantic Segmentation

Pin the Memory: Learning to Generalize Semantic Segmentation

论文地址


深度学习模型在语义分割领域取得了一些突破,但在源域训练的模型通常在新的挑战域中不能正常发挥效果,因而会影响泛化性能。文章基于元学习框架提出了记忆引导的领域泛化方法,该方法抽取出领域不变的语义类别概念知识,融入到类别记忆中。根据元学习的概念,文章反复训练记忆引导的网络,并模拟虚拟测试来:1)学习如何记忆领域无关的和独特的类信息;2)提供外部设置的记忆作为类别指导,以减少在任意新领域测试时数据表达的模糊性。文章提出了记忆发散和特征凝聚力损失,以指导面向类别感知领域泛化过程的记忆读取和更新过程。在多种基准数据集上的大量实验,表明了模型相比目前最新方法具有更好的泛化性能。

现有问题

语义分割近期的许多进展主要来自于在大批量稠密标注数据集上的深度神经网络,但在给定数据集(源域)上训练的模型不能很好地迁移到模型训练过程中没有见过的新领域(目标域)。克服两个领域分布的差异对于处理意外和未见过的新数据非常重要,尤其是在医疗诊断、自动驾驶等一些代替人工的任务上。

为了解决领域迁移导致的性能下降,目前有两种方法:无监督领域自适应方法领域泛化方法。

无监督领域自适应方法(UDA)致力于通过来自目标域的无标签数据来弥补领域之间的差异。它们采用的策略包括学习出领域不变的特征,或者将源域和目标域对齐到统一空间。但是目标域的数据收集经常难以实现,此外该方法需要在目标域上微调或重训练,因此模型尺度受到较大限制,因而无法泛化到“任何”没见过的领域。

领域泛化方法(DG)致力于学习出能够应对各种未见过的数据分布的泛化模型,由于训练过程中没有目标域的数据,因此实现起来相比UDA更困难。有些方法启发式地定义领域偏置信息定义为风格(纹理、颜色)信息,或明确地增强它们,或通过实例标准化和通道协方差白化来消除风格,但应用到实际领域中效果有限。

文章认为人与机器不同之处在于,人具有概念知识(语义记忆),时从具体地经历中以一种可重用形式抽象出来地,并且能够推广到多种认知活动,例如事件重构、目标识别。因此文章认为人类地知识概念可以通过记住每个类别的共享信息来有效支持领域泛化。例如不同领域中汽车的形状可能变化,但轮胎、门、车头灯等基本组件是不变的,因此这种并行特征的先验知识指引能够提高模型泛化能力。

文章主要目的是将每个类别的共享信息放到额外的内存部件中,通过重用这种类别概念来构建适用于任何未见过领域的鲁棒语义分割,实现类别感知泛化的语义分割,而非以往方法的全局性推理表示。

文章贡献/创新点

  • 文章通过使用内存模块来利用语义类别知识信息来实现领域泛化。
  • 引入记忆指导的元学习算法,通过将模型暴露在不相匹配的数据中来提升记忆引导特征的表示能力。
  • 提出了两个互补损失:记忆分散损失和特征凝聚损失,以促进嵌入特征寻找恰当的类别记忆的能力。
  • 实验证明了类别感知泛化在单源设置和多源设置中的有效性。

相关工作

领域自适应和泛化

许多研究致力于通过减小领域分布的差异来提高深度网络的泛化能力,无监督领域自适应通过利用无标签目标域的信息用于训练,来矫正这种不匹配,近期训练数据来自多个合成数据集的多源UDA方法已经引入到更多实际场景中。单尽管如此,深度网络仍总会遭受未见过的新领域,产生领域泛化问题。领域泛化研究大致分为两类:学习领域不变特征,增强训练样本。然而领域泛化仍然集中在整张图片的分类任务,文章则能够解决城市场景中的语义分割问题。

此处领域泛化的研究表述,在Introduction中是学习不变性特征、对齐源域和目标域到统一空间;而在Related work则变为了学习不变性特征、增强训练样本。

面向语义分割的领域泛化

为解决领域泛化问题,领域随机性方法使用数据增强来生成新样本,但会增加训练消耗并且难以真正覆盖实际场景中的数据分布。受标准化操作的直觉启发,一些方法尝试标准化全局特征,消除每个领域的特有的样式信息,但仅关注到了全局特征表示。近期论文指出,从合成数据中学习到的特征的多样性对于防止语义分割任务中的源域过拟合问题具有关键作用。因此文章采用元学习框架,虚拟地测试不同数据分布下的存储记忆,以使得类别通用信息用于模型泛化。

记忆引导的领域泛化元学习方法

元学习

模型可知的元学习方法,采用一种偶发式的训练策略来实现面向小样本学习的多阶梯度下降方法,通过分离的元训练和元测试来模拟训练和评估步骤。Zhen等人提出长期记忆来存储语义信息,其中来自更新记忆中的梯度并不反馈给网络。Zhao等人简单地将记忆认为是非参数化模块,来解决子网之间的异步梯度更新会使元优化不稳定的问题。本文方法和这些方法正交,目的是通过元学习学习网络来推广类别记忆的更新和读取过程。

记忆网络

记忆网络通过稳定地记忆信息来增强网络的容量。文章使用的记忆模块在整个元学习训练步骤中均保持长期记忆,因此能够鲁棒地读取和写入记忆来帮助领域泛化。相比计算量较高的记忆读取和数据集级别的多类别情景信息记忆,文章所提方法仅需要一次估计即可读取信息,并且能够实现语义类别级的通用特征存储。

所提方法

领域泛化旨在一组可观测的源域来学习通用的语义分割网络,其中网络包含编码器和解码器。直观想法是将所有已有的源域数据放到一起来训练语义分割网络,但会导致网络过于适用于源域,面向新的目标域时产生巨大的精度衰减。

文章提出记忆引导的元学习框架来防止语义分割模型在面向测试时未见过的领域时导致的精度下降,总体结构:

记忆引导的领域泛化元学习训练流程

文章使用数据增强或者领域拆分来人工实现领域切换,使网络能够在特定的领域更新和读取记忆,从而网络学习如何记住领域切换时的概念知识。

记忆模块

记忆模块包含在语义分割网络中的backbone中,将每一类通用的特征信息存入矩阵MRN×C\mathcal M\in\mathbb R^{N\times C}中,其中CC时编码特征的通道数,NN是类别数。

初始化

首先通过在ImageNet上预训练的编码器EE,其参数为ΘE\Theta_E,提取2\ell_2标准化的特征图。然后按类别对图片相应区域进行特征图平均,得到记忆矩阵M\mathcal M

更新

记忆更新网络包含1×11\times 1带有残差连接的卷积层,其参数为ΘU\Theta_U2\ell_2标准化的特征图FRC×H×W\mathcal F\in\mathbb R^{C\times H'\times W'}转化为Z=U(F)\mathcal Z=U(\mathcal F),为了更新类别nn对应的记忆向量M[n]\mathcal M[n],文章对图像中第nn类语义掩码后的区域执行平均池化操作:

Z^[n]=(Y[n]Z)/Kn \hat{\mathcal Z}[n]=(\mathcal Y[n]\mathcal Z^\top)/K_n

其中KnK_n为第nn类真值的类别对应的掩码后区域像素的数量。Z^RN×C\hat{\mathcal Z}\in\mathbb R^{N\times C}为掩码后特征图。YRN×HW\mathcal Y\in\mathbb R^{N\times H'W'}ZRC×HW\mathcal Z\in\mathbb R^{C\times H'W'}分别为one-hot语义真值和掩码后的特征向量。

上述操作实际就是对特征图中每个类别对应的区域做平均池化操作,将每个类别对应的特征图从C×HWC\times H'W'池化为C×1C\times 1,即长度为CC的特征向量,所有类别组成N×CN\times C维度。

然后采用滑动平均的方式更新记忆矩阵:

M^[n]=mM[n]+(1m)Z^[n] \hat{\mathcal M}[n]=m\cdot\mathcal M[n] + (1-m)\cdot\hat{\mathcal Z}[n]

文中将经验性地设置动量m=0.8m=0.8。整体更新过程表达为:

M^=update(M,X;{Θ}E,U) \hat{\mathcal M}=\mathrm{update}(\mathcal M,\mathcal X;\{\Theta\}_{E,U})

其中参数组ΘE\Theta_EΘU\Theta_U表示为{Θ}E,U\{\Theta\}_{E,U}。整个更新过程为:

记忆矩阵更新过程

读取

为了首先沿着每个空间位置维度聚集记忆项,首先计算得到记忆权重矩阵WRN×HW\mathcal W\in\mathbb R^{N\times H'W'}

W[n]=exp(M[n]F)n=1Nexp(M[n]F) \mathcal W[n]=\frac{\exp(\mathcal M[n]\mathcal F)}{\sum_{n'=1}^N\exp(\mathcal M[n']\mathcal F)}

其中MRN×C\mathcal M\in\mathbb R^{N\times C}是记忆矩阵,FRC×HW\mathcal F\in\mathbb R^{C\times H'W'}是输入图片的特征图。用其来指导特征图,得到权重记忆特征MWRC×HW\mathcal M^\top\mathcal W\in\mathbb R^{C\times H'W'}。将其与输入图片特征拼接,经过卷积和激活操作得到记忆指导的特征图:

R=ReLU(Conv1×1(Π(F,MW))) \mathcal R=\mathrm{ReLU}(\mathrm{Conv}_{1\times 1}(\mathrm{\Pi(\mathcal F,\mathcal M^\top\mathcal W)}))

其中Π\Pi代表拼接操作,卷积操作目的在于将融合后的特征图维度从R2C×H×W\mathbb R^{2C\times H'\times W'}将为RC×H×W\mathbb R^{C\times H'\times W'},获得记忆指导的特征图。整个读取过程为:

记忆矩阵读取过程

学习泛化更新和读取

以往基于元学习的领域泛化方法并没有使用额外的先验知识,本文使用元学习实现两个目的:将领域不变的类别知识存储在外部记忆中来为鲁棒的语义分割提供类别指引;强化网络以鲁棒地将每个新场景的像素分类到针对类内和跨域变化的类别标签。文章随机将源域S\mathbb S划分为元训练域Smtr\mathbb S_\text{mtr}Smte\mathbb S_\text{mte},然后重复地从源域记忆类别信息,测试网络在保持记忆的情况下能否在目标域上正常工作。

元训练

给定XmtrSmtr\mathcal X_\text{mtr}\in\mathbb S_\text{mtr},编码器计算特征图Fmtr\mathcal F_\text{mtr}并通过读取操作来使用记忆M\mathcal M增强特征图。然后使用解码器输出分割结果,并将其和真值Ymtr\mathcal Y_\text{mtr}计算交叉熵损失Lseg\mathcal L_\text{seg}。但交叉熵并不能保证编码结果中相同类别的特征在特征嵌入空间中接近,因此文章提出了特征凝聚损失,基于记忆项来促使语义特征产生局部嵌入的效果:

Lcoh=1HWj=1HWYmtr[j]log(Wmtr[j]) \mathcal L_\text{coh}=\frac1{H'W'}\sum_{j=1}^{H'W'}-\mathcal Y_\text{mtr}^\top[j]\log(\mathcal W_\text{mtr}[j])

此处形式上类似信息熵,但是将信息熵中的log\log对象替换为对记忆权重矩阵Wmtr\mathcal W_\text{mtr}。最小化信息熵作用是保持样本之间的结构,直观理解熵增大意味着状态更混乱。
信息熵最小化常用于半监督学习,基本假设是分类器的决策边界不应穿过数据中的高密度区域,因此使用正则化来降低信息熵,保持数据中的结构。
有关最小化信息熵和熵正则化的作用可以参考论文:

  1. Grandvalet Y, Bengio Y. Semi-supervised learning by entropy minimization[J]. Advances in neural information processing systems, 2004, 17.
  2. Grandvalet Y, Bengio Y. Entropy Regularization[J]. 2006.

另外记忆项之间的特征应该足够远,以产生判别性的效果。因此文章提出了记忆分散损失来增强记忆项之间的距离,增大决策边界:

Ldiv=n=1N(I[n]log(G(M^[n]))+2nnNmax(M^[n]M^[n],0)N(N1)) \mathcal L_\text{div}=\sum_{n=1}^N(-\mathcal I[n]\log(G(\hat{\mathcal M}[n]^\top))+2\cdot\sum_{n'\ne n}^N\frac{\max(\hat{\mathcal M}[n]\hat{\mathcal M}[n']^\top,0)}{N(N-1)})

其中第一项用于记忆分类,GG为参数为ΘG\Theta_G的全连接分类器,IRN×N\mathcal I\in\mathbb R^{N\times N}为单位矩阵。第二项和余弦嵌入损失相似。

上述两损失作用是提高类别记忆特征的内聚程度和类间离散程度。但分类器GG是怎么得到的?单独训练吗?

定义记忆读取和更新损失分别为:

Lread(M,Xmtr;{Θ}E,D)=Lseg+λ1LcohLupdate(M,Xmtr;{Θ}E,U,G)=λ2Ldiv \mathcal L_\text{read}(\mathcal M,\mathcal X_\text{mtr};\{\Theta\}_{E,D})=\mathcal L_\text{seg}+\lambda_1\mathcal L_\text{coh}\\ \mathcal L_\text{update}(\mathcal M,\mathcal X_\text{mtr};\{\Theta\}_{E,U,G})=\lambda_2\mathcal L_\text{div}

网络参数更新:

{Θ}E,U,D,ΘG{Θ}E,U,D,GαΘLread(M,Xmtr;{Θ}E,D)αΘLupdate(M,Xmtr;{Θ}E,U,G)\begin{aligned} \{\Theta\}_{E,U,D}',\Theta_G^*&\leftarrow\{\Theta\}_{E,U,D,G}\\&-\alpha\nabla_\Theta\mathcal L_\text{read}(\mathcal M,\mathcal X_\text{mtr};\{\Theta\}_{E,D})\\ &-\alpha\nabla_\Theta\mathcal L_\text{update}(\mathcal M,\mathcal X_\text{mtr};\{\Theta\}_{E,U,G}) \end{aligned}

其中α\alpha是元训练步骤的学习率。

元测试

元测试目的是虚拟仿真测试网络在新数据上的性能,同时评价是否更新类别记忆的操作在跨领域中很好地工作。首先使用元训练图片的编码结果对记忆进行更新:

M=update(M,Xmtr;copy(ΘE),ΘU) \mathcal M'=\textbf{update}(\mathcal M,\mathcal X_\text{mtr};\mathrm{copy}(\Theta_E'),\Theta_U')

其中Xmtr\mathcal X_\text{mtr}是输入的元训练集图片,copy\mathrm{copy}表示冻结编码器,以防止异步梯度更新,得到的编码后特征图经过参数为ΘU\Theta_U'的更新网络来对记忆M\mathcal M进行更新。网络参数更新步骤为:

{Θ}E,U,D{Θ}E,U,DβΘLread(M,Xmte;{Θ}E,U,D) \{\Theta\}_\text{E,U,D}^*\leftarrow\{\Theta\}_{E,U,D}-\beta\nabla_\Theta\mathcal L_\text{read}(\mathcal M',\mathcal X_\text{mte};\{\Theta\}_\text{E,U,D}')

其中β\beta是元测试步骤的学习率。二阶梯度由式中第二项得到。下一次训练迭代步骤中的记忆初始化为:

M=update(M,Xmtr;copy({Θ}E,U)) \mathcal M^*=\textbf{update}(\mathcal M,\mathcal X_\text{mtr};\mathrm{copy}(\{\Theta\}_\text{E,U}^*))

元测试步骤的优化意在将来自元训练图片的领域可知特征写入到现有记忆单元中,并确保记忆指导的特征对于元测试图片的泛化性能。

记忆指导元学习训练伪代码

整个训练过程中,每步会随机划分为元训练集元测试集,并执行元训练元测试。元训练和普通的机器学习训练过程基本一致,读取训练集样本计算损失更新网络参数。首先采样一批样本,读取记忆矩阵M\mathcal M指导这批样本的语义分割,并计算读取损失Lread\mathcal L_\text{read},然后使用更新网络Up-Net更新记忆矩阵,并使用分类器GG计算更新损失Lupdate\mathcal L_\text{update},依据这两部分损失更新网络参数。元测试从模拟的新领域中获得对网络参数和记忆矩阵的反馈。因此首先选择的记忆矩阵应是领域可知的记忆矩阵,即基于现有的元训练集更新得到M\mathcal M',然后看其对于元测试集样本的指导性能。读取更新后的记忆矩阵对元测试集样本进行指导,并计算损失,更新网络参数,由更新后的网络参数再对进行更新,作为下一步元训练的初始记忆矩阵M\mathcal M^*

仿真实验

性能对比

使用的数据集包括真实数据集(Cityscapes、BDD100K、Mapillary、IDD)和合成数据集(GTAV、Synthia),评价指标采用所有类别上的mIoU

记忆引导元学习性能对比

消融实验

文章测试了损失函数、记忆更新策略和记忆学习框架对于性能的影响。结果为:

记忆指导元学习loss消融实验

记忆指导元学习记忆更新策略消融实验

记忆指导元学习记忆学习框架消融实验