GAN生成对抗网络学习笔记

00. 概念

GAN,全称为 Generative Adversarial Nets,直译为生成式对抗网络,是一种非监督式模型。
2014年由Ian Goodfellow提出,业内另一位大牛 Yan Lecun 也对它交口称赞,称其为“20 年来机器学习领域最酷的想法” ,至今为止GAN依然是炙手可热的研究方向。
“生成对抗网络是一种生成模型(Generative Model),其背后基本思想是从训练库里获取很多训练样本,从而学习这些训练案例生成的概率分布。
而实现的方法,是让两个网络相互竞争,‘玩一个游戏’。其中一个叫做生成器网络( Generator Network),它不断捕捉训练库里真实图片的概率分布,将输入的随机噪声(Random Noise)转变成新的样本(也就是假数据)。另一个叫做判别器网络(Discriminator Network),它可以同时观察真实和假造的数据,判断这个数据到底是不是真的。”
— Ian Goodfellow

《GAN生成对抗网络学习笔记》 Ian Goodfellow
《GAN生成对抗网络学习笔记》 GAN模型原理图

01. 例子

生成sin 正玄曲线(使用keras)

import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm_notebook as tqdm
from keras.models import Model
from keras.layers import Input, Reshape
from keras.layers.core import Dense, Activation, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling1D, Conv1D
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam, SGD
from keras.callbacks import TensorBoard
  1. Generative model:
    输入:noise data
    输出:尝试生成真实的 sin 数据
def get_generative(G_in, dense_dim=200, out_dim=50, lr=1e-3):
    x = Dense(dense_dim)(G_in)
    x = Activation('tanh')(x)
    G_out = Dense(out_dim, activation='tanh')(x)
    G = Model(G_in, G_out)
    opt = SGD(lr=lr)
    G.compile(loss='binary_crossentropy', optimizer=opt)
    return G, G_out
  1. Discriminative model:
    输出:识别此数据是真实的,还是由 Generative model 生成的
def get_discriminative(D_in, lr=1e-3, drate=.25, n_channels=50, conv_sz=5, leak=.2):
    x = Reshape((-1, 1))(D_in)
    x = Conv1D(n_channels, conv_sz, activation='relu')(x)
    x = Dropout(drate)(x)
    x = Flatten()(x)
    x = Dense(n_channels)(x)
    D_out = Dense(2, activation='sigmoid')(x)
    D = Model(D_in, D_out)
    dopt = Adam(lr=lr)
    D.compile(loss='binary_crossentropy', optimizer=dopt)
    return D, D_out
  1. chain the two models into a GAN:
    set_trainability 的作用是每次训练 generator 时要冻住 discriminator。
def set_trainability(model, trainable=False):
    model.trainable = trainable
    for layer in model.layers:
        layer.trainable = trainable

def make_gan(GAN_in, G, D):
    set_trainability(D, False)
    x = G(GAN_in)
    GAN_out = D(x)
    GAN = Model(GAN_in, GAN_out)
    GAN.compile(loss='binary_crossentropy', optimizer=G.optimizer)
    return GAN, GAN_out
  1. Training:
    交替训练 discriminator 和 chained GAN,在训练 chained GAN 时要冻住 discriminator 的参数:

《GAN生成对抗网络学习笔记》 交替训练 discriminator 和 chained GAN

def sample_noise(G, noise_dim=10, n_samples=10000):
    X = np.random.uniform(0, 1, size=[n_samples, noise_dim])
    y = np.zeros((n_samples, 2))
    y[:, 1] = 1
    return X, y

def train(GAN, G, D, epochs=500, n_samples=10000, noise_dim=10, batch_size=32, verbose=False, v_freq=50):
    d_loss = []
    g_loss = []
    e_range = range(epochs)
    if verbose:
        e_range = tqdm(e_range)
    for epoch in e_range:
        X, y = sample_data_and_gen(G, n_samples=n_samples, noise_dim=noise_dim)
        set_trainability(D, True)
        d_loss.append(D.train_on_batch(X, y))

        X, y = sample_noise(G, n_samples=n_samples, noise_dim=noise_dim)
        set_trainability(D, False)
        g_loss.append(GAN.train_on_batch(X, y))
        if verbose and (epoch + 1) % v_freq == 0:
            print("Epoch #{}: Generative Loss: {}, Discriminative Loss: {}".format(epoch + 1, g_loss[-1], d_loss[-1]))
    return d_loss, g_loss

d_loss, g_loss = train(GAN, G, D, verbose=True)
  1. Results:
N_VIEWED_SAMPLES = 2
data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES)
pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).rolling(5).mean()[5:].plot()

《GAN生成对抗网络学习笔记》 训练结果图

  1. 后记

这篇文章是在微信公号里看到的,原链接如下:https://mp.weixin.qq.com/s/8vw5LpOPAnNKQmQ_ck-oWg
但是原文中的代码和描述并不完整,原作者是Robin Ricard的blog中翻译的,详细内容参见这篇文章:http://www.rricard.me/machine/learning/generative/adversarial/networks/keras/tensorflow/2017/04/05/gans-part2.html
我把代码重新整理了一下,做在 jupyter笔记中,这个是可以运行的代码(python3,tensorflow>=1.0,keras>=2.0),如果需要发邮件索取(我的邮箱582711548@qq.com)。

    原文作者:gaoshine
    原文地址: https://www.jianshu.com/p/881f15aa0cfa
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞