生成对抗网络(GAN)原理和实现

    生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。论文《Generative Adversarial Nets》首次提出GAN。


GAN的思想

    GAN由生成器G和判别器D组成。生成器G根据输入先验分布的随机向量(一般使用随机分布,论文里用的是高斯分布)得到符合数据集数据分布。判别器D判别输入数据来源于G还是真实数据。框架如下:

1.png

    GAN的训练过程:刚开始G和D里面的参数随机初始化,第一步使用真实图片训练D,则D能轻易判断G生成的图片和真实图片。接着训练G,使得G生成的图片更加逼真,直到D无法判断G生成的图片和真实图片。接着训练D,使D能轻易判断真实图片。。。以此类推,最终G生成的图片和真实图片很相似。就好像论文里的例子,G就像货币的伪造者,D就像警察,G造假币,D识别假币,两者互相对抗。G造的假币越来越逼真,D识别的手段越来越高超。最终G生成的东西跟真的差不多。过程示意:

2.png


目标函数

    从本质上看,GAN的训练目标就是使G恢复训练数据的数据分布。数据分布可以理解为数据的概率函数。

    用更数学的表示,G将先验分布(比如高斯分布)P_prior(z)中的z通过G映射到x,即x的分布为P_G(x,theta),theta即G的参数。将生成器分布P_G(x,theta)和真实数据分布P_data(x)对比即可得到损失loss,如下图:

3.png

    判别器D评估P_G(x)和已知数据分布P_data(x)的差异,即判别输入的x是来自真实分布还是生成器。

    根据GAN的思想,其优化过程可以表示为以下公式:

4.png

    将上式拆分,得到D和G的目标函数。

    优化D,D的目标是将G(z)判断为真的概率D(G(z))尽可能小,即1-D(G(z))尽可能大,且将真实数据x判断为真的概率D(x)尽可能大。因此得公式如下:

5.jpg

    优化G,G的目标是令D判断为真的概率尽可能大,即D(G(z))尽可能大。因此得公式如下:

6.jpg

    实际训练G的时候,早期要求V的初始斜率大,因此需要替换V:

7.png

    实际的数据是离散的,因此计算分布的期望是通过采样计算得到的。论文里提出,迭代的优化k步D和1步G。这可以让D保持在最优解附近,可得迭代优化参数的算法:

8.png


理论证明

    以下的理论分析将证明GAN的目标函数有一个最优解p_g=p_data。

    对于目标函数:

    9.png

    上式是求两个期望的相加,等价于:

10.png

    我们想要在找到最优的D*,使得V(G,D)最大:

11.png

    求V(G,D)等价与求以下公式最大:

12.png

    上式中, P_data(x)是一个常量,表示x对应的概率分布中的值,这里设为a,P_G(x)也是如此,设为b。因此可以对上式进行求导,即可得到D*,过程如下:

13.png

代入D目标函数,得:

14.png

    当且仅当p_g=p_data,C(G)取得最大值,此时C(G)=-log4,如下:

    15.png

    将D*代入V(G,D)的积分表达式,得:

16.png

    分子分母同时除以2。再将1/2提出来,且17.png18.png等于1,则:

19.png

    上式的KL是KL散度。KL散度:相对熵(relative entropy),又被称为Kullback-Leibler散度(Kullback-Leibler divergence)或信息散度(information divergence),是两个概率分布(probability distribution)间差异的非对称性度量 [1]  。在在信息理论中,相对熵等价于两个概率分布的信息熵(Shannon entropy)的差值。离散和连续随机变量的公式如下:

20.png

    继续得到:

21.png

    JSD是Jesen-Shannon散度。由于两个分布之间的Jensen-Shannon散度总是非负的,只有当它们相等时才为零,所以我们已经证明了C∗=−log(4)是C(G)的全局最小值,而唯一的解决方案是p_g = p_data,即,生成模型完美地复制了数据生成过程。


代码实现

    论文里生成器使用全连接网络,使用relu和sigmoid激活函数;判别器也是全连接网络,使用maxout激活函数,同时应用了dropout。生成器的输入是符合高斯分布的随机向量,theta值根据交叉验证得到。在MNIST、TFD和CIFAR-10上面测试。

    我这里的代码也是参照了网上开源的代码,判别器和生成器均使用relu和sigmoid激活函数。经过我的测试,发现每轮迭代的时候,生成器应该要比判别器训练更多,否则会发散,这跟论文里的描述相反。

    网络的计算图如下:

graph.jpg

    代码如下:

定义参数

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import name_scope as namespace
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

# 训练参数
num_steps = 20000
batch_size = 128
learning_rate = 0.0002

# 网络参数
image_dim = 784  # 28*28
gen_hidden_dim = 256
disc_hidden_dim = 256
noise_dim = 100 # Noise data points

k=1

# 保存隐藏层的权重和偏置,用于变量共享
with namespace('var'):
    weights = {
        'gen_hidden1': tf.Variable(tf.truncated_normal([noise_dim, gen_hidden_dim],stddev=0.1)),
        'gen_out': tf.Variable(tf.truncated_normal([gen_hidden_dim, image_dim],stddev=0.1)),
        'disc_hidden1': tf.Variable(tf.truncated_normal([image_dim, disc_hidden_dim],stddev=0.1)),
        'disc_out': tf.Variable(tf.truncated_normal([disc_hidden_dim, 1],stddev=0.1)),
    }
    biases = {
        'gen_hidden1': tf.Variable(tf.zeros([gen_hidden_dim])),
        'gen_out': tf.Variable(tf.zeros([image_dim])),
        'disc_hidden1': tf.Variable(tf.zeros([disc_hidden_dim])),
        'disc_out': tf.Variable(tf.zeros([1])),
    }

定义网络和优化器

# 生成网络
def generator(x):
    with namespace('gen_hidden1'):
        hidden_layer = tf.matmul(x, weights['gen_hidden1'])
        hidden_layer = tf.add(hidden_layer, biases['gen_hidden1'])
        hidden_layer = tf.nn.relu(hidden_layer)
    with namespace('gen_out'):
        out_layer = tf.matmul(hidden_layer, weights['gen_out'])
        out_layer = tf.add(out_layer, biases['gen_out'])
        out_layer = tf.nn.sigmoid(out_layer)
    return out_layer

# 判别网络
def discriminator(x):
    with namespace('disc_hidden1'):
        hidden_layer = tf.matmul(x, weights['disc_hidden1'])
        hidden_layer = tf.add(hidden_layer, biases['disc_hidden1'])
        hidden_layer = tf.nn.relu(hidden_layer)
    with namespace('disc_output'):
        out_layer = tf.matmul(hidden_layer, weights['disc_out'])
        out_layer = tf.add(out_layer, biases['disc_out'])
        out_layer = tf.nn.sigmoid(out_layer)
    return out_layer

# 网络输入
gen_input = tf.placeholder(tf.float32, shape=[None, noise_dim], name='input_noise')
disc_input = tf.placeholder(tf.float32, shape=[None, image_dim], name='disc_input')
# 创建生成网络
with namespace('generator'):
    gen_sample = generator(gen_input)
# 创建两个判别网络 (一个来自噪声输入, 一个来自生成的样本)
with namespace('discriminator'):
    with namespace('discriminator_real'):
        disc_real = discriminator(disc_input)
    with namespace('discriminator_fake'):
        disc_fake = discriminator(gen_sample)
    
with namespace('loss'):
    # 定义损失函数
    gen_loss = -tf.reduce_mean(tf.log(disc_fake))
    disc_loss = -tf.reduce_mean(tf.log(disc_real) + tf.log(1. - disc_fake))
#将变量的损失值写入Loss
tf.summary.scalar('gen_loss', gen_loss)
tf.summary.scalar('disc_loss', disc_loss)
merged_summary = tf.summary.merge_all()
    
with namespace('train'):  
    # 定义优化器
    optimizer_gen = tf.train.AdamOptimizer(learning_rate=learning_rate)
    optimizer_disc = tf.train.AdamOptimizer(learning_rate=learning_rate)
    # 训练每个优化器的变量
    # 生成网络变量
    gen_vars = [weights['gen_hidden1'], weights['gen_out'],biases['gen_hidden1'], biases['gen_out']]
    # 判别网络变量
    disc_vars = [weights['disc_hidden1'], weights['disc_out'],biases['disc_hidden1'], biases['disc_out']]
    # 最小损失函数
    train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
    train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)
# 初始化变量
init = tf.global_variables_initializer()

训练网络

def getData(batch_size=128):
    batch_x, _ = mnist.train.next_batch(batch_size)# 准备数据
    z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])# 产生噪声给生成网络
    return batch_x,z
    

# 开始训练
with tf.Session() as sess:
    sess.run(init)
    writer=tf.summary.FileWriter('D:/Jupyter/GAN/mnist_gan_train_log/log',sess.graph)
    saver=tf.train.Saver()
    
    for i in range(1, num_steps+1):
        for j in range(k):
            x,z=getData()
            _,dl = sess.run([train_disc, disc_loss], feed_dict={disc_input: x, gen_input: z})
        x,z=getData()
        _,gl = sess.run([train_gen, gen_loss], feed_dict={disc_input: x, gen_input: z})
        x,z=getData()
        _,gl = sess.run([train_gen, gen_loss], feed_dict={disc_input: x, gen_input: z})
        x,z=getData()
        _,gl = sess.run([train_gen, gen_loss], feed_dict={disc_input: x, gen_input: z})
        
        x,z=getData()
        summary,g = sess.run([merged_summary,gen_sample], feed_dict={disc_input:x,gen_input:z})
        writer.add_summary(summary,i)#写summary和i到文件
        
        if i % 1000 == 0 or i == 1:
            print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))

    # 使用生成器网络从噪声生成图像
    f, a = plt.subplots(4, 10, figsize=(10, 4))
    for i in range(10):
        
        # 噪声输入.
        z = np.random.uniform(-1., 1., size=[4, noise_dim])
        g = sess.run([gen_sample], feed_dict={gen_input: z})
        g = np.reshape(g, newshape=(4, 28, 28, 1))
        # 将原来黑底白字转换成白底黑字,更好的显示
        g = -1 * (g - 1)
        for j in range(4):
            # 从噪音中生成图像。 扩展到3个通道,用于matplotlib
            img = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),newshape=(28, 28, 3))
            a[j][i].imshow(img)

    plt.savefig('test.png')#保存图片
    #f.show()
    #plt.draw()
    #plt.waitforbuttonpress()

    生成的结果:

result.jpg

    训练过程的损失值:

loss.jpg

    从损失曲线可以看到,GAN的训练过程是不稳定的。


参考文献

[1]Ian J. Goodfellow,etc.Generative Adversarial Nets.2014.arXiv:1406.2661v1

[2]小白的成长. GAN之V(D,G)函数. https://blog.csdn.net/qq_42413820/article/details/80673857. 2018-06-13

首页 所有文章 机器人 计算机视觉 自然语言处理 机器学习 编程随笔 关于