diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-04-22 08:15:48 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-04-22 08:15:48 -0700 |
| commit | 6c347282993d2e2db91b376d3113efa3774c3a22 (patch) | |
| tree | 75ede706fde3f61e73fbb9e7ea9ed6e97aff2a56 | |
| parent | a7917caaeaefe51db959b8f3ae50a20e726fbd93 (diff) | |
add dropout option for G
| -rw-r--r-- | README.md | 2 | ||||
| -rw-r--r-- | models/cycle_gan_model.py | 4 | ||||
| -rw-r--r-- | models/networks.py | 34 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 2 | ||||
| -rw-r--r-- | options/base_options.py | 2 | ||||
| -rw-r--r-- | scripts/train_pix2pix.sh | 2 |
6 files changed, 26 insertions, 20 deletions
@@ -81,7 +81,7 @@ bash ./datasets/download_pix2pix_dataset.sh facades - Train a model: ```bash #!./scripts/train_pix2pix.sh -python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --align_data --no_lsgan +python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --align_data --use_dropout --no_lsgan ``` - To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. To see more intermediate results, check out `./checkpoints/facades_pix2pix/web/index.html` - Test the model (`bash ./scripts/test_pix2pix.sh`): diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index eb1c443..d361e47 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -28,9 +28,9 @@ class CycleGANModel(BaseModel): # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, - opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) + opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, - opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) + opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan diff --git a/models/networks.py b/models/networks.py index 2e3ad79..60e1777 100644 --- a/models/networks.py +++ b/models/networks.py @@ -18,7 +18,7 @@ def weights_init(m): m.bias.data.fill_(0) -def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]): +def define_G(input_nc, output_nc, ngf, which_model_netG, norm, use_dropout=False, gpu_ids=[]): netG = None use_gpu = len(gpu_ids) > 0 if norm == 'batch': @@ -31,13 +31,13 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]): assert(torch.cuda.is_available()) if which_model_netG == 'resnet_9blocks': - netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids) + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) elif which_model_netG == 'resnet_6blocks': - netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=6, gpu_ids=gpu_ids) + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) elif which_model_netG == 'unet_128': - netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, gpu_ids=gpu_ids) + netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) elif which_model_netG == 'unet_256': - netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, gpu_ids=gpu_ids) + netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) else: print('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: @@ -124,7 +124,7 @@ class GANLoss(nn.Module): # Code and idea originally from Justin Johnson's architecture. # https://github.com/jcjohnson/fast-neural-style/ class ResnetGenerator(nn.Module): - def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, n_blocks=6, gpu_ids=[]): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[]): assert(n_blocks >= 0) super(ResnetGenerator, self).__init__() self.input_nc = input_nc @@ -146,7 +146,7 @@ class ResnetGenerator(nn.Module): mult = 2**n_downsampling for i in range(n_blocks): - model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer)] + model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer, use_dropout=use_dropout)] for i in range(n_downsampling): mult = 2**(n_downsampling - i) @@ -170,11 +170,11 @@ class ResnetGenerator(nn.Module): # Define a resnet block class ResnetBlock(nn.Module): - def __init__(self, dim, padding_type, norm_layer): + def __init__(self, dim, padding_type, norm_layer, use_dropout): super(ResnetBlock, self).__init__() - self.conv_block = self.build_conv_block(dim, padding_type, norm_layer) + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout) - def build_conv_block(self, dim, padding_type, norm_layer): + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout): conv_block = [] p = 0 # TODO: support padding types @@ -185,6 +185,8 @@ class ResnetBlock(nn.Module): conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim), nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim)] @@ -201,7 +203,7 @@ class ResnetBlock(nn.Module): # at the bottleneck class UnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, num_downs, ngf=64, - norm_layer=nn.BatchNorm2d, gpu_ids=[]): + norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]): super(UnetGenerator, self).__init__() self.gpu_ids = gpu_ids @@ -211,7 +213,7 @@ class UnetGenerator(nn.Module): # construct unet structure unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True) for i in range(num_downs - 5): - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block) + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, use_dropout=use_dropout) unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block) unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block) unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block) @@ -231,7 +233,7 @@ class UnetGenerator(nn.Module): # |-- downsampling -- |submodule| -- upsampling --| class UnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, - submodule=None, outermost=False, innermost=False): + submodule=None, outermost=False, innermost=False, use_dropout=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost @@ -262,7 +264,11 @@ class UnetSkipConnectionBlock(nn.Module): padding=1) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] - model = down + [submodule] + up + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up self.model = nn.Sequential(*model) diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 3bdd237..0e02ebf 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -25,7 +25,7 @@ class Pix2PixModel(BaseModel): # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, - opt.which_model_netG, opt.norm, self.gpu_ids) + opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, diff --git a/options/base_options.py b/options/base_options.py index ec0f439..4074746 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -34,7 +34,7 @@ class BaseOptions(): self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') - + self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') self.initialized = True def parse(self): diff --git a/scripts/train_pix2pix.sh b/scripts/train_pix2pix.sh index 188050b..f14e7da 100644 --- a/scripts/train_pix2pix.sh +++ b/scripts/train_pix2pix.sh @@ -1 +1 @@ -python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --align_data --no_lsgan +python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --align_data --no_lsgan --use_dropout |
