载入天数...载入时分秒... 总访问量次 🎉
生成对抗网络
1 基本概念介绍
在之前的网络中,输入 X 是已知的,期望它可以生成一个 Y,对于生成网络而言,输入 Z 的是一个随机采样,是从一个分布中来的(simple distribution)
Z 每次都不一样,但是需要够简单,需要知道满足的分布,每次输入时可以通过对分布采样得到 Z,每次输入的 Z 不同,输出的 Y 也不同,因此对于输出而言,也会得到一个复杂分布。此时 Network 就成为了一个 Generator
为什么需要输出是一个分布
- 由于数据集会有多种情况,因此如果不是分布(概率),因此会产生同时包含多种情况的图像(因为 NN 会生成一个同时接近多个原始图像的图像,其中每个都是合理的,但同时包含就是不合理的),因此对于输出而言,需要是一个概率,选择概率大的那个情况进行输出
- 可以生成一些比较有创造性的能力
2 GAN 的基本概念
Generative Adversarial Network 生成对抗网络
2-1 unconditional GAN
Generator
输入是 Z(normal distribution ),输出是 Y(distribution)
- Z 是从正态分布中 sample 出来的,低维向量。distribution 够简单即可
- Y 是高维向量
Discriminator
输入一张图像,输出是数值(scalar)越大表示越像真实图像
基本想法:
首先训练 Discriminator,使得它可以分辨出真实图像和 Generator 产生的图像,然后再训练 Generator 使他能够再次骗过 Discriminator……
算法:
初始化 generator 和 discriminator
在每个训练 epoch:
- 固定 G,训练 D。此时 G 是乱产生。将真实图像和 G 产生的图像输入到 D 中,使得 D 可以鉴别所有 G 产生的图像(分类和回归都可以)
- 固定 D,训练 G。使得 G 可以骗过 D,让 D 看不出是真实的还是 G 生成的。他的目标是让 D 的输出(真实值)越大越好
3 GAN 背后的理论
最大似然估计和最小 KL 距离之间的等价性
4 评估 GAN
评估 GAN 图像生成的质量(generator)的好坏:
- 人类评估,贵且不客观
- 将生成的图像输入到一个图像分类系统中,看输出的类别分布是否集中,如果集中,则效果比较好。但会遇到:
- Mode collapse (模型崩溃)问题,输出的图像质量不错,但是类别很少(找到了一个 Discriminator 的盲点)
- Mode Dropping 问题,输出的图像种类数是数据集的一个子集
解决办法:
种类多样性评估:
将每张图片都输入到一个 image classification 中,得到每个图像的分布情况,最后对所有的分布取平均值,得到最终的分布,如果最终的分布比较分散,则说明种类多。
对于每张图像,如果他的分布越平均,则说明 Quality 越差
对于所有图像分布的平均,如果越平均,则说明 diversity 越好
1 Inception Score:good quality,large diversity -> 大的 IS
2 Fretch Inception Distance
将生成的图像输入到 Inception Network 中,但输出是取的是 softmax 前面的 vector,他会是一个分布,假设真实的图像分布和生成图像的分布是 gaussian 分布,求它们之间的 Frechet distance 距离
不希望产生 Dataset 中的图像,而是希望产生新的图像(memory GAN)
5 Conditional GAN
对于先前的分布,输入的只是从 Gaussian Distribution 采样的数据,Conditional 的含义是还有另外的输入 X,可以指导生成的过程。如 text-to-image
此时需要对 Discriminator 进行修改,使得的输入变为 G 生成的图像 y 和文本 x
训练的数据集需要是 pair 类型,即(文本 x,对应满足的图像且质量好),只有这样的图像才为 1,其余的为 0
6 GAN in Unsupervised Learning
没有成对的训练资料,标注很困难且昂贵
图像风格迁移
输入是原始图像的分布,输出是二次元的分布。选一张图像作为输入,让 G 生成一张图像,D 的输入是 G 生成的图像和二次元图像,判断是否是真实的。
但是不能保证生成的图像具有原来图像的特征。如果使用 conditional gan,但是没有足够成对的数据供模型训练
可以使用 Cycle-GAN 实现
第一个生成器 G1 生成一张图像,第二个生成器 G2 将 G1 生成的图像还原回原来的图像。判别器 D1 判断是否是二次元图像。判别器 D2 判断 G2 生成的图像是否真的是原来的图像
- 存在的一个问题:没法从理论上保证中间生成的图像具有原来图像的特征,可能会学到一些奇怪的转化关系(眼镜 ->痣 ->眼镜)
- 实际上上面的问题不容易出现
starGAN
可以在多种风格之间做转化