diff options
| author | hazirbas <hazirbas@cs.tum.edu> | 2017-08-22 20:29:48 +0200 |
|---|---|---|
| committer | hazirbas <hazirbas@cs.tum.edu> | 2017-08-22 20:29:48 +0200 |
| commit | aa26d40a34cfc62cb209bb8c471643e7779e016a (patch) | |
| tree | 2cb34c776539e1d3fb5384485076bce170e3a2b5 /models/networks.py | |
| parent | 104f95345ee9cf51581a915ddd79e79584e2a356 (diff) | |
initialized bias of conv to zero,
deactivated bias for the conv layers which are followed by batch normalization.
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 38 |
1 files changed, 27 insertions, 11 deletions
diff --git a/models/networks.py b/models/networks.py index 12da13b..db36ac4 100644 --- a/models/networks.py +++ b/models/networks.py @@ -13,6 +13,8 @@ def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0) elif classname.find('BatchNorm2d') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) @@ -139,9 +141,14 @@ class ResnetGenerator(nn.Module): self.output_nc = output_nc self.ngf = ngf self.gpu_ids = gpu_ids + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d model = [nn.ReflectionPad2d(3), - nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, + bias=use_bias), norm_layer(ngf), nn.ReLU(True)] @@ -149,7 +156,7 @@ class ResnetGenerator(nn.Module): for i in range(n_downsampling): mult = 2**i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, - stride=2, padding=1), + stride=2, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.ReLU(True)] @@ -161,7 +168,8 @@ class ResnetGenerator(nn.Module): mult = 2**(n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, - padding=1, output_padding=1), + padding=1, output_padding=1, + bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.ReflectionPad2d(3)] @@ -179,7 +187,7 @@ class ResnetGenerator(nn.Module): # Define a resnet block class ResnetBlock(nn.Module): - def __init__(self, dim, padding_type, norm_layer, use_dropout): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout) @@ -195,7 +203,7 @@ class ResnetBlock(nn.Module): else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) - conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] if use_dropout: @@ -210,7 +218,7 @@ class ResnetBlock(nn.Module): p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) - conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] return nn.Sequential(*conv_block) @@ -259,9 +267,13 @@ class UnetSkipConnectionBlock(nn.Module): submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, - stride=2, padding=1) + stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) @@ -277,14 +289,14 @@ class UnetSkipConnectionBlock(nn.Module): elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, - padding=1) + padding=1, bias=use_bias) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, - padding=1) + padding=1, bias=use_bias) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] @@ -307,6 +319,10 @@ class NLayerDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): super(NLayerDiscriminator, self).__init__() self.gpu_ids = gpu_ids + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d kw = 4 padw = int(np.ceil((kw-1)/2)) @@ -322,7 +338,7 @@ class NLayerDiscriminator(nn.Module): nf_mult = min(2**n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, - kernel_size=kw, stride=2, padding=padw), + kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] @@ -331,7 +347,7 @@ class NLayerDiscriminator(nn.Module): nf_mult = min(2**n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, - kernel_size=kw, stride=1, padding=padw), + kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] |
