summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/networks.py31
-rw-r--r--pretrained_models/download_cyclegan_model.sh13
-rw-r--r--test.py2
3 files changed, 45 insertions, 1 deletions
diff --git a/models/networks.py b/models/networks.py
index 3c54138..568f8c9 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -135,6 +135,8 @@ def define_D(input_nc, ndf, which_model_netD,
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
+ elif which_model_netD == 'pixel':
+ netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
@@ -431,3 +433,32 @@ class NLayerDiscriminator(nn.Module):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
+
+class PixelDiscriminator(nn.Module):
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
+ super(PixelDiscriminator, self).__init__()
+ self.gpu_ids = gpu_ids
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ self.net = [
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
+ norm_layer(ndf * 2),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
+
+ if use_sigmoid:
+ self.net.append(nn.Sigmoid())
+
+ self.net = nn.Sequential(*self.net)
+
+ def forward(self, input):
+ if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
+ return nn.parallel.data_parallel(self.net, input, self.gpu_ids)
+ else:
+ return self.net(input)
+
diff --git a/pretrained_models/download_cyclegan_model.sh b/pretrained_models/download_cyclegan_model.sh
new file mode 100644
index 0000000..91f0021
--- /dev/null
+++ b/pretrained_models/download_cyclegan_model.sh
@@ -0,0 +1,13 @@
+FILE=$1
+
+echo "Note: available models are horse2zebra, zebra2horse"
+
+echo "Specified [$FILE]"
+
+mkdir -p ./checkpoints/${FILE}_pretrained
+MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.pth
+URL=https://people.eecs.berkeley.edu/~taesung_park/pytorch-CycleGAN-and-pix2pix/models/$FILE.pth
+
+wget -N $URL -O $MODEL_FILE
+
+
diff --git a/test.py b/test.py
index 89feae9..fc0f1bb 100644
--- a/test.py
+++ b/test.py
@@ -27,7 +27,7 @@ for i, data in enumerate(dataset):
model.test()
visuals = model.get_current_visuals()
img_path = model.get_image_paths()
- print('process image... %s' % img_path)
+ print('%04d: process image... %s' % (i, img_path))
visualizer.save_images(webpage, visuals, img_path)
webpage.save()