【生成对抗网络】GAN入门与代码实现(一)

人工智能31

文章目录

*
- 1. 生成对抗网络介绍
- 2. 基于TensorFlow2的GAN的简单实现
-
+ 2.1 导包与参数设置
+ 2.2 生成器
+ 2.3 判别器
+ 2.4 搭建生成对抗网络
+ 2.5 数据准备与预处理
+ 2.6 主训练方法
+ 2.7 绘图函数
+ 2.8 开始训练
+ 2.9 loss与acc绘图
+ 2.10 结果

生成对抗网络系列
【生成对抗网络】GAN入门与代码实现(一)
【生成对抗网络】GAN入门与代码实现(二)
【生成对抗网络】基于DCGAN的二次元人物头像生成(TensorFlow2)
【生成对抗网络】ACGAN的代码实现

1. 生成对抗网络介绍

生成对抗网络(Generative Adversarial Network)于2014年被Goodfellow等人提出,然后迅速流行。GAN能通过学习特定领域知识创造出新的图像、文本等。2016年,GAN热潮席卷人工智能领域顶级会议,从ICLR到NIPS,大量论文被发表和探讨。Yann LeCun曾评价GAN是"20年来机器学习领域最酷的想法"。

在GAN中主要由生成器G(generator)与判别器D(discriminator)构成。其中生成器用于生成逼真的假数据、判别器则需要在判别出真实数据与假数据,生成器与判别器相互博弈,在能力上有所提升,生成器生成的数据越来越像是真实的数据,判别器则能更好地将两者分辨出来,直到两者达到一种平衡。

假如以小狗图片作为生成的目标:

  • 生成器:接收随机噪声(随机变量)作为输入,并输出一张小狗的图片(假图片)。
    [En]

    Generator: receives a random noise (random variable) as input and outputs a picture of a puppy (fake picture).*

  • 鉴别器:区分原始的真实小狗图像和生成器生成的小狗图像,以确定谁是真的,谁是假的。
    [En]

    Discriminator: distinguish the original real puppy image from the puppy image generated by the generator to determine who is true and who is false.*

在模型训练的过程中:

​ 生成器:学习如何更好的将生成的小狗图片更加像真实,从而让判别器误认为是真实的。

​ 判别器:不断地将生成器生成的图片与真实的图片用于判别器模型的训练,提高自己的判别准确率。

GAN的整个训练过程如下:

  1. ​ 生成器接收随机噪声,并生成假图像;
  2. ​ 判别器接收假图像和真实图像组合的数据,学习如何判别真假图像;
  3. ​ 生成器生成新的图像,并使用判别器来判别真假,同时通过判别器来判别此次造假的水平;
  4. ​ 重复步骤 1-3。

2. 基于TensorFlow2的GAN的简单实现

我们以手写数据集MNIST为例进行演示。让GAN学习生成一些新的手写数字图片,每张图片的尺寸为28*28。
【生成对抗网络】GAN入门与代码实现(一)

代码实现步骤如下:

  1. 定义生成器,接收随机噪声,输出图像张量
  2. 定义判别器,接收图像张量,输出真假张量
  3. 定义生成对抗网络,接收随机噪声,输出真假张量。生成对抗网络由前面定义的生成器的模型层和判别器的模型创建( 它们共享权重),同时需要冻结判别器的权重。
  4. 将随机噪声输入生成器,生成一批图像
  5. 使用生成的图像与真实图像训练判别器(假图像的目标为0,真图像的目标为1)
  6. 使用新随机噪声输入生成对抗网络,输出真假(使生成的假图像判别为1),提高"造假"水平
  7. 重复4-6步骤

; 2.1 导包与参数设置

import numpy as np
import tensorflow as tf
from tensorflow import keras
import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
LATENT_DTM = 100
IMAGE_SHAPE = (28,28,1)

2.2 生成器

生成器接收随机向量,然后通过模型生成一张手写数字图片。

关键点:

  • 使用随机噪声作为输入,保证模型具有一定的随机性
  • 使用tanh作为最后一层的激活函数,可以获得更好的效果
  • 使用LeakyReLU激活函数来代替ReLU激活函数
generator_net = [
    keras.layers.Input(shape=(LATENT_DTM,)),
    keras.layers.Dense(256),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.BatchNormalization(momentum = 0.8),
    keras.layers.Dense(512),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.BatchNormalization(momentum = 0.8),
    keras.layers.Dense(1024),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.BatchNormalization(momentum = 0.8),
    keras.layers.Dense(np.prod(IMAGE_SHAPE),activation='tanh'),
    keras.layers.Reshape(IMAGE_SHAPE)
]
generator = keras.models.Sequential(generator_net)

2.3 判别器

判别器是一个二分类问题,接收一个图片,输出真假。

discriminator_net =[
    keras.layers.Input(shape=IMAGE_SHAPE),
    keras.layers.Flatten(),
    keras.layers.Dense(512),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.Dense(256),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.Dense(1,activation='sigmoid')
]
discriminator = keras.models.Sequential(discriminator_net)

优化器:

optimizer = keras.optimizers.Adam(0.0002,0.5)

模型编译:

discriminator.compile(loss=keras.losses.binary_crossentropy,optimizer=optimizer,metrics=['acc'])

2.4 搭建生成对抗网络

将生成器与判别器组合在一起,同时冻结判别器的权重。

该过程将生成器生成的图片直接送入鉴别器模型,从而直接输出结果。在这个网络中,我们需要冻结鉴别器的权重,因为我们在过程中需要训练生成器,这样鉴别器的结果就会输出为真,从而不断提高生成器生成图像的水平,所以我们只需要训练生成器的层。

[En]

This process feeds the pictures generated by the generator directly into the discriminator model, thus directly outputting the results. In this network, we need to freeze the weight of the discriminator, because we need to train the generator in the process, so that the result of the discriminator is output as "true", so as to continuously improve the level of the image generated by the generator, so we only need to train the layer of the generator.


adversarial_net = generator_net + discriminator_net

for layer in discriminator_net:
    layer.trainable = False
adversarial = keras.models.Sequential(adversarial_net)

优化器:

optimizer = keras.optimizers.Adam(0.0002,0.5)

模型编译:

adversarial.compile(loss=keras.losses.binary_crossentropy,optimizer=optimizer,metrics=['acc'])

2.5 数据准备与预处理

加载keras中内置的手写数据集

(image_set,_),_ = keras.datasets.mnist.load_data()
image_set = image_set/127.5 - 1
image_set = image_set.reshape((image_set.shape[0],28,28,1))

准备训练过程中可视化的随机向量seed

num_example_to_generate = 6

seed = np.random.normal(0,1,(num_example_to_generate,LATENT_DTM))

用于记录训练过程中的准确率与损失


g_loss_list = []
d_loss_list = []

g_acc_list = []
d_acc_list = []

2.6 主训练方法

def train(batch = 30000,batch_size = 300):

    valid = np.ones((batch_size))
    fake = np.zeros((batch_size))

    batch_tqdm = tqdm.trange(batch)
    for index in batch_tqdm:

        idx = np.random.randint(0,image_set.shape[0],batch_size)
        imgs = image_set[idx]

        noise = np.random.normal(0,1,(batch_size,LATENT_DTM))

        gen_imgs = generator.predict(noise)

        d_state_real = discriminator.train_on_batch(imgs,valid)
        d_state_fake = discriminator.train_on_batch(gen_imgs,fake)

        d_state = 0.5*(np.add(d_state_real,d_state_fake))

        noise = np.random.normal(0,1,(batch_size,LATENT_DTM))

        adv_state = adversarial.train_on_batch(noise,valid)

        state = f"[D loss:{d_state[0]:.4f} acc: {d_state[1]:.4f}]" \
                f"[G loss:{adv_state[0]:.4f} acc: {adv_state[1]:.4f}]"
        batch_tqdm.set_postfix(state=state)

        g_loss_list.append(adv_state[0])
        g_acc_list.append(adv_state[1])
        d_loss_list.append(d_state[0])
        d_acc_list.append(d_state[1])

        if index%500 == 0:
            generate_plot_image(seed)

注意model的train_on_batch方法的使用。

2.7 绘图函数

用固定的noise绘制6张图片,以便观察训练效果。


def generate_plot_image(test_noise):

    pre_image = generator(test_noise,training = False)

    fig = plt.figure(figsize=(16,3))
    for i in range(pre_image.shape[0]):
        plt.subplot(1,6,i+1)
        plt.imshow((pre_image[i,:,:,:] + 1)/2)
        plt.axis('off')
    plt.show()

2.8 开始训练

训练30000个batch,每个batch随机拿出300个图片用于训练。

batch = 30000
batch_size = 300
train(batch,batch_size)

2.9 loss与acc绘图

损失Loss:

plt.plot(range(1, batch+1), g_loss_list, label='g_loss')
plt.plot(range(1, batch+1), d_loss_list, label='d_loss')
plt.legend()

准确率Acc:

plt.plot(range(1, batch+1), g_acc_list, label='g_acc')
plt.plot(range(1, batch+1), d_acc_list, label='d_acc')
plt.legend()

2.10 结果

可以看到生成器生成图片的效果越来越好

【生成对抗网络】GAN入门与代码实现(一)
【生成对抗网络】GAN入门与代码实现(一)
【生成对抗网络】GAN入门与代码实现(一)
【生成对抗网络】GAN入门与代码实现(一)
【生成对抗网络】GAN入门与代码实现(一)

loss:

【生成对抗网络】GAN入门与代码实现(一)

acc:

【生成对抗网络】GAN入门与代码实现(一)

更新GAN的另一种实现方法:使用TensorFlow2中 求导机制进行自定义训练的GAN代码实现,可对比进行学习。
博客链接:【生成对抗网络】GAN入门与代码实现(二)

参考文献:《TensorFlow2实战》艾力

Original: https://blog.csdn.net/AwesomeP/article/details/124330150
Author: 宛如近在咫尺
Title: 【生成对抗网络】GAN入门与代码实现(一)



相关阅读

Title: (保姆教程)Spyder 配置Tensorflow(2.5.0)和keras(2.4.3)

(保姆教程)Spyder 配置Tensorflow(2.5.0)和keras(2.4.3)

前言

其实安装Tensorflow和keras的过程不难,但是寻找匹配的版本,以及使得Spyder适应花了自己很长时间!

安装Tensorflow

打开anaconda prompt

【生成对抗网络】GAN入门与代码实现(一)

; 首先查看自己的python版本

python --version

接着创建一个新的环境

conda create -n tf python=3.9.7

进入这个环境

conda activate tf

安装Tensorflow

pip install tensorflow==2.5.0

安装Keras

首先输入以下两行代码()

conda install mingw libpython
pip install theano

安装对应的Keras(版本一定要对应)

pip install keras==2.4.3

适应Spyder

在安装完成之后打开Spyder会发现其实用的解释器还是base,这个时候我们需要切换到tf中的python.exe解释器,但是切换之后重启内核,会显示内核错误!

【生成对抗网络】GAN入门与代码实现(一)

解决:

这个时候需要我们先在刚才的命令行中输入

conda install spyder

重新安装spyder,再次打开会发现可以正常运行

输入以下代码测试

import keras
from keras.datasets import mnist
(train_images,train_labels),(test_images,test_labels) = mnist.load_data()
print('shape of train images is ',train_images.shape)
print('shape of train labels is ',train_labels.shape)
print('train labels is ',train_labels)
print('shape of test images is ',test_images.shape)
print('shape of test labels is',test_labels.shape)
print('test labels is',test_labels)

【生成对抗网络】GAN入门与代码实现(一)

成功运行!

但美中不足的是,使用tensorflow之后之前conda自带的很多库就用不了了,需要重新pip或者conda

Original: https://blog.csdn.net/qq_43102225/article/details/124833331
Author: skyoung13
Title: (保姆教程)Spyder 配置Tensorflow(2.5.0)和keras(2.4.3)

Original: https://blog.csdn.net/AwesomeP/article/details/124330150
Author: 宛如近在咫尺
Title: 【生成对抗网络】GAN入门与代码实现(一)



相关阅读

Title: Anaconda安装Tensorflow-GPU

1.0 安装前的准备

第一步:查看显卡支持的CUDA版本

如图我的显卡支持的最高版本为11.6
【生成对抗网络】GAN入门与代码实现(一)

【生成对抗网络】GAN入门与代码实现(一)

如下图
【生成对抗网络】GAN入门与代码实现(一)
对于我来说 我可以安装的CUDA是11.6以下的所有版本
我选择了 10.1CUDA,则对应的 7.6cuDNN以及 2.3tensorflow

; 2.0 开始安装

首先创建一个虚拟环境

我选择2.3tensorflow+3.8python

conda create -n tf2.3 python=3.8

然后激活环境进行下一步安装(切记!)

conda install cudatoolkit=10.1
conda install cudnn=7.6
pip install tensorflow==2.3

验证是否安装成功

python
import tensorflow as tf
tf.__version__

验证GPU是否可用

import tensorflow as tf
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

参考文章

conda安装tensorflow(python).

anaconda安装TensorFlow GPU版本(2021.11 3060亲测).

Original: https://blog.csdn.net/m0_56023586/article/details/123489752
Author: 阿昌涩涩发逗
Title: Anaconda安装Tensorflow-GPU

相关文章
R 聚类分析 人工智能

R 聚类分析

聚类分析 1. 数据描述 2. 调入数据,并对数据标准化。 3.系统聚类(类间距离为默认最长距离法) * 3.1. 分2类进行系统聚类,画系统聚类图,添加分类框,查看分类结果。 3.2.分3类进行系统...
如何使用Anaconda创建Tensorflow环境? 人工智能

如何使用Anaconda创建Tensorflow环境?

Tensorflow框架与Python之间存在着明确的版本对应关系,若安装版本不匹配,则后期会出现各种报错情况。为此,在安装之前要确定好所要安装的版本。若是不明确对应版本的同学,可参考这篇文章:Ten...
机器学习17 -- GAN 生成对抗网络 人工智能

机器学习17 — GAN 生成对抗网络

1 什么是GAN 1.1 组成部分:生成器和判别器 GAN诞生于2014年,由深度学习三巨头之一的Bengio团队提出。是目前为止机器学习中最令人兴奋的技术之一。目前有几百种不同构架的GAN,论文也是...