1. 概述
生成对抗网络GAN(Generative adversarial nets)[1]是由Goodfellow等人于2014年提出的基于深度学习模型的生成框架,可用于多种生成任务。从名称也不难看出,在GAN中包括了两个部分,分别为”生成”和“对抗”,整两个部分也分别对应了两个网络,即生成网络(Generator)G和判别网络(Discriminator)D,为描述简单,以图像生成为例:
- 生成网络(Generator)G用于生成图片,其输入是一个随机的噪声z,通过这个噪声生成图片,记作G(z)
- 判别网络(Discriminator)D用于判别一张图片是否是真实的,对应的,其输入是一整图片x,输出D(x)表示的是图片x为真实图片的概率
在GAN框架的训练过程中,希望生成网络G生成的图片尽量真实,能够欺骗过判别网络D;而希望判别网络D能够把G生成的图片从真实图片中区分开。这样的一个过程就构成了一个动态的“博弈”。最终,GAN希望能够使得训练好的生成网络G生成的图片能够以假乱真,即对于判别网络D来说,无法判断G生成的网络是不是真实的。
综上,训练好的生成网络G便可以用于生成“以假乱真”的图片。
2. 算法原理
2.1. GAN的框架结构
GAN的框架是由生成网络G和判别网络D这两种网络结构组成,通过两种网络的“对抗”过程完成两个网络的训练,GAN框架由下图所示:
由生成网络G生成一张“Fake image”,判别网络D判断这张图片是否来自真实图片。
2.2. GAN框架的训练过程
在GAN的训练过程中,其最终的目标是使得训练出来的生成模型G生成的图片与真实图片具有相同的分布,其过程可通过下图描述[2]:
假设有一个先验分布pz(z),如上图中的unit gaussian,通过采样得到其中的一个样本点z。对于真实的图片,事先对于其分布是未知的,即上图中的p(x)未知。为了使得能与真实图片具有相同的分布,通过一个生成模型将先验分布映射到另一个分布,生成模型记为G(z;θg),其中θg为生成模型的参数,这里的生成模型可以是一个前馈神经网络MLP,θg便为该神经网络的参数。通过多次的采样,便可以刻画出生成的分布p^(x),此时需要计算其与真实的分布p(x)之间的相关性,即需要一个判别模型来定量表示两个分布之间的相关性,这里可以通过另一个前馈神经网络MLP,判别模型记为D(x;θd),其中D(x;θd)的输出是一个标量,表示的是x来自真实的分布,而不是来自于生成模型构造出的分布的概率。
对于这样的一个过程中,有两个模型,分别为生成模型G(z;θg)和判别模型D(x;θd),在GAN中,生成模型和判断模型分别对应了一个神经网络,以下都称为生成网络和判别模型。GAN希望的是对于判别网络,其能够正确判定数据是否来自真实的分布,对于生成网络,其能够尽可能使得生成的数据能够“以假乱真”,使得判别网络分辨不了。这样的训练过程是一个动态的“博弈”过程,通过交替训练,最终使得生成网络G生成的图片能够“以假乱真”,其具体过程如下图所示:
如上图(a)中,黑色的虚线表示的是从真实的分布px,绿色的实线表示的是需要训练的生成网络的生成的分布pg(G),蓝色的虚线表示的是判别网络,最下面的横线z表示的是从一个先验分布(如图中是一个均匀分布)采样得到的数据点,中间的横线x表示的真实分布,两条横线之间的对应关系表示的是生成网络将先验分布映射成一个生成分布pg(G)。从图(a)到图(d)表示了一个完整的交替训练过程,首先,如图(a)所示,当通过先验分布采样后的数据经过生成网络G映射后得到了图上绿色的实线代表的分布,此时判别网络D并不能区分数据是否来自真实数据,通过对判别网络的训练,其能够正确地判断生成的数据是否来自真实数据,如图(b)所示;此时更新生成网络G,通过对先验分布重新映射到新的生成分布上,如图(c)中的绿色实线所示。依次交替完成上述步骤,当达到一定迭代的代数后,达到一个平衡状态,此时pg=pdata,判别网络D将不能区分图片是否来自真实分布,且D(x)=21。
2.3. 价值函数
对于GAN框架,其价值函数V(G,D)为:
GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
其中,Ex∼pdata(x)[logD(x)]表示的是logD(x)的期望,同理,Ez∼pz(z)[log(1−D(G(z)))]表示的是log(1−D(G(z)))的期望。
假设从真实数据中采样m个样本{x(1),x(2),⋯,x(m)},从噪音分布pg(z)中同样采样m个样本,记为{z(1),z(2),⋯,z(m)},此时,上述价值函数可以近似表示为:
GminDmaxV(D,G)≈i=1∑m[logD(x(i))]+i=1∑m[log(1−D(G(z(i))))]
简化后为:
GminDmaxV(D,G)≈i=1∑m[logD(x(i))+log(1−D(G(z(i))))]
上述的交替训练过程如下流程所示:
3. GAN背后的数学原理
为了能够从数学的角度对上述过程做分析,首先对问题进行数学的描述:假设真实的数据分布为pdata(x),生成网络得到的分布为pg(x;θ),其中θ为生成网络的参数,现在需要找到一个θ∗使得pg(x;θ)≈pdata(x)。
3.1. 为什么会有这样的价值函数
由上可知,当生成网络G确定后,GAN的价值函数可以近似为:
DmaxV(D)≈i=1∑m[logD(x(i))]+i=1∑m[log(1−D(G(z(i))))]
其来源可以追溯到二分类的损失函数,对于一个二分类来说,通常选择交叉墒作为其损失函数,交叉墒的一般形式为:
J(θ)=−m1i=1∑m[y(i)logy^(i)+(1−y(i))log(1−y^(i))]
其中,y(i)表示的是真实的样本标签,y^(i)表示的是模型的预测值。对于GAN来说,样本分为两个部分,一个是来自真实的样本{(x(1),1),(x(2),1),⋯,(x(m),1)},将其带入到交叉墒的公式中(去除交叉墒的负号)为:
=m1i=1∑m[1⋅logD(x(i))+(1−1)log(1−D(x(i)))]=m1i=1∑m[logD(x(i))]
另一个则是来时生成模型{(G(z(1)),0),(G(z(2)),0),⋯,(G(z(m)),0)},将其带入到交叉墒的公式中(去除交叉墒的负号)为:
=m1i=1∑m[0⋅logD(G(z(i)))+(1−0)⋅log(1−D(G(z(i))))]=m1i=1∑m[log(1−D(G(z(i))))]
将两部分合并在一起,便是上述的价值函数。
3.2. KL散度
需要刻画两个分布是否相似,需要用到KL散度(KL divergence)。KL散度是统计学中的一个基本概念,用于衡量两个分布的相似程度,数值越小,表示两种概率分布越接近。对于离散的概率分布,定义如下:
DKL(P∥Q)=i∑P(i)logQ(i)P(i)
对于连续的概率分布,定义如下:
DKL(P∥Q)=∫−∞+∞p(x)logq(x)p(x)dx
3.3. 极大似然估计
极大似然估计(Maximum Likelihood Estimation),是一种概率论在统计学的应用,它是参数估计的方法之一。上述需要求解生成分布pg(x;θ)中的参数θ,需要用到极大似然估计。根据极大似然估计的方式,由于最终是希望生成的分布pg(x;θ)与原始的真实分布pdata(x),首先从真实分布pdata(x)采样m个数据点,记为{x(1),x(2),⋯,x(m)},根据生成的分布,得到似然函数为:
L=i=1∏mpg(x(i);θ)
取log后,得到等价的log似然:
L=i=1∑mlogpg(x(i);θ)
此时,θ∗为:
θ∗=θargmaxi=1∑mlogpg(x(i);θ)≈θargmaxEx∼pdata[logpg(x;θ)]=θargmax∫xpdata(x)logpg(x;θ)dx
对上述的公式做一些修改,增加一个与θ无关的项∫xpdata(x)logpdata(x)dx,这样并不改变对θ∗的求解,此时,公式变为:
θargmax∫xpdata(x)logpg(x;θ)dx−∫xpdata(x)logpdata(x)dx
将最大值求解变成最小值为:
θargmin∫xpdata(x)logpdata(x)dx−∫xpdata(x)logpg(x;θ)dx
通过积分公式的合并,得到:
θargmin∫xpdata(x)loglogpg(x;θ)pdata(x)dx
由KL散度可知,上述可以表示为:
θargminKL(pdata(x)∥pg(x;θ))
由此可以看出最小化KL散度等价于最大化似然函数。
3.4. 收敛性分析
当生成网络G确定后,价值函数可以表示为:
DmaxV(D)=∫xpdata(x)log(D(x))dx+∫zpz(z)log(1−D(G(z)))dz=∫xpdata(x)log(D(x))+pg(x)log(1−D(x))dx
由于上述的积分与D无关,上述可以简化成求解:
Dmax[pdata(x)log(D(x))+pg(x)log(1−D(x))]
求导数并令其为0,便可以得到最大的D:
D∗=pdata(x)+pg(x)pdata(x)
且D(x)∈[0,1],将其带入到价值函数中,可得
V(D∗,G)=Ex∼pdata(x)[logpdata(x)+pg(x)pdata(x)]+Ex∼pg(x)[log(1−pdata(x)+pg(x)pdata(x))]
对上式简化,可得:
V(D∗,G)=Ex∼pdata(x)[logpdata(x)+pg(x)pdata(x)]+Ex∼pg(x)[logpdata(x)+pg(x)pg(x)]=∫xpdata(x)logpdata(x)+pg(x)pdata(x)dx+∫xpg(x)logpdata(x)+pg(x)pg(x)dx
通过对分子和分母分别除2,可得:
V(D∗,G)=∫xpdata(x)log2pdata(x)+pg(x)21pdata(x)dx+∫xpg(x)log2pdata(x)+pg(x)21pg(x)dx=−2log2+KL(pdata(x)∥2pdata(x)+pg(x))+KL(pg(x)∥2pdata(x)+pg(x))
这里引入另一个符号:JS散度(Jensen-Shannon Divergence)
JSD(P∥Q)=21[KL(P∥M)+KL(Q∥M)]
其中,M=2P+Q。因此V(D∗,G)可以表示为:
V(D∗,G)=−2log2+2JSD(pdata(x)∥pg(x))
已知JS散度是一个非负值,且值域为[0,1],当两个分布相同时取0,不同时取1。对于V(D∗,G)的最小值为当JS=0时,即最小值是−2log2。此时pdata(x)=pg(x),此时求得的生成网络G生成的数据分布与真实的数据分布差异性最小,即GAN所要求的目标:pg=pdata。
3. 总结
生成对抗网络GAN中通过生成网络G和判别网络D之间的“生成”和“对抗”过程,通过多次的迭代,最终达到平衡,使得训练出来的生成网络G能够生成“以假乱真”的数据,判别网络D不能将其从真实数据中区分开。
参考文献
[1] Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[J]. Advances in neural information processing systems, 2014, 27.
[2] Generative Models
[2] PyTorch 学习笔记(十):初识生成对抗网络(GANs)
[3] 通俗理解生成对抗网络GAN