admin管理员组

文章数量:1658762

GAN生成对抗网络:由两个子网络组成,generator和discriminator,在训练过程中,这两个子网络进行着最小最大值机制,generator用随机向量输出一个目标数据分布的样例,discriminator从目标样例中区分出生成器生成的样例。generator通过后向传播混淆discriminator,依此generator生成与目标样例相似的样例。

这篇论文中,将generator换成一个分割网络(可以是任意形式的分割网络,如:FCN,DeepLab,DilatedNet……,输入是H*W*3,依次是长宽,通道数,输出概率图为H*W*C,其中C是语义种类数),这个网络对输入的图片分割输出一个概率图,使得输出的概率图尽可能的接近ground truth。其中discriminator采用了全卷积网络(输入为generator或ground truth得到的概率图,输出位空间概率图H*W*1,其中其中像素点p代表这个来自gournd truth(p=1)还是generator(p=0)。

代码

训练中,用半监督机制,一部分是注解数据,一部分是无注解数据。
当用有注解数据时,分割网络由基于ground truth的标准交叉熵损失和基于鉴别器的对抗损失共同监督。注意,训练discriminator只用标记数据。

当用无注解数据时,用半监督方法训练分割网络,在从分割网络中获取未标记图像的初始分割预测后,通过判别网络对分割预测进行传递,得到一个置信图。我们反过来将这个置信图作为监督信号,使用一个自学机制来训练带masked交叉熵损失的分割网络。置信图表示了预测分割的质量。

对抗网络的半监督训练

输入图像 xn x n 大小为H*W*3, 分割网络表示为 s() s ( · ) ,预测概率图为 s(xn) s ( x n ) 大小为H*W*C。全卷积discriminator表示为 D() D ( · ) ,其输入有两种形式:分割预测 s(xn) s ( x n ) 和one-hot编码的gournd truth Yn Y n .

训练discriminator网络:

最小化空间交叉熵损失 LD L D ,其表示为:

LD=h,w(1yn)log(1(s(xn))(h,w))+ynlog(D(Yn)(h,w)) L D = − ∑ h , w ( 1 − y n ) l o g ( 1 − ( s ( x n ) ) ( h , w ) ) + y n l o g ( D ( Y n ) ( h , w ) )
当输入来自分割网络时 yn=0 y n = 0 ,若来自ground truth则为 yn=1 y n = 1 .
为了将ground truth转换为C通道的概率图,我们用one-hot机制进行编码,即如果像素 x(h,w)n x n ( h , w ) 输入类C,则取1,否则为0.

训练分割网络:

这里使用的损失是多任务损失:

Lseg=Lce+λadvLadv+λsemiLsemi L s e g = L c e + λ a d v L a d v + λ s e m i L s e m i
其中 Lce L c e Ladv L a d v Lsemi L s e m i 分别代表 multi-class cross entropy loss, the adversarial loss,和the semi-supervised loss,这里的 λadv λ a d v λsemi λ s e m i .
这里先考虑用有注解的数据,则:
Lce=h,wcϵCY(h,w,c)nlog(s(xn)(h,w,c)) L c e = − ∑ h , w ∑ c ϵ C Y n ( h , w , c ) l o g ( s ( x n ) ( h , w , c ) )
Ladv L a d v 表示为:
Ladv=h,wlog(D(S(XN))(h,w)) L a d v = − ∑ h , w l o g ( D ( S ( X N ) ) ( h , w ) )

用无标签数据训练

由于没有ground truth,因此这里不能使用 Lce L c e ,这里提出了用自学机制在无注解数据中利用被训练的discriminator,大意是被训练的discriminator可以生成一个置信图,即 D(S(Xn))(h,w) D ( S ( X n ) ) ( h , w ) ,这个公式用来推断预测结构足够接近gournd truth的区域。这里用一个阈值来二值化置信图, Y^=argmax(s(xn)) Y ^ = a r g m a x ( s ( x n ) ) ,使用二值化置信图,半监督损失可以定义为:

Lsemi=h,wcϵCI(D(S(xn))(h,w)>Tsemi)Y^(h,w,c)nlog(s(xn)(h,w,c)) L s e m i = − ∑ h , w ∑ c ϵ C I ( D ( S ( x n ) ) ( h , w ) > T s e m i ) ∙ Y ^ n ( h , w , c ) l o g ( s ( x n ) ( h , w , c ) )
其中 I() I ( ∙ ) 是指示函数, Tsemi T s e m i 是阈值,注意在训练期间,自学目标值 Y^n Y ^ n 和指示函数的值为常量,因此上式可以简单看做空间交叉熵损失。

本文标签: SemiLearningADVERSARIALSegmentationSemantic