diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/cycle_gan_model.py | 4 | ||||
| -rw-r--r-- | models/networks.py | 34 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 2 |
3 files changed, 23 insertions, 17 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index eb1c443..d361e47 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -28,9 +28,9 @@ class CycleGANModel(BaseModel): # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, - opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) + opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, - opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) + opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan diff --git a/models/networks.py b/models/networks.py index 2e3ad79..60e1777 100644 --- a/models/networks.py +++ b/models/networks.py @@ -18,7 +18,7 @@ def weights_init(m): m.bias.data.fill_(0) -def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]): +def define_G(input_nc, output_nc, ngf, which_model_netG, norm, use_dropout=False, gpu_ids=[]): netG = None use_gpu = len(gpu_ids) > 0 if norm == 'batch': @@ -31,13 +31,13 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]): assert(torch.cuda.is_available()) if which_model_netG == 'resnet_9blocks': - netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids) + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) elif which_model_netG == 'resnet_6blocks': - netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=6, gpu_ids=gpu_ids) + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) elif which_model_netG == 'unet_128': - netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, gpu_ids=gpu_ids) + netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) elif which_model_netG == 'unet_256': - netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, gpu_ids=gpu_ids) + netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) else: print('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: @@ -124,7 +124,7 @@ class GANLoss(nn.Module): # Code and idea originally from Justin Johnson's architecture. # https://github.com/jcjohnson/fast-neural-style/ class ResnetGenerator(nn.Module): - def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, n_blocks=6, gpu_ids=[]): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[]): assert(n_blocks >= 0) super(ResnetGenerator, self).__init__() self.input_nc = input_nc @@ -146,7 +146,7 @@ class ResnetGenerator(nn.Module): mult = 2**n_downsampling for i in range(n_blocks): - model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer)] + model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer, use_dropout=use_dropout)] for i in range(n_downsampling): mult = 2**(n_downsampling - i) @@ -170,11 +170,11 @@ class ResnetGenerator(nn.Module): # Define a resnet block class ResnetBlock(nn.Module): - def __init__(self, dim, padding_type, norm_layer): + def __init__(self, dim, padding_type, norm_layer, use_dropout): super(ResnetBlock, self).__init__() - self.conv_block = self.build_conv_block(dim, padding_type, norm_layer) + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout) - def build_conv_block(self, dim, padding_type, norm_layer): + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout): conv_block = [] p = 0 # TODO: support padding types @@ -185,6 +185,8 @@ class ResnetBlock(nn.Module): conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim), nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim)] @@ -201,7 +203,7 @@ class ResnetBlock(nn.Module): # at the bottleneck class UnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, num_downs, ngf=64, - norm_layer=nn.BatchNorm2d, gpu_ids=[]): + norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]): super(UnetGenerator, self).__init__() self.gpu_ids = gpu_ids @@ -211,7 +213,7 @@ class UnetGenerator(nn.Module): # construct unet structure unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True) for i in range(num_downs - 5): - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block) + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, use_dropout=use_dropout) unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block) unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block) unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block) @@ -231,7 +233,7 @@ class UnetGenerator(nn.Module): # |-- downsampling -- |submodule| -- upsampling --| class UnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, - submodule=None, outermost=False, innermost=False): + submodule=None, outermost=False, innermost=False, use_dropout=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost @@ -262,7 +264,11 @@ class UnetSkipConnectionBlock(nn.Module): padding=1) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] - model = down + [submodule] + up + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up self.model = nn.Sequential(*model) diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 3bdd237..0e02ebf 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -25,7 +25,7 @@ class Pix2PixModel(BaseModel): # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, - opt.which_model_netG, opt.norm, self.gpu_ids) + opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, |
