diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-07-05 19:08:39 -0400 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-07-05 19:08:39 -0400 |
| commit | c7f7d1979a35b443dba7e776203ed7084efecf77 (patch) | |
| tree | 65652aaad410182eb6cb8bbc1accb8a87d75dc92 /models/networks.py | |
| parent | e77d1352c0618adf8abf348b04647dd86e8890c1 (diff) | |
add reflection padding layer
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 43 |
1 files changed, 28 insertions, 15 deletions
diff --git a/models/networks.py b/models/networks.py index a2ddbdf..12da13b 100644 --- a/models/networks.py +++ b/models/networks.py @@ -24,7 +24,7 @@ def get_norm_layer(norm_type='instance'): elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) else: - print('normalization layer [%s] is not found' % norm) + raise NotImplementedError('normalization layer [%s] is not found' % norm) return norm_layer @@ -45,7 +45,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo elif which_model_netG == 'unet_256': netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) else: - print('Generator model name [%s] is not recognized' % which_model_netG) + raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: netG.cuda(device_id=gpu_ids[0]) netG.apply(weights_init) @@ -65,8 +65,8 @@ def define_D(input_nc, ndf, which_model_netD, elif which_model_netD == 'n_layers': netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) else: - print('Discriminator model name [%s] is not recognized' % - which_model_netD) + raise NotImplementedError('Discriminator model name [%s] is not recognized' % + which_model_netD) if use_gpu: netD.cuda(device_id=gpu_ids[0]) netD.apply(weights_init) @@ -132,7 +132,7 @@ class GANLoss(nn.Module): # Code and idea originally from Justin Johnson's architecture. # https://github.com/jcjohnson/fast-neural-style/ class ResnetGenerator(nn.Module): - def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[]): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'): assert(n_blocks >= 0) super(ResnetGenerator, self).__init__() self.input_nc = input_nc @@ -140,7 +140,8 @@ class ResnetGenerator(nn.Module): self.ngf = ngf self.gpu_ids = gpu_ids - model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3), + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), nn.ReLU(True)] @@ -154,7 +155,7 @@ class ResnetGenerator(nn.Module): mult = 2**n_downsampling for i in range(n_blocks): - model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer, use_dropout=use_dropout)] + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)] for i in range(n_downsampling): mult = 2**(n_downsampling - i) @@ -163,8 +164,8 @@ class ResnetGenerator(nn.Module): padding=1, output_padding=1), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] - - model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=3)] + model += [nn.ReflectionPad2d(3)] + model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model) @@ -185,16 +186,30 @@ class ResnetBlock(nn.Module): def build_conv_block(self, dim, padding_type, norm_layer, use_dropout): conv_block = [] p = 0 - # TODO: support padding types - assert(padding_type == 'zero') - p = 1 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) - # TODO: InstanceNorm conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim)] @@ -308,7 +323,6 @@ class NLayerDiscriminator(nn.Module): sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw), - # TODO: use InstanceNorm norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] @@ -318,7 +332,6 @@ class NLayerDiscriminator(nn.Module): sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw), - # TODO: useInstanceNorm norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] |
