summaryrefslogtreecommitdiff
path: root/models/networks.py
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-07-05 19:08:39 -0400
committerjunyanz <junyanz@berkeley.edu>2017-07-05 19:08:39 -0400
commitc7f7d1979a35b443dba7e776203ed7084efecf77 (patch)
tree65652aaad410182eb6cb8bbc1accb8a87d75dc92 /models/networks.py
parente77d1352c0618adf8abf348b04647dd86e8890c1 (diff)
add reflection padding layer
Diffstat (limited to 'models/networks.py')
-rw-r--r--models/networks.py43
1 files changed, 28 insertions, 15 deletions
diff --git a/models/networks.py b/models/networks.py
index a2ddbdf..12da13b 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -24,7 +24,7 @@ def get_norm_layer(norm_type='instance'):
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
- print('normalization layer [%s] is not found' % norm)
+ raise NotImplementedError('normalization layer [%s] is not found' % norm)
return norm_layer
@@ -45,7 +45,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo
elif which_model_netG == 'unet_256':
netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
else:
- print('Generator model name [%s] is not recognized' % which_model_netG)
+ raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
netG.cuda(device_id=gpu_ids[0])
netG.apply(weights_init)
@@ -65,8 +65,8 @@ def define_D(input_nc, ndf, which_model_netD,
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
else:
- print('Discriminator model name [%s] is not recognized' %
- which_model_netD)
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' %
+ which_model_netD)
if use_gpu:
netD.cuda(device_id=gpu_ids[0])
netD.apply(weights_init)
@@ -132,7 +132,7 @@ class GANLoss(nn.Module):
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
- def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[]):
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
@@ -140,7 +140,8 @@ class ResnetGenerator(nn.Module):
self.ngf = ngf
self.gpu_ids = gpu_ids
- model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3),
+ model = [nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
nn.ReLU(True)]
@@ -154,7 +155,7 @@ class ResnetGenerator(nn.Module):
mult = 2**n_downsampling
for i in range(n_blocks):
- model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer, use_dropout=use_dropout)]
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
@@ -163,8 +164,8 @@ class ResnetGenerator(nn.Module):
padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
-
- model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=3)]
+ model += [nn.ReflectionPad2d(3)]
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
@@ -185,16 +186,30 @@ class ResnetBlock(nn.Module):
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout):
conv_block = []
p = 0
- # TODO: support padding types
- assert(padding_type == 'zero')
- p = 1
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
- # TODO: InstanceNorm
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
+
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim)]
@@ -308,7 +323,6 @@ class NLayerDiscriminator(nn.Module):
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw),
- # TODO: use InstanceNorm
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
@@ -318,7 +332,6 @@ class NLayerDiscriminator(nn.Module):
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw),
- # TODO: useInstanceNorm
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]