diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-04-18 03:56:29 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-04-18 03:56:29 -0700 |
| commit | efef2906b0a3f4fd265823ff4b4b99ccebeb6d05 (patch) | |
| tree | 0b041af61c33b2262b336598b9ff402595ed5f3b | |
| parent | e17139daa9f3af07acee72dbf186d7eaf9ad089e (diff) | |
add unet and update README
| -rw-r--r-- | README.md | 3 | ||||
| -rw-r--r-- | models/networks.py | 98 | ||||
| -rw-r--r-- | util/visualizer.py | 2 |
3 files changed, 84 insertions, 19 deletions
@@ -79,11 +79,10 @@ The test results will be saved to a html file here: `./results/facades_pix2pix/l More example scripts can be found at `scripts` directory. - ## Training/test Details - See `options/train_options.py` and `options/base_options.py` for training flags; see `optoins/test_options.py` and `options/base_options.py` for test flags. - CPU/GPU: Set `--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. -- If you set `--display_id 0`, we will save the training results to `../checkpoints/name/web/index.html`. If you set `--display_id` > 0, we will use a browser-based graphics server. You need to call `th -ldisplay.start 8000 0.0.0.0` to start the server. See [[szym/display]](https://github.com/szym/display) for more details. +- During training, you can visualize the result of current training. If you set `--display_id 0`, we will periodically save the training results to `[opt.checkpoints_dir]/[opt.name]/web/`. If you set `--display_id` > 0, the results will be shown on a local graphics web server launched by [szym/display: a lightweight display server for Torch](https://github.com/szym/display). To do this, you should have Torch, Python 3, and the display package installed. You need to invoke `th -ldisplay.start 8000 0.0.0.0` to start the server. ### CycleGAN Datasets Download the CycleGAN datasets using the following script: diff --git a/models/networks.py b/models/networks.py index edbe972..9cb6222 100644 --- a/models/networks.py +++ b/models/networks.py @@ -26,18 +26,19 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]): norm_layer = InstanceNormalization else: print('normalization layer [%s] is not found' % norm) - if use_gpu: - assert(torch.cuda.is_available()) - + + assert(torch.cuda.is_available() == use_gpu) if which_model_netG == 'resnet_9blocks': netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids) elif which_model_netG == 'resnet_6blocks': - netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=6, gpu_ids=gpu_ids) - elif which_model_netG == 'unet': - netG = UnetGenerator(input_nc, output_nc, ngf, norm_layer, gpu_ids=gpu_ids) + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, 6, gpu_ids=gpu_ids) + elif which_model_netG == 'unet_128': + netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, gpu_ids=gpu_ids) + elif which_model_netG == 'unet_256': + netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, gpu_ids=gpu_ids) else: print('Generator model name [%s] is not recognized' % which_model_netG) - if use_gpu: + if len(gpu_ids) > 0: netG.cuda() netG.apply(weights_init) return netG @@ -47,9 +48,7 @@ def define_D(input_nc, ndf, which_model_netD, n_layers_D=3, use_sigmoid=False, gpu_ids=[]): netD = None use_gpu = len(gpu_ids) > 0 - if use_gpu: - assert(torch.cuda.is_available()) - + assert(torch.cuda.is_available() == use_gpu) if which_model_netD == 'basic': netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'n_layers': @@ -143,7 +142,7 @@ class ResnetGenerator(nn.Module): mult = 2**n_downsampling for i in range(n_blocks): - model += [Resnet_block(ngf * mult, 'zero', norm_layer=norm_layer)] + model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer)] for i in range(n_downsampling): mult = 2**(n_downsampling - i) @@ -165,11 +164,10 @@ class ResnetGenerator(nn.Module): return self.model(input) - # Define a resnet block -class Resnet_block(nn.Module): +class ResnetBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer): - super(Resnet_block, self).__init__() + super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer) def build_conv_block(self, dim, padding_type, norm_layer): @@ -193,15 +191,35 @@ class Resnet_block(nn.Module): return out -# Defines the Unet geneator. +# Defines the Unet generator. +# |num_downs|: number of downsamplings in UNet. For example, +# if |num_downs| == 7, image of size 128x128 will become of size 1x1 +# at the bottleneck class UnetGenerator(nn.Module): - def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, gpu_ids=[]): + def __init__(self, input_nc, output_nc, num_downs, ngf=64, + norm_layer=nn.BatchNorm2d, gpu_ids=[]): super(UnetGenerator, self).__init__() self.input_nc = input_nc self.output_nc = output_nc self.ngf = ngf self.gpu_ids = gpu_ids + # currently support only input_nc == output_nc + assert(input_nc == output_nc) + + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True) + + for i in range(num_downs - 5): + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block) + + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block) + unet_block = UnetSkipConnectionBlock(input_nc, ngf, unet_block, + outermost=True) + + self.model = unet_block + def forward(self, input): if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: return nn.parallel.data_parallel(self.model, input, self.gpu_ids) @@ -209,6 +227,54 @@ class UnetGenerator(nn.Module): return self.model(input) +# Defines the submodule with skip connection. +# X -------------------identity---------------------- X +# |-- downsampling -- |submodule| -- upsampling --| +class UnetSkipConnectionBlock(nn.Module): + def __init__(self, outer_nc, inner_nc, + submodule=None, outermost=False, innermost=False): + super(UnetSkipConnectionBlock, self).__init__() + self.outer_nc = outer_nc + self.inner_nc = inner_nc + self.innermost = innermost + downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, + stride=2, padding=1) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = nn.BatchNorm2d(inner_nc) + uprelu = nn.ReLU(True) + upnorm = nn.BatchNorm2d(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, output_padding=0) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, output_padding=0) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, output_padding=0) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + #print(self.outer_nc, self.inner_nc, self.innermost) + #print(x.size()) + #print(self.model(x).size()) + return torch.cat([self.model(x), x], 1) + + # Defines the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): diff --git a/util/visualizer.py b/util/visualizer.py index 0b8578e..4daf506 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -27,7 +27,7 @@ class Visualizer(): def display_current_results(self, visuals, epoch): if self.display_id > 0: # show images in the browser idx = 0 - for label, image_numpy in visuals: + for label, image_numpy in visuals.items(): image_numpy = np.flipud(image_numpy) self.display.image(image_numpy, title=label, win=self.display_id + idx) |
