GAN网络

基本概念

GAN(Generative Adversarial Network)生成对抗网络,由生成器G(Generator)和判别器(Discriminator)组成。生成器努力生成数据以通过判别器,判别器努力判别出生成数据的真假,最后它俩收敛。

这个模型很有想象力,像是结对进化,像是二人关系,像是阴阳,达到收敛就是达到了平衡

判别模型和生成模型

  • 有监督学习

    2种:

    决策函数:Y = f(x)

    条件概率分布:P(Y | X)

  • 判别方法与生成方法

    判别模型:由数据直接学习决策函数或条件概率的模型。(其他)

    生成模型:由数据学习联合概率P(X,Y),然后由P(Y|X) = P(X,Y)/ P(X)求出条件概率分布的模型。(朴素贝叶斯与隐马尔科夫)

    除此,CNN、RNN都属于生成模型,有结构的都是生成模型。

    对比:

    1. 生成模型可以做预测与分类,判别模型只能分类
    2. 存在隐变量时,仍可以用生成模型,不能用判别模型
    3. 生成模型从大量的数据中找规律,而判别模型只关心不同类型数据的差别,利用差别来分类。

结构

GAN由一对生成器与判别器组成,生成器G捕捉样本数据的分布,追求效果越像真实样本越好;判别器D是一个分类器,估计样本来来自于训练数据的概率,来自训练数据则输出大概率,反之小概率。

论文:Generative Adversarial Nets

流程

参考

  1. 初始化判别器D的参数 θd 和生成器G的参数 θg 。
  2. 从真实样本中采样 m 个样本 { x1, x2, ..., xm } ,从先验分布噪声中采样 m 个噪声样本 { z1, z2, ..., zm } 并通过生成器获取 m 个生成样本 { x~1, x~2, ..., x~m } 。固定生成器G,训练判别器D尽可能好地准确判别真实样本和生成样本,尽可能大地区分正确样本和生成的样本。
  3. 循环k次更新判别器之后,使用较小的学习率来更新一次生成器的参数,训练生成器使其尽可能能够减小生成样本与真实样本之间的差距,也相当于尽量使得判别器判别错误。
  4. 多次更新迭代之后,最终理想情况是使得判别器判别不出样本来自于生成器的输出还是真实的输出。亦即最终样本判别概率均为0.5。

之所以要训练k次判别器,再训练生成器,是因为要先拥有一个好的判别器,使得能够教好地区分出真实样本和生成样本之后,才好更为准确地对生成器进行更新。更直观的理解可以参考下图:

图中

黑色虚线表示真实的样本的分布情况,

蓝色虚线表示判别器判别概率的分布情况,

绿色实线表示生成样本的分布。

Z 表示噪声, Z 到 x 表示通过生成器之后的分布的映射情况。目标是使用生成样本分布(绿色实线)去拟合真实的样本分布(黑色虚线),来达到生成以假乱真样本的目的。

目标函数

目标函数如下所示:

  • 里的 V(G,D)相当于表示真实样本和生成样本的差异程度。

  • 先看 。这里的意思是固定生成器G,尽可能地让判别器能够最大化地判别出样本来自于真实数据还是生成的数据。

  • 然后min,这里是在固定判别器D的条件下得到生成器G,这个G要求能够最小化真实样本与生成样本的差异。

  • 通过上述min max的博弈过程,理想情况下会收敛于生成分布拟合于真实分布。

  1. 这个目标函数要从交叉熵开始:

    pi和qi为真实的样本分布和生成器的生成分布

  2. 判别器是一个二分类问题(是不是来自训练样本),交叉熵可以写成:

    y1为正确样本分布,(1-y1)就是生成样本分布。D代表判别器,D(x1)代表判别样本为正确的概率,(1-D(x1))代表判别为错误样本的概率。

  3. 推广到N个样本

  4. 写成期望形式

    image-20220812181146015

    这一步的1/2没看懂,原文:

    我觉得反过来想吧,交叉熵的p*logq形式就是为了让q的分布去接近p的分布。

    我们最终想要达到的目的是是判别器无法正确判别,也就是判别概率为1/2。所以其实这一整个过程就相当于让D(x)和D(G(z))判别的概率去接近1/2,也就是上面yi取1/2那个式子。这里的1/2应该看成是判别的概率吧。所以其实是想说明下面的minmaxV(G, D)就相当于上述的拟合到1/2这个过程。

    而上面说的标签为0,1应该算是V(G, D)。原始论文中也有说明最后这个博弈过程有一个最优点,也就是收敛于Pg = Pdata也就是等于1/2了。我想应该是这样

https://zhuanlan.zhihu.com/p/33752313

https://developers.google.com/machine-learning/gan/loss

缺点

  1. GAN采用对抗学习的准则,理论上还不能判断模型的收敛性和均衡点的存在。训练过程需要保证两个对抗网络的平衡和同步,否则难以得到很好的训练效果。而实际过程汇总两个对抗网络的同步不易把控,训练过程可能不稳定。
  2. GAN生成样本容易具有多样性,但是多模态样本容易导致崩溃模式现象。

发展

GAN基础上发展出来一些新的GAN模型,如WGAN、WGAN-GP、CGAN等

WGAN

GAN中,最开始样本数据与生成数据之间有一定的重叠,如果不重叠JS散度会趋近一个常数,而常数的梯度接近于0,这就导致梯度小时。

WGAN重新定义了一种Wasserstein-1距离代替原来的JS散度。

Wasserstein-1

使得即使Pr、Pg补不重叠,依然可以清楚的反应两个分布的距离。详细

目标函数变为

image-20220811174658489

更进一步是WGAN-GP,解决了训练不稳定和崩溃问题。

https://neptune.ai/blog/gan-loss-functions

有必要再增进一下损失函数了,JS散度、Wasserstein-1距离等内容

CGAN

GAN网络只能随机产生一个类别,这对于需要指定类别的场景不候好,于是就有了CGAN,能够指定类别来生成。它的模型如下:

加入了一个条件y。相应的目标函数变为:

image-20220811175151282

与GAN相比,增加了控制的类别

应用

Pix2Pix

把一张图片变成另外一种需要的图片,图片映射。

不再输入高斯噪声,而是输入一张图片。

图像作为一种信息媒介,可以有很多种表达方式,比如灰度图、彩色图、素描图、梯度图。图像翻译就是指这些图像之间的变换。比如已知灰度图生成一张彩色图。这个任务使用基于CGAN的变体-Pix2Pix来完成。

Pix2Pix将生成器看做一种映射,即将图片映射成另一张需要的图片,表示Map pixels to pixels的意思。

生成器将原始图片x作为条件输入,输出的是转换之后的图片。判别器输入的是转换后的图片或真实图片。判别器也将原图输入,会极大提高实验的记过。

CAAE

针对人脸变年轻或者衰老的条件GAN

https://github.com/zzutk/Face-Aging-CAAE

其他

Keras的使用:

  1. Build:构建模型
  2. Compile:定义学习过程
  3. Fit:准备数据+训练
  4. Evaluate/Predict:查看结果

源码:1:04:53

​ 1:19:02

  • 一强一弱最终会导致不收敛
  • 多模型参与如何
  • 能否将判别算法用其他形式替代

评论

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×