diff options
| author | tingchunw <tingchunw@nvidia.com> | 2017-12-04 16:52:46 -0800 |
|---|---|---|
| committer | tingchunw <tingchunw@nvidia.com> | 2017-12-04 16:52:46 -0800 |
| commit | 9054cf9b0c327a5077fd0793abe178f400da3315 (patch) | |
| tree | 3c69c07bdcba86c47d8442648fd69c0434e04136 /models | |
| parent | f9e9999541d67a908a169cc88407675133130e1f (diff) | |
first commit
Diffstat (limited to 'models')
| -rwxr-xr-x | models/__init__.py | 0 | ||||
| -rwxr-xr-x | models/base_model.py | 86 | ||||
| -rwxr-xr-x | models/models.py | 14 | ||||
| -rwxr-xr-x | models/networks.py | 421 | ||||
| -rwxr-xr-x | models/pix2pixHD_model.py | 260 |
5 files changed, 781 insertions, 0 deletions
diff --git a/models/__init__.py b/models/__init__.py new file mode 100755 index 0000000..e69de29 --- /dev/null +++ b/models/__init__.py diff --git a/models/base_model.py b/models/base_model.py new file mode 100755 index 0000000..d3879d0 --- /dev/null +++ b/models/base_model.py @@ -0,0 +1,86 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import os +import torch + +class BaseModel(torch.nn.Module): + def name(self): + return 'BaseModel' + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # used in test time, no backprop + def test(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, network_label, epoch_label, gpu_ids): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(network.cpu().state_dict(), save_path) + if len(gpu_ids) and torch.cuda.is_available(): + network.cuda() + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label, save_dir=''): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + if not save_dir: + save_dir = self.save_dir + save_path = os.path.join(save_dir, save_filename) + if not os.path.isfile(save_path): + print('%s not exists yet!' % save_path) + if network_label == 'G': + raise('Generator must exist!') + else: + #network.load_state_dict(torch.load(save_path)) + try: + network.load_state_dict(torch.load(save_path)) + except: + pretrained_dict = torch.load(save_path) + model_dict = network.state_dict() + try: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + network.load_state_dict(pretrained_dict) + print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) + except: + print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) + from sets import Set + not_initialized = Set() + for k, v in pretrained_dict.items(): + if v.size() == model_dict[k].size(): + model_dict[k] = v + + for k, v in model_dict.items(): + if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): + not_initialized.add(k.split('.')[0]) + print(sorted(not_initialized)) + network.load_state_dict(model_dict) + + def update_learning_rate(): + pass diff --git a/models/models.py b/models/models.py new file mode 100755 index 0000000..351483c --- /dev/null +++ b/models/models.py @@ -0,0 +1,14 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import torch + +def create_model(opt): + from .pix2pixHD_model import Pix2PixHDModel + model = Pix2PixHDModel() + model.initialize(opt) + print("model [%s] was created" % (model.name())) + + if opt.isTrain and len(opt.gpu_ids): + model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) + + return model diff --git a/models/networks.py b/models/networks.py new file mode 100755 index 0000000..a673a56 --- /dev/null +++ b/models/networks.py @@ -0,0 +1,421 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import torch +import torch.nn as nn +from torch.nn import init +import functools +from torch.autograd import Variable +import numpy as np +import math +import torch.nn.functional as F +import copy + +############################################################################### +# Functions +############################################################################### +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + +def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, + n_blocks_local=3, norm='instance', gpu_ids=[]): + norm_layer = get_norm_layer(norm_type=norm) + if netG == 'global': + netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer) + elif netG == 'local': + netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, + n_local_enhancers, n_blocks_local, norm_layer) + elif netG == 'encoder': + netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer) + else: + raise('generator not implemented!') + print(netG) + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + netG.cuda(device_id=gpu_ids[0]) + netG.apply(weights_init) + return netG + +def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]): + norm_layer = get_norm_layer(norm_type=norm) + netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) + print(netD) + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + netD.cuda(device_id=gpu_ids[0]) + netD.apply(weights_init) + return netD + +def print_network(net): + if isinstance(net, list): + net = net[0] + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print('Total number of parameters: %d' % num_params) + +############################################################################## +# Losses +############################################################################## +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_var = None + self.fake_label_var = None + self.Tensor = tensor + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + target_tensor = None + if target_is_real: + create_label = ((self.real_label_var is None) or + (self.real_label_var.numel() != input.numel())) + if create_label: + real_tensor = self.Tensor(input.size()).fill_(self.real_label) + self.real_label_var = Variable(real_tensor, requires_grad=False) + target_tensor = self.real_label_var + else: + create_label = ((self.fake_label_var is None) or + (self.fake_label_var.numel() != input.numel())) + if create_label: + fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) + self.fake_label_var = Variable(fake_tensor, requires_grad=False) + target_tensor = self.fake_label_var + return target_tensor + + def __call__(self, input, target_is_real): + if isinstance(input[0], list): + loss = 0 + for input_i in input: + pred = input_i[-1] + target_tensor = self.get_target_tensor(pred, target_is_real) + loss += self.loss(pred, target_tensor) + return loss + else: + target_tensor = self.get_target_tensor(input[-1], target_is_real) + return self.loss(input[-1], target_tensor) + +class VGGLoss(nn.Module): + def __init__(self, gpu_ids): + super(VGGLoss, self).__init__() + self.vgg = Vgg19().cuda() + self.criterion = nn.L1Loss() + self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] + + def forward(self, x, y): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + for i in range(len(x_vgg)): + loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) + return loss + +############################################################################## +# Generator +############################################################################## +class LocalEnhancer(nn.Module): + def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9, + n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'): + super(LocalEnhancer, self).__init__() + self.n_local_enhancers = n_local_enhancers + + ###### global generator model ##### + ngf_global = ngf * (2**n_local_enhancers) + model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model + model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers + self.model = nn.Sequential(*model_global) + + ###### local enhancer layers ##### + for n in range(1, n_local_enhancers+1): + ### downsample + ngf_global = ngf * (2**(n_local_enhancers-n)) + model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), + norm_layer(ngf_global), nn.ReLU(True), + nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1), + norm_layer(ngf_global * 2), nn.ReLU(True)] + ### residual blocks + model_upsample = [] + for i in range(n_blocks_local): + model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer)] + + ### upsample + model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1), + norm_layer(ngf_global), nn.ReLU(True)] + + ### final convolution + if n == n_local_enhancers: + model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] + + setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample)) + setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample)) + + self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) + + def forward(self, input): + ### create input pyramid + input_downsampled = [input] + for i in range(self.n_local_enhancers): + input_downsampled.append(self.downsample(input_downsampled[-1])) + + ### output at coarest level + output_prev = self.model(input_downsampled[-1]) + ### build up one layer at a time + for n_local_enhancers in range(1, self.n_local_enhancers+1): + model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1') + model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2') + input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers] + output_prev = model_upsample(model_downsample(input_i) + output_prev) + return output_prev + +class GlobalGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, + padding_type='reflect'): + assert(n_blocks >= 0) + super(GlobalGenerator, self).__init__() + activation = nn.ReLU(True) + + model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] + ### downsample + for i in range(n_downsampling): + mult = 2**i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), + norm_layer(ngf * mult * 2), activation] + + ### resnet blocks + mult = 2**n_downsampling + for i in range(n_blocks): + model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)] + + ### upsample + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), + norm_layer(int(ngf * mult / 2)), activation] + model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + +# Define a resnet block +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) + + def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim), + activation] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + +class Encoder(nn.Module): + def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): + super(Encoder, self).__init__() + self.output_nc = output_nc + + model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), nn.ReLU(True)] + ### downsample + for i in range(n_downsampling): + mult = 2**i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), + norm_layer(ngf * mult * 2), nn.ReLU(True)] + + ### upsample + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), + norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] + + model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*model) + + def forward(self, input, inst): + outputs = self.model(input) + + # instance-wise average pooling + outputs_mean = outputs.clone() + inst_list = np.unique(inst.cpu().numpy().astype(int)) + for i in inst_list: + indices = (inst == i).nonzero() # n x 4 + for j in range(self.output_nc): + output_ins = outputs[indices[:,0], indices[:,1] + j, indices[:,2], indices[:,3]] + mean_feat = torch.mean(output_ins).expand_as(output_ins) + outputs_mean[indices[:,0], indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat + return outputs_mean + +class MultiscaleDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, + use_sigmoid=False, num_D=3, getIntermFeat=False): + super(MultiscaleDiscriminator, self).__init__() + self.num_D = num_D + self.n_layers = n_layers + self.getIntermFeat = getIntermFeat + + for i in range(num_D): + netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) + if getIntermFeat: + for j in range(n_layers+2): + setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) + else: + setattr(self, 'layer'+str(i), netD.model) + + self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) + + def singleD_forward(self, model, input): + if self.getIntermFeat: + result = [input] + for i in range(len(model)): + result.append(model[i](result[-1])) + return result[1:] + else: + return [model(input)] + + def forward(self, input): + num_D = self.num_D + result = [] + input_downsampled = input + for i in range(num_D): + if self.getIntermFeat: + model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] + else: + model = getattr(self, 'layer'+str(num_D-1-i)) + result.append(self.singleD_forward(model, input_downsampled)) + if i != (num_D-1): + input_downsampled = self.downsample(input_downsampled) + return result + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [nn.Sigmoid()] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[1:] + else: + return self.model(input) + +from torchvision import models +class Vgg19(torch.nn.Module): + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py new file mode 100755 index 0000000..ba44e53 --- /dev/null +++ b/models/pix2pixHD_model.py @@ -0,0 +1,260 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import numpy as np +import torch +import os +from collections import OrderedDict +from torch.autograd import Variable +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks + +class Pix2PixHDModel(BaseModel): + def name(self): + return 'Pix2PixHDModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + if opt.resize_or_crop != 'none': # when training at full res this causes OOM + torch.backends.cudnn.benchmark = True + self.isTrain = opt.isTrain + self.use_features = opt.instance_feat or opt.label_feat + self.gen_features = self.use_features and not self.opt.load_features + + ##### define networks + # Generator network + netG_input_nc = opt.label_nc + if not opt.no_instance: + netG_input_nc += 1 + if self.use_features: + netG_input_nc += opt.feat_num + self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, + opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, + opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) + + # Discriminator network + if self.isTrain: + use_sigmoid = opt.no_lsgan + netD_input_nc = opt.label_nc + opt.output_nc + if not opt.no_instance: + netD_input_nc += 1 + self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, + opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) + + ### Encoder network + if self.gen_features: + self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', + opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) + + print('---------- Networks initialized -------------') + + # load networks + if not self.isTrain or opt.continue_train or opt.load_pretrain: + pretrained_path = '' if not self.isTrain else opt.load_pretrain + self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) + if self.isTrain: + self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) + if self.gen_features: + self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) + + # set loss functions and optimizers + if self.isTrain: + if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: + raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") + self.fake_pool = ImagePool(opt.pool_size) + self.old_lr = opt.lr + + # define loss functions + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + self.criterionFeat = torch.nn.L1Loss() + if not opt.no_vgg_loss: + self.criterionVGG = networks.VGGLoss(self.gpu_ids) + + # Names so we can breakout loss + self.loss_names = ['G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake'] + + # initialize optimizers + # optimizer G + if opt.niter_fix_global > 0: + print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) + params_dict = dict(self.netG.named_parameters()) + params = [] + for key, value in params_dict.items(): + if key.startswith('model' + str(opt.n_local_enhancers)): + params += [{'params':[value],'lr':opt.lr}] + else: + params += [{'params':[value],'lr':0.0}] + else: + params = list(self.netG.parameters()) + if self.gen_features: + params += list(self.netE.parameters()) + self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + + # optimizer D + params = list(self.netD.parameters()) + self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + + def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): + # create one-hot vector for label map + size = label_map.size() + oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) + input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() + input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) + + # get edges from instance map + if not self.opt.no_instance: + inst_map = inst_map.data.cuda() + edge_map = self.get_edges(inst_map) + input_label = torch.cat((input_label, edge_map), dim=1) + input_label = Variable(input_label, volatile=infer) + + # real images for training + if real_image is not None: + real_image = Variable(real_image.data.cuda()) + + # instance map for feature encoding + if self.use_features: + # get precomputed feature maps + if self.opt.load_features: + feat_map = Variable(feat_map.data.cuda()) + + return input_label, inst_map, real_image, feat_map + + def discriminate(self, input_label, test_image, use_pool=False): + input_concat = torch.cat((input_label, test_image.detach()), dim=1) + if use_pool: + fake_query = self.fake_pool.query(input_concat) + return self.netD.forward(fake_query) + else: + return self.netD.forward(input_concat) + + def forward(self, label, inst, image, feat, infer=False): + # Encode Inputs + input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) + + # Fake Generation + if self.use_features: + if not self.opt.load_features: + feat_map = self.netE.forward(real_image, inst_map) + input_concat = torch.cat((input_label, feat_map), dim=1) + else: + input_concat = input_label + fake_image = self.netG.forward(input_concat) + + # Fake Detection and Loss + pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) + loss_D_fake = self.criterionGAN(pred_fake_pool, False) + + # Real Detection and Loss + pred_real = self.discriminate(input_label, real_image) + loss_D_real = self.criterionGAN(pred_real, True) + + # GAN loss (Fake Passability Loss) + pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) + loss_G_GAN = self.criterionGAN(pred_fake, True) + + # GAN feature matching loss + loss_G_GAN_Feat = 0 + if not self.opt.no_ganFeat_loss: + feat_weights = 4.0 / (self.opt.n_layers_D + 1) + D_weights = 1.0 / self.opt.num_D + for i in range(self.opt.num_D): + for j in range(len(pred_fake[i])-1): + loss_G_GAN_Feat += D_weights * feat_weights * \ + self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat + + # VGG feature matching loss + loss_G_VGG = 0 + if not self.opt.no_vgg_loss: + loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat + + # Only return the fake_B image if necessary to save BW + return [ [ loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ], None if not infer else fake_image ] + + def inference(self, label, inst): + # Encode Inputs + input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True) + + # Fake Generation + if self.use_features: + # sample clusters from precomputed features + feat_map = self.sample_features(inst_map) + input_concat = torch.cat((input_label, feat_map), dim=1) + else: + input_concat = input_label + fake_image = self.netG.forward(input_concat) + return fake_image + + def sample_features(self, inst): + # read precomputed feature clusters + cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) + features_clustered = np.load(cluster_path).item() + + # randomly sample from the feature clusters + inst_np = inst.cpu().numpy().astype(int) + feat_map = torch.cuda.FloatTensor(1, self.opt.feat_num, inst.size()[2], inst.size()[3]) + for i in np.unique(inst_np): + label = i if i < 1000 else i//1000 + if label in features_clustered: + feat = features_clustered[label] + cluster_idx = np.random.randint(0, feat.shape[0]) + + idx = (inst == i).nonzero() + for k in range(self.opt.feat_num): + feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] + return feat_map + + def encode_features(self, image, inst): + image = Variable(image.cuda(), volatile=True) + feat_num = self.opt.feat_num + h, w = inst.size()[2], inst.size()[3] + block_num = 32 + feat_map = self.netE.forward(image, inst.cuda()) + inst_np = inst.cpu().numpy().astype(int) + feature = {} + for i in range(self.opt.label_nc): + feature[i] = np.zeros((0, feat_num+1)) + for i in np.unique(inst_np): + label = i if i < 1000 else i//1000 + idx = (inst == i).nonzero() + num = idx.size()[0] + idx = idx[num//2,:] + val = np.zeros((1, feat_num+1)) + for k in range(feat_num): + val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] + val[0, feat_num] = float(num) / (h * w // block_num) + feature[label] = np.append(feature[label], val, axis=0) + return feature + + def get_edges(self, t): + edge = torch.cuda.ByteTensor(t.size()).zero_() + edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) + edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) + edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) + edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) + return edge.float() + + def save(self, which_epoch): + self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) + self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) + if self.gen_features: + self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) + + def update_fixed_params(self): + # after fixing the global generator for a number of iterations, also start finetuning it + params = list(self.netG.parameters()) + if self.gen_features: + params += list(self.netE.parameters()) + self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) + print('------------ Now also finetuning global generator -----------') + + def update_learning_rate(self): + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr |
