diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-07-03 17:18:13 -0400 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-07-03 17:18:13 -0400 |
| commit | 233630e79d79901faff420eb0ae481b35d952f97 (patch) | |
| tree | 66b98747d7c0a97b37e2921ecbc378ae994aef35 /models/networks.py | |
| parent | 11690eaffc7dcdc0f64267263f5d7a3b4fc735cf (diff) | |
fix instancenorm & batchnorm
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 40 |
1 files changed, 22 insertions, 18 deletions
diff --git a/models/networks.py b/models/networks.py index 1a0bc1c..a2ddbdf 100644 --- a/models/networks.py +++ b/models/networks.py @@ -1,5 +1,7 @@ import torch import torch.nn as nn +from torch.nn import init +import functools from torch.autograd import Variable import numpy as np ############################################################################### @@ -11,19 +13,21 @@ def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) - elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1: + elif classname.find('BatchNorm2d') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) -def get_norm_layer(norm_type): + +def get_norm_layer(norm_type='instance'): if norm_type == 'batch': - norm_layer = nn.BatchNorm2d + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': - norm_layer = nn.InstanceNorm2d + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) else: print('normalization layer [%s] is not found' % norm) return norm_layer + def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[]): netG = None use_gpu = len(gpu_ids) > 0 @@ -137,7 +141,7 @@ class ResnetGenerator(nn.Module): self.gpu_ids = gpu_ids model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3), - norm_layer(ngf, affine=True), + norm_layer(ngf), nn.ReLU(True)] n_downsampling = 2 @@ -145,7 +149,7 @@ class ResnetGenerator(nn.Module): mult = 2**i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), - norm_layer(ngf * mult * 2, affine=True), + norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2**n_downsampling @@ -157,7 +161,7 @@ class ResnetGenerator(nn.Module): model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), - norm_layer(int(ngf * mult / 2), affine=True), + norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=3)] @@ -187,12 +191,12 @@ class ResnetBlock(nn.Module): # TODO: InstanceNorm conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), - norm_layer(dim, affine=True), + norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), - norm_layer(dim, affine=True)] + norm_layer(dim)] return nn.Sequential(*conv_block) @@ -215,7 +219,7 @@ class UnetGenerator(nn.Module): assert(input_nc == output_nc) # construct unet structure - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True) + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True) for i in range(num_downs - 5): unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout) unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer) @@ -226,7 +230,7 @@ class UnetGenerator(nn.Module): self.model = unet_block def forward(self, input): - if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) @@ -244,9 +248,9 @@ class UnetSkipConnectionBlock(nn.Module): downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, stride=2, padding=1) downrelu = nn.LeakyReLU(0.2, True) - downnorm = norm_layer(inner_nc, affine=True) + downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) - upnorm = norm_layer(outer_nc, affine=True) + upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, @@ -303,9 +307,9 @@ class NLayerDiscriminator(nn.Module): nf_mult = min(2**n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, - kernel_size=kw, stride=2, padding=padw), + kernel_size=kw, stride=2, padding=padw), # TODO: use InstanceNorm - norm_layer(ndf * nf_mult, affine=True), + norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] @@ -313,9 +317,9 @@ class NLayerDiscriminator(nn.Module): nf_mult = min(2**n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, - kernel_size=kw, stride=1, padding=padw), + kernel_size=kw, stride=1, padding=padw), # TODO: useInstanceNorm - norm_layer(ndf * nf_mult, affine=True), + norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] @@ -327,7 +331,7 @@ class NLayerDiscriminator(nn.Module): self.model = nn.Sequential(*sequence) def forward(self, input): - if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): + if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) |
