summaryrefslogtreecommitdiff
path: root/models/networks.py
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-04-22 08:15:48 -0700
committerjunyanz <junyanz@berkeley.edu>2017-04-22 08:15:48 -0700
commit6c347282993d2e2db91b376d3113efa3774c3a22 (patch)
tree75ede706fde3f61e73fbb9e7ea9ed6e97aff2a56 /models/networks.py
parenta7917caaeaefe51db959b8f3ae50a20e726fbd93 (diff)
add dropout option for G
Diffstat (limited to 'models/networks.py')
-rw-r--r--models/networks.py34
1 files changed, 20 insertions, 14 deletions
diff --git a/models/networks.py b/models/networks.py
index 2e3ad79..60e1777 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -18,7 +18,7 @@ def weights_init(m):
m.bias.data.fill_(0)
-def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]):
+def define_G(input_nc, output_nc, ngf, which_model_netG, norm, use_dropout=False, gpu_ids=[]):
netG = None
use_gpu = len(gpu_ids) > 0
if norm == 'batch':
@@ -31,13 +31,13 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]):
assert(torch.cuda.is_available())
if which_model_netG == 'resnet_9blocks':
- netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids)
+ netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)
elif which_model_netG == 'resnet_6blocks':
- netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=6, gpu_ids=gpu_ids)
+ netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_128':
- netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, gpu_ids=gpu_ids)
+ netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_256':
- netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, gpu_ids=gpu_ids)
+ netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
else:
print('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
@@ -124,7 +124,7 @@ class GANLoss(nn.Module):
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
- def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, n_blocks=6, gpu_ids=[]):
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[]):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
@@ -146,7 +146,7 @@ class ResnetGenerator(nn.Module):
mult = 2**n_downsampling
for i in range(n_blocks):
- model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer)]
+ model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer, use_dropout=use_dropout)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
@@ -170,11 +170,11 @@ class ResnetGenerator(nn.Module):
# Define a resnet block
class ResnetBlock(nn.Module):
- def __init__(self, dim, padding_type, norm_layer):
+ def __init__(self, dim, padding_type, norm_layer, use_dropout):
super(ResnetBlock, self).__init__()
- self.conv_block = self.build_conv_block(dim, padding_type, norm_layer)
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout)
- def build_conv_block(self, dim, padding_type, norm_layer):
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout):
conv_block = []
p = 0
# TODO: support padding types
@@ -185,6 +185,8 @@ class ResnetBlock(nn.Module):
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
nn.ReLU(True)]
+ if use_dropout:
+ conv_block += [nn.Dropout(0.5)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim)]
@@ -201,7 +203,7 @@ class ResnetBlock(nn.Module):
# at the bottleneck
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
- norm_layer=nn.BatchNorm2d, gpu_ids=[]):
+ norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]):
super(UnetGenerator, self).__init__()
self.gpu_ids = gpu_ids
@@ -211,7 +213,7 @@ class UnetGenerator(nn.Module):
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True)
for i in range(num_downs - 5):
- unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block)
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, use_dropout=use_dropout)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block)
@@ -231,7 +233,7 @@ class UnetGenerator(nn.Module):
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc,
- submodule=None, outermost=False, innermost=False):
+ submodule=None, outermost=False, innermost=False, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
@@ -262,7 +264,11 @@ class UnetSkipConnectionBlock(nn.Module):
padding=1)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
- model = down + [submodule] + up
+
+ if use_dropout:
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
+ else:
+ model = down + [submodule] + up
self.model = nn.Sequential(*model)