summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md4
-rw-r--r--models/cycle_gan_model.py4
-rw-r--r--models/networks.py103
-rw-r--r--models/pix2pix_model.py2
-rw-r--r--options/base_options.py2
-rw-r--r--train.py3
-rw-r--r--util/visualizer.py4
7 files changed, 45 insertions, 77 deletions
diff --git a/README.md b/README.md
index 40445ba..0be5a3b 100644
--- a/README.md
+++ b/README.md
@@ -143,11 +143,7 @@ This will combine each pair of images (A,B) into a single image file, ready for
## TODO
- add reflection and other padding layers.
-- add one-direction test mode for CycleGAN.
- add more preprocessing options.
-- fully test Unet architecture.
-- fully test instance normalization layer from [fast-neural-style project](https://github.com/darkstar112358/fast-neural-style).
-- fully test CPU mode and multi-GPU mode.
## Related Projects:
[CycleGAN](https://github.com/junyanz/CycleGAN): Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index 451002d..f8c4f9f 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -35,10 +35,10 @@ class CycleGANModel(BaseModel):
use_sigmoid = opt.no_lsgan
self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
- opt.n_layers_D, use_sigmoid, self.gpu_ids)
+ opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
- opt.n_layers_D, use_sigmoid, self.gpu_ids)
+ opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
if not self.isTrain or opt.continue_train:
which_epoch = opt.which_epoch
self.load_network(self.netG_A, 'G_A', which_epoch)
diff --git a/models/networks.py b/models/networks.py
index b0f3b11..1a0bc1c 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -2,7 +2,6 @@ import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
-
###############################################################################
# Functions
###############################################################################
@@ -12,31 +11,35 @@ def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
- elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNormalization') != -1:
+ elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
-
-def define_G(input_nc, output_nc, ngf, which_model_netG, norm, use_dropout=False, gpu_ids=[]):
- netG = None
- use_gpu = len(gpu_ids) > 0
- if norm == 'batch':
+def get_norm_layer(norm_type):
+ if norm_type == 'batch':
norm_layer = nn.BatchNorm2d
- elif norm == 'instance':
- norm_layer = InstanceNormalization
+ elif norm_type == 'instance':
+ norm_layer = nn.InstanceNorm2d
else:
print('normalization layer [%s] is not found' % norm)
+ return norm_layer
+
+def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[]):
+ netG = None
+ use_gpu = len(gpu_ids) > 0
+ norm_layer = get_norm_layer(norm_type=norm)
+
if use_gpu:
assert(torch.cuda.is_available())
if which_model_netG == 'resnet_9blocks':
- netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)
+ netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)
elif which_model_netG == 'resnet_6blocks':
- netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids)
+ netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_128':
- netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
+ netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_256':
- netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
+ netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
else:
print('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
@@ -46,15 +49,17 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, use_dropout=False
def define_D(input_nc, ndf, which_model_netD,
- n_layers_D=3, use_sigmoid=False, gpu_ids=[]):
+ n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[]):
netD = None
use_gpu = len(gpu_ids) > 0
+ norm_layer = get_norm_layer(norm_type=norm)
+
if use_gpu:
assert(torch.cuda.is_available())
if which_model_netD == 'basic':
- netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
+ netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'n_layers':
- netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid, gpu_ids=gpu_ids)
+ netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
else:
print('Discriminator model name [%s] is not recognized' %
which_model_netD)
@@ -132,7 +137,7 @@ class ResnetGenerator(nn.Module):
self.gpu_ids = gpu_ids
model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3),
- norm_layer(ngf),
+ norm_layer(ngf, affine=True),
nn.ReLU(True)]
n_downsampling = 2
@@ -140,7 +145,7 @@ class ResnetGenerator(nn.Module):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1),
- norm_layer(ngf * mult * 2),
+ norm_layer(ngf * mult * 2, affine=True),
nn.ReLU(True)]
mult = 2**n_downsampling
@@ -152,7 +157,7 @@ class ResnetGenerator(nn.Module):
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1),
- norm_layer(int(ngf * mult / 2)),
+ norm_layer(int(ngf * mult / 2), affine=True),
nn.ReLU(True)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=3)]
@@ -182,12 +187,12 @@ class ResnetBlock(nn.Module):
# TODO: InstanceNorm
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
- norm_layer(dim),
+ norm_layer(dim, affine=True),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
- norm_layer(dim)]
+ norm_layer(dim, affine=True)]
return nn.Sequential(*conv_block)
@@ -212,11 +217,11 @@ class UnetGenerator(nn.Module):
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True)
for i in range(num_downs - 5):
- unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, use_dropout=use_dropout)
- unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block)
- unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block)
- unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block)
- unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True)
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer)
self.model = unet_block
@@ -232,16 +237,16 @@ class UnetGenerator(nn.Module):
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc,
- submodule=None, outermost=False, innermost=False, use_dropout=False):
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
stride=2, padding=1)
downrelu = nn.LeakyReLU(0.2, True)
- downnorm = nn.BatchNorm2d(inner_nc)
+ downnorm = norm_layer(inner_nc, affine=True)
uprelu = nn.ReLU(True)
- upnorm = nn.BatchNorm2d(outer_nc)
+ upnorm = norm_layer(outer_nc, affine=True)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
@@ -280,7 +285,7 @@ class UnetSkipConnectionBlock(nn.Module):
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
- def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]):
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(NLayerDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
@@ -300,7 +305,7 @@ class NLayerDiscriminator(nn.Module):
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw),
# TODO: use InstanceNorm
- nn.BatchNorm2d(ndf * nf_mult),
+ norm_layer(ndf * nf_mult, affine=True),
nn.LeakyReLU(0.2, True)
]
@@ -310,7 +315,7 @@ class NLayerDiscriminator(nn.Module):
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw),
# TODO: useInstanceNorm
- nn.BatchNorm2d(ndf * nf_mult),
+ norm_layer(ndf * nf_mult, affine=True),
nn.LeakyReLU(0.2, True)
]
@@ -326,37 +331,3 @@ class NLayerDiscriminator(nn.Module):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
-
-# Instance Normalization layer from
-# https://github.com/darkstar112358/fast-neural-style
-
-class InstanceNormalization(torch.nn.Module):
- """InstanceNormalization
- Improves convergence of neural-style.
- ref: https://arxiv.org/pdf/1607.08022.pdf
- """
-
- def __init__(self, dim, eps=1e-5):
- super(InstanceNormalization, self).__init__()
- self.weight = nn.Parameter(torch.FloatTensor(dim))
- self.bias = nn.Parameter(torch.FloatTensor(dim))
- self.eps = eps
- self._reset_parameters()
-
- def _reset_parameters(self):
- self.weight.data.uniform_()
- self.bias.data.zero_()
-
- def forward(self, x):
- n = x.size(2) * x.size(3)
- t = x.view(x.size(0), x.size(1), n)
- mean = torch.mean(t, 2).unsqueeze(2).expand_as(x)
- # Calculate the biased var. torch.var returns unbiased var
- var = torch.var(t, 2).unsqueeze(2).expand_as(x) * ((n - 1) / float(n))
- scale_broadcast = self.weight.unsqueeze(1).unsqueeze(1).unsqueeze(0)
- scale_broadcast = scale_broadcast.expand_as(x)
- shift_broadcast = self.bias.unsqueeze(1).unsqueeze(1).unsqueeze(0)
- shift_broadcast = shift_broadcast.expand_as(x)
- out = (x - mean) / torch.sqrt(var + self.eps)
- out = out * scale_broadcast + shift_broadcast
- return out
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index 34e0bac..4581d33 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -28,7 +28,7 @@ class Pix2PixModel(BaseModel):
use_sigmoid = opt.no_lsgan
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
opt.which_model_netD,
- opt.n_layers_D, use_sigmoid, self.gpu_ids)
+ opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
if not self.isTrain or opt.continue_train:
self.load_network(self.netG, 'G', opt.which_epoch)
if self.isTrain:
diff --git a/options/base_options.py b/options/base_options.py
index cce6aae..9ec7c9a 100644
--- a/options/base_options.py
+++ b/options/base_options.py
@@ -29,7 +29,7 @@ class BaseOptions():
self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
- self.parser.add_argument('--norm', type=str, default='batch', help='batch normalization or instance normalization')
+ self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
diff --git a/train.py b/train.py
index ae129b9..12c9dbc 100644
--- a/train.py
+++ b/train.py
@@ -30,7 +30,8 @@ for epoch in range(1, opt.niter + opt.niter_decay + 1):
if total_steps % opt.print_freq == 0:
errors = model.get_current_errors()
- visualizer.print_current_errors(epoch, epoch_iter, errors, iter_start_time)
+ t = (time.time() - iter_start_time) / opt.batchSize
+ visualizer.print_current_errors(epoch, epoch_iter, errors, t)
if opt.display_id > 0:
visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
diff --git a/util/visualizer.py b/util/visualizer.py
index 9a6da2a..46348be 100644
--- a/util/visualizer.py
+++ b/util/visualizer.py
@@ -70,8 +70,8 @@ class Visualizer():
win=self.display_id)
# errors: same format as |errors| of plotCurrentErrors
- def print_current_errors(self, epoch, i, errors, start_time):
- message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, time.time() - start_time)
+ def print_current_errors(self, epoch, i, errors, t):
+ message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
for k, v in errors.items():
message += '%s: %.3f ' % (k, v)