参考博文:

1.pytorch搭建DCGAN

论文地址:https://arxiv.53yu.com/pdf/1511.06434.pdf

论文源码:略

文章引用源码:https://github.com/bubbliiiing/dcgan-pytorch

网络构建

DCGAN

DCGAN的全称是Deep Convolutional Generative Adversarial Networks,即深度卷积对抗生成网络。

它是由Alec Radford在论文Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks中提出的。

实际上它就是在GAN的基础上增加深度卷积网络结构。

论文中给出的DCGAN结构如图所示。其使用反卷积将特征层的高宽不断扩大,整体结构看起来像普通神经网络的逆过程。

生成网络的构建

对于生成网络来讲,它的目的是生成假图片,它的输入是正态分布随机数。输出是假图片。

在GAN当中,我们将这个正态分布随机数长度定义为100,在经过处理后,我们会得到一个(64,64,3)的假图片。

在处理过程中,我们会使用到反卷积,反卷积的概念是相对于正常卷积的,在正常卷积下,我们的特征层的高宽会不断被压缩;在反卷积下,我们的特征层的高宽会不断变大。

在DCGAN的生成网络中,我们首先利用一个全连接,将输入长条全连接到16,384(4x4x1024)这样一个长度上,这样我们才可以对这个全连接的结果进行reshape,使它变成(4,4,1024)的特征层。

在获得这个特征层之后,我们就可以利用反卷积进行上采样了。

在每次反卷积后,特征层的高和宽会变为原来的两倍,在四次反卷积后,我们特征层的shape变化是这样的:( 4 , 4 , 1024 ) − > ( 8 , 8 , 512 ) − > ( 16 , 16 , 256 ) − > ( 32 , 32 , 128 ) − > ( 64 , 64 , 3 )

此时我们再进行一次tanh激活函数,我们就可以获得一张假图片了。
实现代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import math
import torch
import torch.nn as nn

# 如果stride=2,就是宽高减半,下采样操作
def conv_out_size_same(size, stride):
return int(math.ceil(float(size) / float(stride)))


# 反卷积公式:H_out=(H_in −1)×stride[0]−2×padding[0]+dilation[0]×(kernel_size[0]−1)+output_padding[0]+1
class generator(nn.Module):
def __init__(self, d=128, input_shape=[64, 64]):
super(generator, self).__init__()
# 64, 64
s_h, s_w = input_shape[0], input_shape[1]
# 32, 32
s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
# 16, 16
s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
# 8, 8
s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
# 4, 4
self.s_h16, self.s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)
# (bs, 100)-> (bs, 4*4*128*8)
self.linear = nn.Linear(100, self.s_h16 * self.s_w16 * d * 8)
self.linear_bn = nn.BatchNorm2d(d * 8)

# (bs, 1024, 4, 4)->(bs, 512, 8, 8)
self.deconv1 = nn.ConvTranspose2d(d * 8, d * 4, 4, 2, 1)
self.deconv1_bn = nn.BatchNorm2d(d * 4)

# (bs, 512, 8, 8)->(bs, 256, 16, 16)
self.deconv2 = nn.ConvTranspose2d(d * 4, d * 2, 4, 2, 1)
self.deconv2_bn = nn.BatchNorm2d(d * 2)

# (bs, 256, 16, 16)->(bs, 128, 32, 32)
self.deconv3 = nn.ConvTranspose2d(d * 2, d, 4, 2, 1)
self.deconv3_bn = nn.BatchNorm2d(d)

# (bs, 128, 8, 8)->(bs, 3, 64, 64)
self.deconv4 = nn.ConvTranspose2d(d, 3, 4, 2, 1)

self.relu = nn.ReLU()

def weight_init(self):
for m in self.modules():
if isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(0.1, 0.02)
m.bias.data.fill_(0)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)

def forward(self, x):
# (bs, 100)
bs, _ = x.size()
# (bs, 16*1024)
x = self.linear(x)
# (bs, 1024, 4, 4)
x = x.view([bs, -1, self.s_h16, self.s_w16])
x = self.relu(self.linear_bn(x))
# (bs, 1024, 4, 4)->(bs, 512, 8, 8)
x = self.relu(self.deconv1_bn(self.deconv1(x)))
# (bs, 512, 8, 8)->(bs, 256, 16, 16)
x = self.relu(self.deconv2_bn(self.deconv2(x)))
# (bs, 256, 16, 16)->(bs, 128, 32, 32)
x = self.relu(self.deconv3_bn(self.deconv3(x)))
# (bs, 128, 32, 32)->(bs, 3, 64, 64)
x = torch.tanh(self.deconv4(x))
return x

判别网络的构建

对于生成网络来讲,它的目的是生成假图片,它的输入是正态分布随机数。输出是假图片。

对于判别网络来讲,它的目的是判断输入图片的真假,它的输入是图片,输出是判断结果。

判断结果处于0-1之间,利用接近1代表判断为真图片,接近0代表判断为假图片。

判别网络的构建和普通卷积网络差距不大,都是不断的卷积对图片进行下采用,在多次卷积后,最终接一次全连接判断结果。

实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import math


def conv_out_size_same(size, stride):
return int(math.ceil(float(size) / float(stride)))


class discriminator(nn.Module):
def __init__(self, d=128, input_shape=[64, 64]):
super(discriminator, self).__init__()
s_h, s_w = input_shape[0], input_shape[1]
s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
self.s_h16, self.s_w16 = conv_out_size_same(
s_h8, 2), conv_out_size_same(s_w8, 2)

# 64,64,3 -> 32,32,128
self.conv1 = nn.Conv2d(3, d, 4, 2, 1)

# 32,32,128 -> 16,16,256
self.conv2 = nn.Conv2d(d, d * 2, 4, 2, 1)
self.conv2_bn = nn.BatchNorm2d(d * 2)

# 16,16,256 -> 8,8,512
self.conv3 = nn.Conv2d(d * 2, d * 4, 4, 2, 1)
self.conv3_bn = nn.BatchNorm2d(d * 4)

# 8,8,512 -> 4,4,1024
self.conv4 = nn.Conv2d(d * 4, d * 8, 4, 2, 1)
self.conv4_bn = nn.BatchNorm2d(d * 8)

# 4,4,1024 -> 1
self.linear = nn.Linear(self.s_h16 * self.s_w16 * d * 8, 1)

self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
self.sigmoid = nn.Sigmoid()

def weight_init(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(0.1, 0.02)
m.bias.data.fill_(0)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)

def forward(self, x):
bs, _, _, _ = x.size()
# (3, 64, 64)->(128, 32, 32)
x = self.leaky_relu(self.conv1(x))
# (128, 32, 32)->(256, 16, 16)
x = self.leaky_relu(self.conv2_bn(self.conv2(x)))
# (256, 16, 16)->(512, 8, 8)
x = self.leaky_relu(self.conv3_bn(self.conv3(x)))
# (512, 8, 8)->(1024, 4, 4)
x = self.leaky_relu(self.conv4_bn(self.conv4(x)))
# (1024, 4, 4)->(bs, 16*1024)
x = x.view([bs, -1])
# (bs, 16*1024)->(bs, 1)
x = self.sigmoid(self.linear(x))

return x.squeeze()

训练思路

DCGAN的训练可以分为生成器训练和判别器训练,每一个step中一般先训练判别器,然后训练生成器

判别器的训练

在训练判别器的时候我们希望判别器可以判断输入图片的真伪,因此我们的输入就是真图片、假图片和它们对应的标签。

因此判别器的训练步骤如下:

1、随机选取batch_size个真实的图片。
2、随机生成batch_size个N维向量,传入到Generator中生成batch_size个虚假图片。
3、真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练。

生成器训练

在训练生成器的时候我们希望生成器可以生成极为真实的假图片。因此我们在训练生成器需要知道判别器认为什么图片是真图片。

因此生成器的训练步骤如下:

1、随机生成batch_size个N维向量,传入到Generator中生成batch_size个虚假图片。
2、将虚假图片的Discriminator预测结果与1的对比作为loss对Generator进行训练(与1对比的意思是,让生成器根据判别器判别的结果进行训练)

利用DCGAN生成图片

详情见源码和参考博文