summaryrefslogtreecommitdiff
path: root/models/networks.py
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-07-03 17:18:13 -0400
committerjunyanz <junyanz@berkeley.edu>2017-07-03 17:18:13 -0400
commit233630e79d79901faff420eb0ae481b35d952f97 (patch)
tree66b98747d7c0a97b37e2921ecbc378ae994aef35 /models/networks.py
parent11690eaffc7dcdc0f64267263f5d7a3b4fc735cf (diff)
fix instancenorm & batchnorm
Diffstat (limited to 'models/networks.py')
-rw-r--r--models/networks.py40
1 files changed, 22 insertions, 18 deletions
diff --git a/models/networks.py b/models/networks.py
index 1a0bc1c..a2ddbdf 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -1,5 +1,7 @@
import torch
import torch.nn as nn
+from torch.nn import init
+import functools
from torch.autograd import Variable
import numpy as np
###############################################################################
@@ -11,19 +13,21 @@ def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
- elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1:
+ elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
-def get_norm_layer(norm_type):
+
+def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
- norm_layer = nn.BatchNorm2d
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
- norm_layer = nn.InstanceNorm2d
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
print('normalization layer [%s] is not found' % norm)
return norm_layer
+
def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[]):
netG = None
use_gpu = len(gpu_ids) > 0
@@ -137,7 +141,7 @@ class ResnetGenerator(nn.Module):
self.gpu_ids = gpu_ids
model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3),
- norm_layer(ngf, affine=True),
+ norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
@@ -145,7 +149,7 @@ class ResnetGenerator(nn.Module):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1),
- norm_layer(ngf * mult * 2, affine=True),
+ norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
@@ -157,7 +161,7 @@ class ResnetGenerator(nn.Module):
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1),
- norm_layer(int(ngf * mult / 2), affine=True),
+ norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=3)]
@@ -187,12 +191,12 @@ class ResnetBlock(nn.Module):
# TODO: InstanceNorm
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
- norm_layer(dim, affine=True),
+ norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
- norm_layer(dim, affine=True)]
+ norm_layer(dim)]
return nn.Sequential(*conv_block)
@@ -215,7 +219,7 @@ class UnetGenerator(nn.Module):
assert(input_nc == output_nc)
# construct unet structure
- unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True)
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer)
@@ -226,7 +230,7 @@ class UnetGenerator(nn.Module):
self.model = unet_block
def forward(self, input):
- if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
+ if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
@@ -244,9 +248,9 @@ class UnetSkipConnectionBlock(nn.Module):
downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
stride=2, padding=1)
downrelu = nn.LeakyReLU(0.2, True)
- downnorm = norm_layer(inner_nc, affine=True)
+ downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
- upnorm = norm_layer(outer_nc, affine=True)
+ upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
@@ -303,9 +307,9 @@ class NLayerDiscriminator(nn.Module):
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
- kernel_size=kw, stride=2, padding=padw),
+ kernel_size=kw, stride=2, padding=padw),
# TODO: use InstanceNorm
- norm_layer(ndf * nf_mult, affine=True),
+ norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
@@ -313,9 +317,9 @@ class NLayerDiscriminator(nn.Module):
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
- kernel_size=kw, stride=1, padding=padw),
+ kernel_size=kw, stride=1, padding=padw),
# TODO: useInstanceNorm
- norm_layer(ndf * nf_mult, affine=True),
+ norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
@@ -327,7 +331,7 @@ class NLayerDiscriminator(nn.Module):
self.model = nn.Sequential(*sequence)
def forward(self, input):
- if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
+ if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)