summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/cycle_gan_model.py4
-rw-r--r--models/networks.py34
-rw-r--r--models/pix2pix_model.py2
3 files changed, 23 insertions, 17 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index eb1c443..d361e47 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -28,9 +28,9 @@ class CycleGANModel(BaseModel):
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
- opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids)
+ opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
- opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids)
+ opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids)
if self.isTrain:
use_sigmoid = opt.no_lsgan
diff --git a/models/networks.py b/models/networks.py
index 2e3ad79..60e1777 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -18,7 +18,7 @@ def weights_init(m):
m.bias.data.fill_(0)
-def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]):
+def define_G(input_nc, output_nc, ngf, which_model_netG, norm, use_dropout=False, gpu_ids=[]):
netG = None
use_gpu = len(gpu_ids) > 0
if norm == 'batch':
@@ -31,13 +31,13 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]):
assert(torch.cuda.is_available())
if which_model_netG == 'resnet_9blocks':
- netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids)
+ netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)
elif which_model_netG == 'resnet_6blocks':
- netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=6, gpu_ids=gpu_ids)
+ netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_128':
- netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, gpu_ids=gpu_ids)
+ netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_256':
- netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, gpu_ids=gpu_ids)
+ netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
else:
print('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
@@ -124,7 +124,7 @@ class GANLoss(nn.Module):
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
- def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, n_blocks=6, gpu_ids=[]):
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[]):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
@@ -146,7 +146,7 @@ class ResnetGenerator(nn.Module):
mult = 2**n_downsampling
for i in range(n_blocks):
- model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer)]
+ model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer, use_dropout=use_dropout)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
@@ -170,11 +170,11 @@ class ResnetGenerator(nn.Module):
# Define a resnet block
class ResnetBlock(nn.Module):
- def __init__(self, dim, padding_type, norm_layer):
+ def __init__(self, dim, padding_type, norm_layer, use_dropout):
super(ResnetBlock, self).__init__()
- self.conv_block = self.build_conv_block(dim, padding_type, norm_layer)
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout)
- def build_conv_block(self, dim, padding_type, norm_layer):
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout):
conv_block = []
p = 0
# TODO: support padding types
@@ -185,6 +185,8 @@ class ResnetBlock(nn.Module):
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
nn.ReLU(True)]
+ if use_dropout:
+ conv_block += [nn.Dropout(0.5)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim)]
@@ -201,7 +203,7 @@ class ResnetBlock(nn.Module):
# at the bottleneck
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
- norm_layer=nn.BatchNorm2d, gpu_ids=[]):
+ norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]):
super(UnetGenerator, self).__init__()
self.gpu_ids = gpu_ids
@@ -211,7 +213,7 @@ class UnetGenerator(nn.Module):
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True)
for i in range(num_downs - 5):
- unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block)
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, use_dropout=use_dropout)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block)
@@ -231,7 +233,7 @@ class UnetGenerator(nn.Module):
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc,
- submodule=None, outermost=False, innermost=False):
+ submodule=None, outermost=False, innermost=False, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
@@ -262,7 +264,11 @@ class UnetSkipConnectionBlock(nn.Module):
padding=1)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
- model = down + [submodule] + up
+
+ if use_dropout:
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
+ else:
+ model = down + [submodule] + up
self.model = nn.Sequential(*model)
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index 3bdd237..0e02ebf 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -25,7 +25,7 @@ class Pix2PixModel(BaseModel):
# load/define networks
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
- opt.which_model_netG, opt.norm, self.gpu_ids)
+ opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids)
if self.isTrain:
use_sigmoid = opt.no_lsgan
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,