diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-07-03 17:18:13 -0400 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-07-03 17:18:13 -0400 |
| commit | 233630e79d79901faff420eb0ae481b35d952f97 (patch) | |
| tree | 66b98747d7c0a97b37e2921ecbc378ae994aef35 | |
| parent | 11690eaffc7dcdc0f64267263f5d7a3b4fc735cf (diff) | |
fix instancenorm & batchnorm
| -rw-r--r-- | README.md | 4 | ||||
| -rw-r--r-- | models/networks.py | 40 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 7 | ||||
| -rw-r--r-- | scripts/test_pix2pix.sh | 2 | ||||
| -rw-r--r-- | scripts/test_single.sh | 2 | ||||
| -rw-r--r-- | scripts/train_pix2pix.sh | 2 |
6 files changed, 31 insertions, 26 deletions
@@ -87,13 +87,13 @@ bash ./datasets/download_pix2pix_dataset.sh facades - Train a model: ```bash #!./scripts/train_pix2pix.sh -python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --dataset_mode aligned --use_dropout --no_lsgan +python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --dataset_mode aligned --use_dropout --no_lsgan --norm batch ``` - To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. To see more intermediate results, check out `./checkpoints/facades_pix2pix/web/index.html` - Test the model (`bash ./scripts/test_pix2pix.sh`): ```bash #!./scripts/test_pix2pix.sh -python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --dataset_mode aligned +python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --dataset_mode aligned --norm batch ``` The test results will be saved to a html file here: `./results/facades_pix2pix/latest_val/index.html`. diff --git a/models/networks.py b/models/networks.py index 1a0bc1c..a2ddbdf 100644 --- a/models/networks.py +++ b/models/networks.py @@ -1,5 +1,7 @@ import torch import torch.nn as nn +from torch.nn import init +import functools from torch.autograd import Variable import numpy as np ############################################################################### @@ -11,19 +13,21 @@ def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) - elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1: + elif classname.find('BatchNorm2d') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) -def get_norm_layer(norm_type): + +def get_norm_layer(norm_type='instance'): if norm_type == 'batch': - norm_layer = nn.BatchNorm2d + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': - norm_layer = nn.InstanceNorm2d + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) else: print('normalization layer [%s] is not found' % norm) return norm_layer + def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[]): netG = None use_gpu = len(gpu_ids) > 0 @@ -137,7 +141,7 @@ class ResnetGenerator(nn.Module): self.gpu_ids = gpu_ids model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3), - norm_layer(ngf, affine=True), + norm_layer(ngf), nn.ReLU(True)] n_downsampling = 2 @@ -145,7 +149,7 @@ class ResnetGenerator(nn.Module): mult = 2**i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), - norm_layer(ngf * mult * 2, affine=True), + norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2**n_downsampling @@ -157,7 +161,7 @@ class ResnetGenerator(nn.Module): model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), - norm_layer(int(ngf * mult / 2), affine=True), + norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=3)] @@ -187,12 +191,12 @@ class ResnetBlock(nn.Module): # TODO: InstanceNorm conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), - norm_layer(dim, affine=True), + norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), - norm_layer(dim, affine=True)] + norm_layer(dim)] return nn.Sequential(*conv_block) @@ -215,7 +219,7 @@ class UnetGenerator(nn.Module): assert(input_nc == output_nc) # construct unet structure - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True) + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True) for i in range(num_downs - 5): unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout) unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer) @@ -226,7 +230,7 @@ class UnetGenerator(nn.Module): self.model = unet_block def forward(self, input): - if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) @@ -244,9 +248,9 @@ class UnetSkipConnectionBlock(nn.Module): downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, stride=2, padding=1) downrelu = nn.LeakyReLU(0.2, True) - downnorm = norm_layer(inner_nc, affine=True) + downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) - upnorm = norm_layer(outer_nc, affine=True) + upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, @@ -303,9 +307,9 @@ class NLayerDiscriminator(nn.Module): nf_mult = min(2**n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, - kernel_size=kw, stride=2, padding=padw), + kernel_size=kw, stride=2, padding=padw), # TODO: use InstanceNorm - norm_layer(ndf * nf_mult, affine=True), + norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] @@ -313,9 +317,9 @@ class NLayerDiscriminator(nn.Module): nf_mult = min(2**n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, - kernel_size=kw, stride=1, padding=padw), + kernel_size=kw, stride=1, padding=padw), # TODO: useInstanceNorm - norm_layer(ndf * nf_mult, affine=True), + norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] @@ -327,7 +331,7 @@ class NLayerDiscriminator(nn.Module): self.model = nn.Sequential(*sequence) def forward(self, input): - if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): + if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index e44529b..3ab45fd 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -48,10 +48,11 @@ class Pix2PixModel(BaseModel): self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) - print('---------- Networks initialized -------------') - networks.print_network(self.netG) + print('---------- Networks initialized -------------') + networks.print_network(self.netG) + if self.isTrain: networks.print_network(self.netD) - print('-----------------------------------------------') + print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' diff --git a/scripts/test_pix2pix.sh b/scripts/test_pix2pix.sh index b821878..c6d6b32 100644 --- a/scripts/test_pix2pix.sh +++ b/scripts/test_pix2pix.sh @@ -1 +1 @@ -python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --dataset_mode aligned --use_dropout +python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --dataset_mode aligned --use_dropout --norm batch diff --git a/scripts/test_single.sh b/scripts/test_single.sh index 6157b29..1eb8580 100644 --- a/scripts/test_single.sh +++ b/scripts/test_single.sh @@ -1 +1 @@ -python test.py --dataroot ./datasets/facades/testB/ --name facades_pix2pix --model test --which_model_netG unet_256 --which_direction BtoA --dataset_mode single --use_dropout +python test.py --dataroot ./datasets/facades/testB/ --name facades_pix2pix --model test --which_model_netG unet_256 --which_direction BtoA --dataset_mode single --use_dropout --norm batch diff --git a/scripts/train_pix2pix.sh b/scripts/train_pix2pix.sh index bf45c84..88031dd 100644 --- a/scripts/train_pix2pix.sh +++ b/scripts/train_pix2pix.sh @@ -1 +1 @@ -python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --dataset_mode aligned --no_lsgan --use_dropout +python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --dataset_mode aligned --no_lsgan --use_dropout --norm batch |
