1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
| import torch import torch.nn as nn import math
def conv_out_size_same(size, stride): return int(math.ceil(float(size) / float(stride)))
class discriminator(nn.Module): def __init__(self, d=128, input_shape=[64, 64]): super(discriminator, self).__init__() s_h, s_w = input_shape[0], input_shape[1] s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2) s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2) s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2) self.s_h16, self.s_w16 = conv_out_size_same( s_h8, 2), conv_out_size_same(s_w8, 2)
self.conv1 = nn.Conv2d(3, d, 4, 2, 1)
self.conv2 = nn.Conv2d(d, d * 2, 4, 2, 1) self.conv2_bn = nn.BatchNorm2d(d * 2)
self.conv3 = nn.Conv2d(d * 2, d * 4, 4, 2, 1) self.conv3_bn = nn.BatchNorm2d(d * 4)
self.conv4 = nn.Conv2d(d * 4, d * 8, 4, 2, 1) self.conv4_bn = nn.BatchNorm2d(d * 8)
self.linear = nn.Linear(self.s_h16 * self.s_w16 * d * 8, 1)
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2) self.sigmoid = nn.Sigmoid()
def weight_init(self): for m in self.modules(): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0.0, 0.02) elif isinstance(m, nn.BatchNorm2d): m.weight.data.normal_(0.1, 0.02) m.bias.data.fill_(0) elif isinstance(m, nn.Linear): m.weight.data.normal_(0.0, 0.02) m.bias.data.fill_(0)
def forward(self, x): bs, _, _, _ = x.size() x = self.leaky_relu(self.conv1(x)) x = self.leaky_relu(self.conv2_bn(self.conv2(x))) x = self.leaky_relu(self.conv3_bn(self.conv3(x))) x = self.leaky_relu(self.conv4_bn(self.conv4(x))) x = x.view([bs, -1]) x = self.sigmoid(self.linear(x))
return x.squeeze()
|