diff options
| author | Taesung Park <taesung_park@berkeley.edu> | 2017-11-08 11:37:42 -0800 |
|---|---|---|
| committer | Taesung Park <taesung_park@berkeley.edu> | 2017-11-08 11:37:42 -0800 |
| commit | b6f5966eb8224dfc7be68b1b67a87f006e42730d (patch) | |
| tree | c34a3f4746f1453fd83aeeeb2135eb7b6f0afb63 /models/networks.py | |
| parent | 5e0f7d6980ed1a1aaac8593351028d320e5f0a94 (diff) | |
working version with handwritten GAN loss. Shift value can be changed
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/models/networks.py b/models/networks.py index 2df58fe..965bacb 100644 --- a/models/networks.py +++ b/models/networks.py @@ -136,6 +136,8 @@ def define_D(input_nc, ndf, which_model_netD, netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'n_layers': netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'pixel': + netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) @@ -432,3 +434,32 @@ class NLayerDiscriminator(nn.Module): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) + +class PixelDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): + super(PixelDiscriminator, self).__init__() + self.gpu_ids = gpu_ids + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + if use_sigmoid: + self.net.append(nn.Sigmoid()) + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): + return nn.parallel.data_parallel(self.net, input, self.gpu_ids) + else: + return self.net(input) + |
