diff options
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 288 |
1 files changed, 288 insertions, 0 deletions
diff --git a/models/networks.py b/models/networks.py new file mode 100644 index 0000000..d41bd0e --- /dev/null +++ b/models/networks.py @@ -0,0 +1,288 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +from pdb import set_trace as st + +############################################################################### +# Functions +############################################################################### + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]): + netG = None + use_gpu = len(gpu_ids) > 0 + if norm == 'batch': + norm_layer = nn.BatchNorm2d + elif norm == 'instance': + norm_layer = InstanceNormalization + else: + print('normalization layer [%s] is not found' % norm) + + assert(torch.cuda.is_available() == use_gpu) + if which_model_netG == 'resnet_9blocks': + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids) + elif which_model_netG == 'resnet_6blocks': + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=6, gpu_ids=gpu_ids) + elif which_model_netG == 'unet': + netG = UnetGenerator(input_nc, output_nc, ngf, norm_layer, gpu_ids=gpu_ids) + else: + print('Generator model name [%s] is not recognized' % which_model_netG) + if use_gpu: + netG.cuda() + netG.apply(weights_init) + return netG + + +def define_D(input_nc, ndf, which_model_netD, + n_layers_D=3, use_sigmoid=False, gpu_ids=[]): + netD = None + use_gpu = len(gpu_ids) > 0 + assert(torch.cuda.is_available() == use_gpu) + if which_model_netD == 'basic': + netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'n_layers': + netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid, gpu_ids=gpu_ids) + else: + print('Discriminator model name [%s] is not recognized' % + which_model_netD) + if use_gpu: + netD.cuda() + netD.apply(weights_init) + return netD + + +def print_network(net): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print('Total number of parameters: %d' % num_params) + + +############################################################################## +# Classes +############################################################################## + + +# Defines the GAN loss used in LSGAN. +# It is basically same as MSELoss, but it abstracts away the need to create +# the target label tensor that has the same size as the input +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_var = None + self.fake_label_var = None + self.Tensor = tensor + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + target_tensor = None + if target_is_real: + create_label = ((self.real_label_var is None) or + (self.real_label_var.numel() != input.numel())) + if create_label: + real_tensor = self.Tensor(input.size()).fill_(self.real_label) + self.real_label_var = Variable(real_tensor, requires_grad=False) + target_tensor = self.real_label_var + else: + create_label = ((self.fake_label_var is None) or + (self.fake_label_var.numel() != input.numel())) + if create_label: + fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) + self.fake_label_var = Variable(fake_tensor, requires_grad=False) + target_tensor = self.fake_label_var + return target_tensor + + def __call__(self, input, target_is_real): + target_tensor = self.get_target_tensor(input, target_is_real) + return self.loss(input, target_tensor) + + +# Defines the generator that consists of Resnet blocks between a few +# downsampling/upsampling operations. +# Code and idea originally from Justin Johnson's architecture. +# https://github.com/jcjohnson/fast-neural-style/ +class ResnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, n_blocks=6, gpu_ids=[]): + assert(n_blocks >= 0) + super(ResnetGenerator, self).__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.ngf = ngf + self.gpu_ids = gpu_ids + + model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3), + norm_layer(ngf), + nn.ReLU(True)] + + n_downsampling = 2 + for i in range(n_downsampling): + mult = 2**i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, + stride=2, padding=1), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + + mult = 2**n_downsampling + for i in range(n_blocks): + model += [Resnet_block(ngf * mult, 'zero', norm_layer=norm_layer)] + + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=1, output_padding=1), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + + model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=3)] + model += [nn.Tanh()] + + self.model = nn.Sequential(*model) + + def forward(self, input): + if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + else: + return self.model(input) + + + +# Define a resnet block +class Resnet_block(nn.Module): + def __init__(self, dim, padding_type, norm_layer): + super(Resnet_block, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer) + + def build_conv_block(self, dim, padding_type, norm_layer): + conv_block = [] + p = 0 + # TODO: support padding types + assert(padding_type == 'zero') + p = 1 + + # TODO: InstanceNorm + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim), + nn.ReLU(True)] + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +# Defines the Unet geneator. +class UnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, gpu_ids=[]): + super(UnetGenerator, self).__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.ngf = ngf + self.gpu_ids = gpu_ids + + def forward(self, input): + if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + else: + return self.model(input) + + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]): + super(NLayerDiscriminator, self).__init__() + self.gpu_ids = gpu_ids + + kw = 4 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=2), + nn.LeakyReLU(0.2, True) + ] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=2), + # TODO: use InstanceNorm + nn.BatchNorm2d(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=1, stride=2, padding=2), + # TODO: useInstanceNorm + nn.BatchNorm2d(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=1)] + sequence += [nn.Sigmoid()] + + self.model = nn.Sequential(*sequence) + + def forward(self, input): + if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + else: + return self.model(input) + +# Instance Normalization layer from +# https://github.com/darkstar112358/fast-neural-style + +class InstanceNormalization(torch.nn.Module): + """InstanceNormalization + Improves convergence of neural-style. + ref: https://arxiv.org/pdf/1607.08022.pdf + """ + + def __init__(self, dim, eps=1e-5): + super(InstanceNormalization, self).__init__() + self.weight = nn.Parameter(torch.FloatTensor(dim)) + self.bias = nn.Parameter(torch.FloatTensor(dim)) + self.eps = eps + self._reset_parameters() + + def _reset_parameters(self): + self.weight.data.uniform_() + self.bias.data.zero_() + + def forward(self, x): + n = x.size(2) * x.size(3) + t = x.view(x.size(0), x.size(1), n) + mean = torch.mean(t, 2).unsqueeze(2).expand_as(x) + # Calculate the biased var. torch.var returns unbiased var + var = torch.var(t, 2).unsqueeze(2).expand_as(x) * ((n - 1) / float(n)) + scale_broadcast = self.weight.unsqueeze(1).unsqueeze(1).unsqueeze(0) + scale_broadcast = scale_broadcast.expand_as(x) + shift_broadcast = self.bias.unsqueeze(1).unsqueeze(1).unsqueeze(0) + shift_broadcast = shift_broadcast.expand_as(x) + out = (x - mean) / torch.sqrt(var + self.eps) + out = out * scale_broadcast + shift_broadcast + return out |
