summaryrefslogtreecommitdiff
path: root/models/pix2pixHD_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/pix2pixHD_model.py')
-rwxr-xr-xmodels/pix2pixHD_model.py260
1 files changed, 260 insertions, 0 deletions
diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py
new file mode 100755
index 0000000..ba44e53
--- /dev/null
+++ b/models/pix2pixHD_model.py
@@ -0,0 +1,260 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import numpy as np
+import torch
+import os
+from collections import OrderedDict
+from torch.autograd import Variable
+import util.util as util
+from util.image_pool import ImagePool
+from .base_model import BaseModel
+from . import networks
+
+class Pix2PixHDModel(BaseModel):
+ def name(self):
+ return 'Pix2PixHDModel'
+
+ def initialize(self, opt):
+ BaseModel.initialize(self, opt)
+ if opt.resize_or_crop != 'none': # when training at full res this causes OOM
+ torch.backends.cudnn.benchmark = True
+ self.isTrain = opt.isTrain
+ self.use_features = opt.instance_feat or opt.label_feat
+ self.gen_features = self.use_features and not self.opt.load_features
+
+ ##### define networks
+ # Generator network
+ netG_input_nc = opt.label_nc
+ if not opt.no_instance:
+ netG_input_nc += 1
+ if self.use_features:
+ netG_input_nc += opt.feat_num
+ self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
+ opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
+ opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
+
+ # Discriminator network
+ if self.isTrain:
+ use_sigmoid = opt.no_lsgan
+ netD_input_nc = opt.label_nc + opt.output_nc
+ if not opt.no_instance:
+ netD_input_nc += 1
+ self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
+ opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
+
+ ### Encoder network
+ if self.gen_features:
+ self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
+ opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)
+
+ print('---------- Networks initialized -------------')
+
+ # load networks
+ if not self.isTrain or opt.continue_train or opt.load_pretrain:
+ pretrained_path = '' if not self.isTrain else opt.load_pretrain
+ self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
+ if self.isTrain:
+ self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
+ if self.gen_features:
+ self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)
+
+ # set loss functions and optimizers
+ if self.isTrain:
+ if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
+ raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
+ self.fake_pool = ImagePool(opt.pool_size)
+ self.old_lr = opt.lr
+
+ # define loss functions
+ self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
+ self.criterionFeat = torch.nn.L1Loss()
+ if not opt.no_vgg_loss:
+ self.criterionVGG = networks.VGGLoss(self.gpu_ids)
+
+ # Names so we can breakout loss
+ self.loss_names = ['G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake']
+
+ # initialize optimizers
+ # optimizer G
+ if opt.niter_fix_global > 0:
+ print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
+ params_dict = dict(self.netG.named_parameters())
+ params = []
+ for key, value in params_dict.items():
+ if key.startswith('model' + str(opt.n_local_enhancers)):
+ params += [{'params':[value],'lr':opt.lr}]
+ else:
+ params += [{'params':[value],'lr':0.0}]
+ else:
+ params = list(self.netG.parameters())
+ if self.gen_features:
+ params += list(self.netE.parameters())
+ self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
+
+ # optimizer D
+ params = list(self.netD.parameters())
+ self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
+
+ def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
+ # create one-hot vector for label map
+ size = label_map.size()
+ oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
+ input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
+ input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
+
+ # get edges from instance map
+ if not self.opt.no_instance:
+ inst_map = inst_map.data.cuda()
+ edge_map = self.get_edges(inst_map)
+ input_label = torch.cat((input_label, edge_map), dim=1)
+ input_label = Variable(input_label, volatile=infer)
+
+ # real images for training
+ if real_image is not None:
+ real_image = Variable(real_image.data.cuda())
+
+ # instance map for feature encoding
+ if self.use_features:
+ # get precomputed feature maps
+ if self.opt.load_features:
+ feat_map = Variable(feat_map.data.cuda())
+
+ return input_label, inst_map, real_image, feat_map
+
+ def discriminate(self, input_label, test_image, use_pool=False):
+ input_concat = torch.cat((input_label, test_image.detach()), dim=1)
+ if use_pool:
+ fake_query = self.fake_pool.query(input_concat)
+ return self.netD.forward(fake_query)
+ else:
+ return self.netD.forward(input_concat)
+
+ def forward(self, label, inst, image, feat, infer=False):
+ # Encode Inputs
+ input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
+
+ # Fake Generation
+ if self.use_features:
+ if not self.opt.load_features:
+ feat_map = self.netE.forward(real_image, inst_map)
+ input_concat = torch.cat((input_label, feat_map), dim=1)
+ else:
+ input_concat = input_label
+ fake_image = self.netG.forward(input_concat)
+
+ # Fake Detection and Loss
+ pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
+ loss_D_fake = self.criterionGAN(pred_fake_pool, False)
+
+ # Real Detection and Loss
+ pred_real = self.discriminate(input_label, real_image)
+ loss_D_real = self.criterionGAN(pred_real, True)
+
+ # GAN loss (Fake Passability Loss)
+ pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
+ loss_G_GAN = self.criterionGAN(pred_fake, True)
+
+ # GAN feature matching loss
+ loss_G_GAN_Feat = 0
+ if not self.opt.no_ganFeat_loss:
+ feat_weights = 4.0 / (self.opt.n_layers_D + 1)
+ D_weights = 1.0 / self.opt.num_D
+ for i in range(self.opt.num_D):
+ for j in range(len(pred_fake[i])-1):
+ loss_G_GAN_Feat += D_weights * feat_weights * \
+ self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
+
+ # VGG feature matching loss
+ loss_G_VGG = 0
+ if not self.opt.no_vgg_loss:
+ loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
+
+ # Only return the fake_B image if necessary to save BW
+ return [ [ loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ], None if not infer else fake_image ]
+
+ def inference(self, label, inst):
+ # Encode Inputs
+ input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True)
+
+ # Fake Generation
+ if self.use_features:
+ # sample clusters from precomputed features
+ feat_map = self.sample_features(inst_map)
+ input_concat = torch.cat((input_label, feat_map), dim=1)
+ else:
+ input_concat = input_label
+ fake_image = self.netG.forward(input_concat)
+ return fake_image
+
+ def sample_features(self, inst):
+ # read precomputed feature clusters
+ cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
+ features_clustered = np.load(cluster_path).item()
+
+ # randomly sample from the feature clusters
+ inst_np = inst.cpu().numpy().astype(int)
+ feat_map = torch.cuda.FloatTensor(1, self.opt.feat_num, inst.size()[2], inst.size()[3])
+ for i in np.unique(inst_np):
+ label = i if i < 1000 else i//1000
+ if label in features_clustered:
+ feat = features_clustered[label]
+ cluster_idx = np.random.randint(0, feat.shape[0])
+
+ idx = (inst == i).nonzero()
+ for k in range(self.opt.feat_num):
+ feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
+ return feat_map
+
+ def encode_features(self, image, inst):
+ image = Variable(image.cuda(), volatile=True)
+ feat_num = self.opt.feat_num
+ h, w = inst.size()[2], inst.size()[3]
+ block_num = 32
+ feat_map = self.netE.forward(image, inst.cuda())
+ inst_np = inst.cpu().numpy().astype(int)
+ feature = {}
+ for i in range(self.opt.label_nc):
+ feature[i] = np.zeros((0, feat_num+1))
+ for i in np.unique(inst_np):
+ label = i if i < 1000 else i//1000
+ idx = (inst == i).nonzero()
+ num = idx.size()[0]
+ idx = idx[num//2,:]
+ val = np.zeros((1, feat_num+1))
+ for k in range(feat_num):
+ val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
+ val[0, feat_num] = float(num) / (h * w // block_num)
+ feature[label] = np.append(feature[label], val, axis=0)
+ return feature
+
+ def get_edges(self, t):
+ edge = torch.cuda.ByteTensor(t.size()).zero_()
+ edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
+ edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
+ edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
+ edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
+ return edge.float()
+
+ def save(self, which_epoch):
+ self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
+ self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
+ if self.gen_features:
+ self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
+
+ def update_fixed_params(self):
+ # after fixing the global generator for a number of iterations, also start finetuning it
+ params = list(self.netG.parameters())
+ if self.gen_features:
+ params += list(self.netE.parameters())
+ self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
+ print('------------ Now also finetuning global generator -----------')
+
+ def update_learning_rate(self):
+ lrd = self.opt.lr / self.opt.niter_decay
+ lr = self.old_lr - lrd
+ for param_group in self.optimizer_D.param_groups:
+ param_group['lr'] = lr
+ for param_group in self.optimizer_G.param_groups:
+ param_group['lr'] = lr
+ print('update learning rate: %f -> %f' % (self.old_lr, lr))
+ self.old_lr = lr