diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-04-22 08:15:48 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-04-22 08:15:48 -0700 |
| commit | 6c347282993d2e2db91b376d3113efa3774c3a22 (patch) | |
| tree | 75ede706fde3f61e73fbb9e7ea9ed6e97aff2a56 /models/networks.py | |
| parent | a7917caaeaefe51db959b8f3ae50a20e726fbd93 (diff) | |
add dropout option for G
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 34 |
1 files changed, 20 insertions, 14 deletions
diff --git a/models/networks.py b/models/networks.py index 2e3ad79..60e1777 100644 --- a/models/networks.py +++ b/models/networks.py @@ -18,7 +18,7 @@ def weights_init(m): m.bias.data.fill_(0) -def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]): +def define_G(input_nc, output_nc, ngf, which_model_netG, norm, use_dropout=False, gpu_ids=[]): netG = None use_gpu = len(gpu_ids) > 0 if norm == 'batch': @@ -31,13 +31,13 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]): assert(torch.cuda.is_available()) if which_model_netG == 'resnet_9blocks': - netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids) + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) elif which_model_netG == 'resnet_6blocks': - netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=6, gpu_ids=gpu_ids) + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) elif which_model_netG == 'unet_128': - netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, gpu_ids=gpu_ids) + netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) elif which_model_netG == 'unet_256': - netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, gpu_ids=gpu_ids) + netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) else: print('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: @@ -124,7 +124,7 @@ class GANLoss(nn.Module): # Code and idea originally from Justin Johnson's architecture. # https://github.com/jcjohnson/fast-neural-style/ class ResnetGenerator(nn.Module): - def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, n_blocks=6, gpu_ids=[]): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[]): assert(n_blocks >= 0) super(ResnetGenerator, self).__init__() self.input_nc = input_nc @@ -146,7 +146,7 @@ class ResnetGenerator(nn.Module): mult = 2**n_downsampling for i in range(n_blocks): - model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer)] + model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer, use_dropout=use_dropout)] for i in range(n_downsampling): mult = 2**(n_downsampling - i) @@ -170,11 +170,11 @@ class ResnetGenerator(nn.Module): # Define a resnet block class ResnetBlock(nn.Module): - def __init__(self, dim, padding_type, norm_layer): + def __init__(self, dim, padding_type, norm_layer, use_dropout): super(ResnetBlock, self).__init__() - self.conv_block = self.build_conv_block(dim, padding_type, norm_layer) + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout) - def build_conv_block(self, dim, padding_type, norm_layer): + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout): conv_block = [] p = 0 # TODO: support padding types @@ -185,6 +185,8 @@ class ResnetBlock(nn.Module): conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim), nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim)] @@ -201,7 +203,7 @@ class ResnetBlock(nn.Module): # at the bottleneck class UnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, num_downs, ngf=64, - norm_layer=nn.BatchNorm2d, gpu_ids=[]): + norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]): super(UnetGenerator, self).__init__() self.gpu_ids = gpu_ids @@ -211,7 +213,7 @@ class UnetGenerator(nn.Module): # construct unet structure unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True) for i in range(num_downs - 5): - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block) + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, use_dropout=use_dropout) unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block) unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block) unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block) @@ -231,7 +233,7 @@ class UnetGenerator(nn.Module): # |-- downsampling -- |submodule| -- upsampling --| class UnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, - submodule=None, outermost=False, innermost=False): + submodule=None, outermost=False, innermost=False, use_dropout=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost @@ -262,7 +264,11 @@ class UnetSkipConnectionBlock(nn.Module): padding=1) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] - model = down + [submodule] + up + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up self.model = nn.Sequential(*model) |
