diff options
| -rw-r--r-- | models/networks.py | 31 | ||||
| -rw-r--r-- | pretrained_models/download_cyclegan_model.sh | 13 | ||||
| -rw-r--r-- | test.py | 2 |
3 files changed, 45 insertions, 1 deletions
diff --git a/models/networks.py b/models/networks.py index 3c54138..568f8c9 100644 --- a/models/networks.py +++ b/models/networks.py @@ -135,6 +135,8 @@ def define_D(input_nc, ndf, which_model_netD, netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'n_layers': netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'pixel': + netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) @@ -431,3 +433,32 @@ class NLayerDiscriminator(nn.Module): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) + +class PixelDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): + super(PixelDiscriminator, self).__init__() + self.gpu_ids = gpu_ids + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + if use_sigmoid: + self.net.append(nn.Sigmoid()) + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): + return nn.parallel.data_parallel(self.net, input, self.gpu_ids) + else: + return self.net(input) + diff --git a/pretrained_models/download_cyclegan_model.sh b/pretrained_models/download_cyclegan_model.sh new file mode 100644 index 0000000..91f0021 --- /dev/null +++ b/pretrained_models/download_cyclegan_model.sh @@ -0,0 +1,13 @@ +FILE=$1 + +echo "Note: available models are horse2zebra, zebra2horse" + +echo "Specified [$FILE]" + +mkdir -p ./checkpoints/${FILE}_pretrained +MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.pth +URL=https://people.eecs.berkeley.edu/~taesung_park/pytorch-CycleGAN-and-pix2pix/models/$FILE.pth + +wget -N $URL -O $MODEL_FILE + + @@ -27,7 +27,7 @@ for i, data in enumerate(dataset): model.test() visuals = model.get_current_visuals() img_path = model.get_image_paths() - print('process image... %s' % img_path) + print('%04d: process image... %s' % (i, img_path)) visualizer.save_images(webpage, visuals, img_path) webpage.save() |
