diff options
| -rw-r--r-- | README.md | 4 | ||||
| -rw-r--r-- | models/cycle_gan_model.py | 4 | ||||
| -rw-r--r-- | models/networks.py | 103 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 2 | ||||
| -rw-r--r-- | options/base_options.py | 2 | ||||
| -rw-r--r-- | train.py | 3 | ||||
| -rw-r--r-- | util/visualizer.py | 4 |
7 files changed, 45 insertions, 77 deletions
@@ -143,11 +143,7 @@ This will combine each pair of images (A,B) into a single image file, ready for ## TODO - add reflection and other padding layers. -- add one-direction test mode for CycleGAN. - add more preprocessing options. -- fully test Unet architecture. -- fully test instance normalization layer from [fast-neural-style project](https://github.com/darkstar112358/fast-neural-style). -- fully test CPU mode and multi-GPU mode. ## Related Projects: [CycleGAN](https://github.com/junyanz/CycleGAN): Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 451002d..f8c4f9f 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -35,10 +35,10 @@ class CycleGANModel(BaseModel): use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, - opt.n_layers_D, use_sigmoid, self.gpu_ids) + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, - opt.n_layers_D, use_sigmoid, self.gpu_ids) + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) diff --git a/models/networks.py b/models/networks.py index b0f3b11..1a0bc1c 100644 --- a/models/networks.py +++ b/models/networks.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn from torch.autograd import Variable import numpy as np - ############################################################################### # Functions ############################################################################### @@ -12,31 +11,35 @@ def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) - elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNormalization') != -1: + elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) - -def define_G(input_nc, output_nc, ngf, which_model_netG, norm, use_dropout=False, gpu_ids=[]): - netG = None - use_gpu = len(gpu_ids) > 0 - if norm == 'batch': +def get_norm_layer(norm_type): + if norm_type == 'batch': norm_layer = nn.BatchNorm2d - elif norm == 'instance': - norm_layer = InstanceNormalization + elif norm_type == 'instance': + norm_layer = nn.InstanceNorm2d else: print('normalization layer [%s] is not found' % norm) + return norm_layer + +def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[]): + netG = None + use_gpu = len(gpu_ids) > 0 + norm_layer = get_norm_layer(norm_type=norm) + if use_gpu: assert(torch.cuda.is_available()) if which_model_netG == 'resnet_9blocks': - netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) elif which_model_netG == 'resnet_6blocks': - netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) elif which_model_netG == 'unet_128': - netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) + netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) elif which_model_netG == 'unet_256': - netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) + netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) else: print('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: @@ -46,15 +49,17 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, use_dropout=False def define_D(input_nc, ndf, which_model_netD, - n_layers_D=3, use_sigmoid=False, gpu_ids=[]): + n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[]): netD = None use_gpu = len(gpu_ids) > 0 + norm_layer = get_norm_layer(norm_type=norm) + if use_gpu: assert(torch.cuda.is_available()) if which_model_netD == 'basic': - netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'n_layers': - netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid, gpu_ids=gpu_ids) + netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) else: print('Discriminator model name [%s] is not recognized' % which_model_netD) @@ -132,7 +137,7 @@ class ResnetGenerator(nn.Module): self.gpu_ids = gpu_ids model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3), - norm_layer(ngf), + norm_layer(ngf, affine=True), nn.ReLU(True)] n_downsampling = 2 @@ -140,7 +145,7 @@ class ResnetGenerator(nn.Module): mult = 2**i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), - norm_layer(ngf * mult * 2), + norm_layer(ngf * mult * 2, affine=True), nn.ReLU(True)] mult = 2**n_downsampling @@ -152,7 +157,7 @@ class ResnetGenerator(nn.Module): model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), - norm_layer(int(ngf * mult / 2)), + norm_layer(int(ngf * mult / 2), affine=True), nn.ReLU(True)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=3)] @@ -182,12 +187,12 @@ class ResnetBlock(nn.Module): # TODO: InstanceNorm conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), - norm_layer(dim), + norm_layer(dim, affine=True), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), - norm_layer(dim)] + norm_layer(dim, affine=True)] return nn.Sequential(*conv_block) @@ -212,11 +217,11 @@ class UnetGenerator(nn.Module): # construct unet structure unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True) for i in range(num_downs - 5): - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, use_dropout=use_dropout) - unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block) - unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block) - unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block) - unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True) + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer) self.model = unet_block @@ -232,16 +237,16 @@ class UnetGenerator(nn.Module): # |-- downsampling -- |submodule| -- upsampling --| class UnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, - submodule=None, outermost=False, innermost=False, use_dropout=False): + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, stride=2, padding=1) downrelu = nn.LeakyReLU(0.2, True) - downnorm = nn.BatchNorm2d(inner_nc) + downnorm = norm_layer(inner_nc, affine=True) uprelu = nn.ReLU(True) - upnorm = nn.BatchNorm2d(outer_nc) + upnorm = norm_layer(outer_nc, affine=True) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, @@ -280,7 +285,7 @@ class UnetSkipConnectionBlock(nn.Module): # Defines the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): - def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): super(NLayerDiscriminator, self).__init__() self.gpu_ids = gpu_ids @@ -300,7 +305,7 @@ class NLayerDiscriminator(nn.Module): nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw), # TODO: use InstanceNorm - nn.BatchNorm2d(ndf * nf_mult), + norm_layer(ndf * nf_mult, affine=True), nn.LeakyReLU(0.2, True) ] @@ -310,7 +315,7 @@ class NLayerDiscriminator(nn.Module): nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw), # TODO: useInstanceNorm - nn.BatchNorm2d(ndf * nf_mult), + norm_layer(ndf * nf_mult, affine=True), nn.LeakyReLU(0.2, True) ] @@ -326,37 +331,3 @@ class NLayerDiscriminator(nn.Module): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) - -# Instance Normalization layer from -# https://github.com/darkstar112358/fast-neural-style - -class InstanceNormalization(torch.nn.Module): - """InstanceNormalization - Improves convergence of neural-style. - ref: https://arxiv.org/pdf/1607.08022.pdf - """ - - def __init__(self, dim, eps=1e-5): - super(InstanceNormalization, self).__init__() - self.weight = nn.Parameter(torch.FloatTensor(dim)) - self.bias = nn.Parameter(torch.FloatTensor(dim)) - self.eps = eps - self._reset_parameters() - - def _reset_parameters(self): - self.weight.data.uniform_() - self.bias.data.zero_() - - def forward(self, x): - n = x.size(2) * x.size(3) - t = x.view(x.size(0), x.size(1), n) - mean = torch.mean(t, 2).unsqueeze(2).expand_as(x) - # Calculate the biased var. torch.var returns unbiased var - var = torch.var(t, 2).unsqueeze(2).expand_as(x) * ((n - 1) / float(n)) - scale_broadcast = self.weight.unsqueeze(1).unsqueeze(1).unsqueeze(0) - scale_broadcast = scale_broadcast.expand_as(x) - shift_broadcast = self.bias.unsqueeze(1).unsqueeze(1).unsqueeze(0) - shift_broadcast = shift_broadcast.expand_as(x) - out = (x - mean) / torch.sqrt(var + self.eps) - out = out * scale_broadcast + shift_broadcast - return out diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 34e0bac..4581d33 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -28,7 +28,7 @@ class Pix2PixModel(BaseModel): use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, - opt.n_layers_D, use_sigmoid, self.gpu_ids) + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: diff --git a/options/base_options.py b/options/base_options.py index cce6aae..9ec7c9a 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -29,7 +29,7 @@ class BaseOptions(): self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') - self.parser.add_argument('--norm', type=str, default='batch', help='batch normalization or instance normalization') + self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') @@ -30,7 +30,8 @@ for epoch in range(1, opt.niter + opt.niter_decay + 1): if total_steps % opt.print_freq == 0: errors = model.get_current_errors() - visualizer.print_current_errors(epoch, epoch_iter, errors, iter_start_time) + t = (time.time() - iter_start_time) / opt.batchSize + visualizer.print_current_errors(epoch, epoch_iter, errors, t) if opt.display_id > 0: visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) diff --git a/util/visualizer.py b/util/visualizer.py index 9a6da2a..46348be 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -70,8 +70,8 @@ class Visualizer(): win=self.display_id) # errors: same format as |errors| of plotCurrentErrors - def print_current_errors(self, epoch, i, errors, start_time): - message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, time.time() - start_time) + def print_current_errors(self, epoch, i, errors, t): + message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) for k, v in errors.items(): message += '%s: %.3f ' % (k, v) |
