summaryrefslogtreecommitdiff
path: root/models/networks.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/networks.py')
-rw-r--r--models/networks.py288
1 files changed, 288 insertions, 0 deletions
diff --git a/models/networks.py b/models/networks.py
new file mode 100644
index 0000000..d41bd0e
--- /dev/null
+++ b/models/networks.py
@@ -0,0 +1,288 @@
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+from pdb import set_trace as st
+
+###############################################################################
+# Functions
+###############################################################################
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ m.weight.data.normal_(0.0, 0.02)
+ elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1:
+ m.weight.data.normal_(1.0, 0.02)
+ m.bias.data.fill_(0)
+
+
+def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]):
+ netG = None
+ use_gpu = len(gpu_ids) > 0
+ if norm == 'batch':
+ norm_layer = nn.BatchNorm2d
+ elif norm == 'instance':
+ norm_layer = InstanceNormalization
+ else:
+ print('normalization layer [%s] is not found' % norm)
+
+ assert(torch.cuda.is_available() == use_gpu)
+ if which_model_netG == 'resnet_9blocks':
+ netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids)
+ elif which_model_netG == 'resnet_6blocks':
+ netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=6, gpu_ids=gpu_ids)
+ elif which_model_netG == 'unet':
+ netG = UnetGenerator(input_nc, output_nc, ngf, norm_layer, gpu_ids=gpu_ids)
+ else:
+ print('Generator model name [%s] is not recognized' % which_model_netG)
+ if use_gpu:
+ netG.cuda()
+ netG.apply(weights_init)
+ return netG
+
+
+def define_D(input_nc, ndf, which_model_netD,
+ n_layers_D=3, use_sigmoid=False, gpu_ids=[]):
+ netD = None
+ use_gpu = len(gpu_ids) > 0
+ assert(torch.cuda.is_available() == use_gpu)
+ if which_model_netD == 'basic':
+ netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
+ elif which_model_netD == 'n_layers':
+ netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid, gpu_ids=gpu_ids)
+ else:
+ print('Discriminator model name [%s] is not recognized' %
+ which_model_netD)
+ if use_gpu:
+ netD.cuda()
+ netD.apply(weights_init)
+ return netD
+
+
+def print_network(net):
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ print(net)
+ print('Total number of parameters: %d' % num_params)
+
+
+##############################################################################
+# Classes
+##############################################################################
+
+
+# Defines the GAN loss used in LSGAN.
+# It is basically same as MSELoss, but it abstracts away the need to create
+# the target label tensor that has the same size as the input
+class GANLoss(nn.Module):
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
+ tensor=torch.FloatTensor):
+ super(GANLoss, self).__init__()
+ self.real_label = target_real_label
+ self.fake_label = target_fake_label
+ self.real_label_var = None
+ self.fake_label_var = None
+ self.Tensor = tensor
+ if use_lsgan:
+ self.loss = nn.MSELoss()
+ else:
+ self.loss = nn.BCELoss()
+
+ def get_target_tensor(self, input, target_is_real):
+ target_tensor = None
+ if target_is_real:
+ create_label = ((self.real_label_var is None) or
+ (self.real_label_var.numel() != input.numel()))
+ if create_label:
+ real_tensor = self.Tensor(input.size()).fill_(self.real_label)
+ self.real_label_var = Variable(real_tensor, requires_grad=False)
+ target_tensor = self.real_label_var
+ else:
+ create_label = ((self.fake_label_var is None) or
+ (self.fake_label_var.numel() != input.numel()))
+ if create_label:
+ fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
+ self.fake_label_var = Variable(fake_tensor, requires_grad=False)
+ target_tensor = self.fake_label_var
+ return target_tensor
+
+ def __call__(self, input, target_is_real):
+ target_tensor = self.get_target_tensor(input, target_is_real)
+ return self.loss(input, target_tensor)
+
+
+# Defines the generator that consists of Resnet blocks between a few
+# downsampling/upsampling operations.
+# Code and idea originally from Justin Johnson's architecture.
+# https://github.com/jcjohnson/fast-neural-style/
+class ResnetGenerator(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, n_blocks=6, gpu_ids=[]):
+ assert(n_blocks >= 0)
+ super(ResnetGenerator, self).__init__()
+ self.input_nc = input_nc
+ self.output_nc = output_nc
+ self.ngf = ngf
+ self.gpu_ids = gpu_ids
+
+ model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ for i in range(n_downsampling):
+ mult = 2**i
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
+ stride=2, padding=1),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+
+ mult = 2**n_downsampling
+ for i in range(n_blocks):
+ model += [Resnet_block(ngf * mult, 'zero', norm_layer=norm_layer)]
+
+ for i in range(n_downsampling):
+ mult = 2**(n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=3)]
+ model += [nn.Tanh()]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
+ else:
+ return self.model(input)
+
+
+
+# Define a resnet block
+class Resnet_block(nn.Module):
+ def __init__(self, dim, padding_type, norm_layer):
+ super(Resnet_block, self).__init__()
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer)
+
+ def build_conv_block(self, dim, padding_type, norm_layer):
+ conv_block = []
+ p = 0
+ # TODO: support padding types
+ assert(padding_type == 'zero')
+ p = 1
+
+ # TODO: InstanceNorm
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
+ norm_layer(dim),
+ nn.ReLU(True)]
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
+ norm_layer(dim)]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ out = x + self.conv_block(x)
+ return out
+
+
+# Defines the Unet geneator.
+class UnetGenerator(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, gpu_ids=[]):
+ super(UnetGenerator, self).__init__()
+ self.input_nc = input_nc
+ self.output_nc = output_nc
+ self.ngf = ngf
+ self.gpu_ids = gpu_ids
+
+ def forward(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
+ else:
+ return self.model(input)
+
+
+
+# Defines the PatchGAN discriminator with the specified arguments.
+class NLayerDiscriminator(nn.Module):
+ def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]):
+ super(NLayerDiscriminator, self).__init__()
+ self.gpu_ids = gpu_ids
+
+ kw = 4
+ sequence = [
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=2),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers):
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
+ kernel_size=kw, stride=2, padding=2),
+ # TODO: use InstanceNorm
+ nn.BatchNorm2d(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
+ kernel_size=1, stride=2, padding=2),
+ # TODO: useInstanceNorm
+ nn.BatchNorm2d(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=1)]
+ sequence += [nn.Sigmoid()]
+
+ self.model = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
+ else:
+ return self.model(input)
+
+# Instance Normalization layer from
+# https://github.com/darkstar112358/fast-neural-style
+
+class InstanceNormalization(torch.nn.Module):
+ """InstanceNormalization
+ Improves convergence of neural-style.
+ ref: https://arxiv.org/pdf/1607.08022.pdf
+ """
+
+ def __init__(self, dim, eps=1e-5):
+ super(InstanceNormalization, self).__init__()
+ self.weight = nn.Parameter(torch.FloatTensor(dim))
+ self.bias = nn.Parameter(torch.FloatTensor(dim))
+ self.eps = eps
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ self.weight.data.uniform_()
+ self.bias.data.zero_()
+
+ def forward(self, x):
+ n = x.size(2) * x.size(3)
+ t = x.view(x.size(0), x.size(1), n)
+ mean = torch.mean(t, 2).unsqueeze(2).expand_as(x)
+ # Calculate the biased var. torch.var returns unbiased var
+ var = torch.var(t, 2).unsqueeze(2).expand_as(x) * ((n - 1) / float(n))
+ scale_broadcast = self.weight.unsqueeze(1).unsqueeze(1).unsqueeze(0)
+ scale_broadcast = scale_broadcast.expand_as(x)
+ shift_broadcast = self.bias.unsqueeze(1).unsqueeze(1).unsqueeze(0)
+ shift_broadcast = shift_broadcast.expand_as(x)
+ out = (x - mean) / torch.sqrt(var + self.eps)
+ out = out * scale_broadcast + shift_broadcast
+ return out