summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rwxr-xr-xmodels/__init__.py0
-rwxr-xr-xmodels/base_model.py86
-rwxr-xr-xmodels/models.py14
-rwxr-xr-xmodels/networks.py421
-rwxr-xr-xmodels/pix2pixHD_model.py260
5 files changed, 781 insertions, 0 deletions
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100755
index 0000000..e69de29
--- /dev/null
+++ b/models/__init__.py
diff --git a/models/base_model.py b/models/base_model.py
new file mode 100755
index 0000000..d3879d0
--- /dev/null
+++ b/models/base_model.py
@@ -0,0 +1,86 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import os
+import torch
+
+class BaseModel(torch.nn.Module):
+ def name(self):
+ return 'BaseModel'
+
+ def initialize(self, opt):
+ self.opt = opt
+ self.gpu_ids = opt.gpu_ids
+ self.isTrain = opt.isTrain
+ self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
+
+ def set_input(self, input):
+ self.input = input
+
+ def forward(self):
+ pass
+
+ # used in test time, no backprop
+ def test(self):
+ pass
+
+ def get_image_paths(self):
+ pass
+
+ def optimize_parameters(self):
+ pass
+
+ def get_current_visuals(self):
+ return self.input
+
+ def get_current_errors(self):
+ return {}
+
+ def save(self, label):
+ pass
+
+ # helper saving function that can be used by subclasses
+ def save_network(self, network, network_label, epoch_label, gpu_ids):
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
+ save_path = os.path.join(self.save_dir, save_filename)
+ torch.save(network.cpu().state_dict(), save_path)
+ if len(gpu_ids) and torch.cuda.is_available():
+ network.cuda()
+
+ # helper loading function that can be used by subclasses
+ def load_network(self, network, network_label, epoch_label, save_dir=''):
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
+ if not save_dir:
+ save_dir = self.save_dir
+ save_path = os.path.join(save_dir, save_filename)
+ if not os.path.isfile(save_path):
+ print('%s not exists yet!' % save_path)
+ if network_label == 'G':
+ raise('Generator must exist!')
+ else:
+ #network.load_state_dict(torch.load(save_path))
+ try:
+ network.load_state_dict(torch.load(save_path))
+ except:
+ pretrained_dict = torch.load(save_path)
+ model_dict = network.state_dict()
+ try:
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
+ network.load_state_dict(pretrained_dict)
+ print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
+ except:
+ print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
+ from sets import Set
+ not_initialized = Set()
+ for k, v in pretrained_dict.items():
+ if v.size() == model_dict[k].size():
+ model_dict[k] = v
+
+ for k, v in model_dict.items():
+ if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
+ not_initialized.add(k.split('.')[0])
+ print(sorted(not_initialized))
+ network.load_state_dict(model_dict)
+
+ def update_learning_rate():
+ pass
diff --git a/models/models.py b/models/models.py
new file mode 100755
index 0000000..351483c
--- /dev/null
+++ b/models/models.py
@@ -0,0 +1,14 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import torch
+
+def create_model(opt):
+ from .pix2pixHD_model import Pix2PixHDModel
+ model = Pix2PixHDModel()
+ model.initialize(opt)
+ print("model [%s] was created" % (model.name()))
+
+ if opt.isTrain and len(opt.gpu_ids):
+ model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
+
+ return model
diff --git a/models/networks.py b/models/networks.py
new file mode 100755
index 0000000..a673a56
--- /dev/null
+++ b/models/networks.py
@@ -0,0 +1,421 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import torch
+import torch.nn as nn
+from torch.nn import init
+import functools
+from torch.autograd import Variable
+import numpy as np
+import math
+import torch.nn.functional as F
+import copy
+
+###############################################################################
+# Functions
+###############################################################################
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ m.weight.data.normal_(0.0, 0.02)
+ elif classname.find('BatchNorm2d') != -1:
+ m.weight.data.normal_(1.0, 0.02)
+ m.bias.data.fill_(0)
+
+def get_norm_layer(norm_type='instance'):
+ if norm_type == 'batch':
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
+ elif norm_type == 'instance':
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
+ else:
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
+ return norm_layer
+
+def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
+ n_blocks_local=3, norm='instance', gpu_ids=[]):
+ norm_layer = get_norm_layer(norm_type=norm)
+ if netG == 'global':
+ netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
+ elif netG == 'local':
+ netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
+ n_local_enhancers, n_blocks_local, norm_layer)
+ elif netG == 'encoder':
+ netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer)
+ else:
+ raise('generator not implemented!')
+ print(netG)
+ if len(gpu_ids) > 0:
+ assert(torch.cuda.is_available())
+ netG.cuda(device_id=gpu_ids[0])
+ netG.apply(weights_init)
+ return netG
+
+def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
+ norm_layer = get_norm_layer(norm_type=norm)
+ netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
+ print(netD)
+ if len(gpu_ids) > 0:
+ assert(torch.cuda.is_available())
+ netD.cuda(device_id=gpu_ids[0])
+ netD.apply(weights_init)
+ return netD
+
+def print_network(net):
+ if isinstance(net, list):
+ net = net[0]
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ print(net)
+ print('Total number of parameters: %d' % num_params)
+
+##############################################################################
+# Losses
+##############################################################################
+class GANLoss(nn.Module):
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
+ tensor=torch.FloatTensor):
+ super(GANLoss, self).__init__()
+ self.real_label = target_real_label
+ self.fake_label = target_fake_label
+ self.real_label_var = None
+ self.fake_label_var = None
+ self.Tensor = tensor
+ if use_lsgan:
+ self.loss = nn.MSELoss()
+ else:
+ self.loss = nn.BCELoss()
+
+ def get_target_tensor(self, input, target_is_real):
+ target_tensor = None
+ if target_is_real:
+ create_label = ((self.real_label_var is None) or
+ (self.real_label_var.numel() != input.numel()))
+ if create_label:
+ real_tensor = self.Tensor(input.size()).fill_(self.real_label)
+ self.real_label_var = Variable(real_tensor, requires_grad=False)
+ target_tensor = self.real_label_var
+ else:
+ create_label = ((self.fake_label_var is None) or
+ (self.fake_label_var.numel() != input.numel()))
+ if create_label:
+ fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
+ self.fake_label_var = Variable(fake_tensor, requires_grad=False)
+ target_tensor = self.fake_label_var
+ return target_tensor
+
+ def __call__(self, input, target_is_real):
+ if isinstance(input[0], list):
+ loss = 0
+ for input_i in input:
+ pred = input_i[-1]
+ target_tensor = self.get_target_tensor(pred, target_is_real)
+ loss += self.loss(pred, target_tensor)
+ return loss
+ else:
+ target_tensor = self.get_target_tensor(input[-1], target_is_real)
+ return self.loss(input[-1], target_tensor)
+
+class VGGLoss(nn.Module):
+ def __init__(self, gpu_ids):
+ super(VGGLoss, self).__init__()
+ self.vgg = Vgg19().cuda()
+ self.criterion = nn.L1Loss()
+ self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
+
+ def forward(self, x, y):
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
+ loss = 0
+ for i in range(len(x_vgg)):
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
+ return loss
+
+##############################################################################
+# Generator
+##############################################################################
+class LocalEnhancer(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9,
+ n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'):
+ super(LocalEnhancer, self).__init__()
+ self.n_local_enhancers = n_local_enhancers
+
+ ###### global generator model #####
+ ngf_global = ngf * (2**n_local_enhancers)
+ model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model
+ model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers
+ self.model = nn.Sequential(*model_global)
+
+ ###### local enhancer layers #####
+ for n in range(1, n_local_enhancers+1):
+ ### downsample
+ ngf_global = ngf * (2**(n_local_enhancers-n))
+ model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
+ norm_layer(ngf_global), nn.ReLU(True),
+ nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1),
+ norm_layer(ngf_global * 2), nn.ReLU(True)]
+ ### residual blocks
+ model_upsample = []
+ for i in range(n_blocks_local):
+ model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer)]
+
+ ### upsample
+ model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1),
+ norm_layer(ngf_global), nn.ReLU(True)]
+
+ ### final convolution
+ if n == n_local_enhancers:
+ model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
+
+ setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample))
+ setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample))
+
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
+
+ def forward(self, input):
+ ### create input pyramid
+ input_downsampled = [input]
+ for i in range(self.n_local_enhancers):
+ input_downsampled.append(self.downsample(input_downsampled[-1]))
+
+ ### output at coarest level
+ output_prev = self.model(input_downsampled[-1])
+ ### build up one layer at a time
+ for n_local_enhancers in range(1, self.n_local_enhancers+1):
+ model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1')
+ model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2')
+ input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers]
+ output_prev = model_upsample(model_downsample(input_i) + output_prev)
+ return output_prev
+
+class GlobalGenerator(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
+ padding_type='reflect'):
+ assert(n_blocks >= 0)
+ super(GlobalGenerator, self).__init__()
+ activation = nn.ReLU(True)
+
+ model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
+ ### downsample
+ for i in range(n_downsampling):
+ mult = 2**i
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
+ norm_layer(ngf * mult * 2), activation]
+
+ ### resnet blocks
+ mult = 2**n_downsampling
+ for i in range(n_blocks):
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
+
+ ### upsample
+ for i in range(n_downsampling):
+ mult = 2**(n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
+ norm_layer(int(ngf * mult / 2)), activation]
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ return self.model(input)
+
+# Define a resnet block
+class ResnetBlock(nn.Module):
+ def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
+ super(ResnetBlock, self).__init__()
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
+
+ def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
+ conv_block = []
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
+ norm_layer(dim),
+ activation]
+ if use_dropout:
+ conv_block += [nn.Dropout(0.5)]
+
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
+ norm_layer(dim)]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ out = x + self.conv_block(x)
+ return out
+
+class Encoder(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
+ super(Encoder, self).__init__()
+ self.output_nc = output_nc
+
+ model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
+ norm_layer(ngf), nn.ReLU(True)]
+ ### downsample
+ for i in range(n_downsampling):
+ mult = 2**i
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
+ norm_layer(ngf * mult * 2), nn.ReLU(True)]
+
+ ### upsample
+ for i in range(n_downsampling):
+ mult = 2**(n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
+ norm_layer(int(ngf * mult / 2)), nn.ReLU(True)]
+
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input, inst):
+ outputs = self.model(input)
+
+ # instance-wise average pooling
+ outputs_mean = outputs.clone()
+ inst_list = np.unique(inst.cpu().numpy().astype(int))
+ for i in inst_list:
+ indices = (inst == i).nonzero() # n x 4
+ for j in range(self.output_nc):
+ output_ins = outputs[indices[:,0], indices[:,1] + j, indices[:,2], indices[:,3]]
+ mean_feat = torch.mean(output_ins).expand_as(output_ins)
+ outputs_mean[indices[:,0], indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat
+ return outputs_mean
+
+class MultiscaleDiscriminator(nn.Module):
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
+ use_sigmoid=False, num_D=3, getIntermFeat=False):
+ super(MultiscaleDiscriminator, self).__init__()
+ self.num_D = num_D
+ self.n_layers = n_layers
+ self.getIntermFeat = getIntermFeat
+
+ for i in range(num_D):
+ netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
+ if getIntermFeat:
+ for j in range(n_layers+2):
+ setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
+ else:
+ setattr(self, 'layer'+str(i), netD.model)
+
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
+
+ def singleD_forward(self, model, input):
+ if self.getIntermFeat:
+ result = [input]
+ for i in range(len(model)):
+ result.append(model[i](result[-1]))
+ return result[1:]
+ else:
+ return [model(input)]
+
+ def forward(self, input):
+ num_D = self.num_D
+ result = []
+ input_downsampled = input
+ for i in range(num_D):
+ if self.getIntermFeat:
+ model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
+ else:
+ model = getattr(self, 'layer'+str(num_D-1-i))
+ result.append(self.singleD_forward(model, input_downsampled))
+ if i != (num_D-1):
+ input_downsampled = self.downsample(input_downsampled)
+ return result
+
+# Defines the PatchGAN discriminator with the specified arguments.
+class NLayerDiscriminator(nn.Module):
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
+ super(NLayerDiscriminator, self).__init__()
+ self.getIntermFeat = getIntermFeat
+ self.n_layers = n_layers
+
+ kw = 4
+ padw = int(np.ceil((kw-1.0)/2))
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
+
+ nf = ndf
+ for n in range(1, n_layers):
+ nf_prev = nf
+ nf = min(nf * 2, 512)
+ sequence += [[
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
+ norm_layer(nf), nn.LeakyReLU(0.2, True)
+ ]]
+
+ nf_prev = nf
+ nf = min(nf * 2, 512)
+ sequence += [[
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
+ norm_layer(nf),
+ nn.LeakyReLU(0.2, True)
+ ]]
+
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
+
+ if use_sigmoid:
+ sequence += [nn.Sigmoid()]
+
+ if getIntermFeat:
+ for n in range(len(sequence)):
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
+ else:
+ sequence_stream = []
+ for n in range(len(sequence)):
+ sequence_stream += sequence[n]
+ self.model = nn.Sequential(*sequence_stream)
+
+ def forward(self, input):
+ if self.getIntermFeat:
+ res = [input]
+ for n in range(self.n_layers+2):
+ model = getattr(self, 'model'+str(n))
+ res.append(model(res[-1]))
+ return res[1:]
+ else:
+ return self.model(input)
+
+from torchvision import models
+class Vgg19(torch.nn.Module):
+ def __init__(self, requires_grad=False):
+ super(Vgg19, self).__init__()
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ for x in range(2):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(2, 7):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(7, 12):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(12, 21):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(21, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h_relu1 = self.slice1(X)
+ h_relu2 = self.slice2(h_relu1)
+ h_relu3 = self.slice3(h_relu2)
+ h_relu4 = self.slice4(h_relu3)
+ h_relu5 = self.slice5(h_relu4)
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
+ return out
diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py
new file mode 100755
index 0000000..ba44e53
--- /dev/null
+++ b/models/pix2pixHD_model.py
@@ -0,0 +1,260 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import numpy as np
+import torch
+import os
+from collections import OrderedDict
+from torch.autograd import Variable
+import util.util as util
+from util.image_pool import ImagePool
+from .base_model import BaseModel
+from . import networks
+
+class Pix2PixHDModel(BaseModel):
+ def name(self):
+ return 'Pix2PixHDModel'
+
+ def initialize(self, opt):
+ BaseModel.initialize(self, opt)
+ if opt.resize_or_crop != 'none': # when training at full res this causes OOM
+ torch.backends.cudnn.benchmark = True
+ self.isTrain = opt.isTrain
+ self.use_features = opt.instance_feat or opt.label_feat
+ self.gen_features = self.use_features and not self.opt.load_features
+
+ ##### define networks
+ # Generator network
+ netG_input_nc = opt.label_nc
+ if not opt.no_instance:
+ netG_input_nc += 1
+ if self.use_features:
+ netG_input_nc += opt.feat_num
+ self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
+ opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
+ opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
+
+ # Discriminator network
+ if self.isTrain:
+ use_sigmoid = opt.no_lsgan
+ netD_input_nc = opt.label_nc + opt.output_nc
+ if not opt.no_instance:
+ netD_input_nc += 1
+ self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
+ opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
+
+ ### Encoder network
+ if self.gen_features:
+ self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
+ opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)
+
+ print('---------- Networks initialized -------------')
+
+ # load networks
+ if not self.isTrain or opt.continue_train or opt.load_pretrain:
+ pretrained_path = '' if not self.isTrain else opt.load_pretrain
+ self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
+ if self.isTrain:
+ self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
+ if self.gen_features:
+ self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)
+
+ # set loss functions and optimizers
+ if self.isTrain:
+ if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
+ raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
+ self.fake_pool = ImagePool(opt.pool_size)
+ self.old_lr = opt.lr
+
+ # define loss functions
+ self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
+ self.criterionFeat = torch.nn.L1Loss()
+ if not opt.no_vgg_loss:
+ self.criterionVGG = networks.VGGLoss(self.gpu_ids)
+
+ # Names so we can breakout loss
+ self.loss_names = ['G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake']
+
+ # initialize optimizers
+ # optimizer G
+ if opt.niter_fix_global > 0:
+ print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
+ params_dict = dict(self.netG.named_parameters())
+ params = []
+ for key, value in params_dict.items():
+ if key.startswith('model' + str(opt.n_local_enhancers)):
+ params += [{'params':[value],'lr':opt.lr}]
+ else:
+ params += [{'params':[value],'lr':0.0}]
+ else:
+ params = list(self.netG.parameters())
+ if self.gen_features:
+ params += list(self.netE.parameters())
+ self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
+
+ # optimizer D
+ params = list(self.netD.parameters())
+ self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
+
+ def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
+ # create one-hot vector for label map
+ size = label_map.size()
+ oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
+ input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
+ input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
+
+ # get edges from instance map
+ if not self.opt.no_instance:
+ inst_map = inst_map.data.cuda()
+ edge_map = self.get_edges(inst_map)
+ input_label = torch.cat((input_label, edge_map), dim=1)
+ input_label = Variable(input_label, volatile=infer)
+
+ # real images for training
+ if real_image is not None:
+ real_image = Variable(real_image.data.cuda())
+
+ # instance map for feature encoding
+ if self.use_features:
+ # get precomputed feature maps
+ if self.opt.load_features:
+ feat_map = Variable(feat_map.data.cuda())
+
+ return input_label, inst_map, real_image, feat_map
+
+ def discriminate(self, input_label, test_image, use_pool=False):
+ input_concat = torch.cat((input_label, test_image.detach()), dim=1)
+ if use_pool:
+ fake_query = self.fake_pool.query(input_concat)
+ return self.netD.forward(fake_query)
+ else:
+ return self.netD.forward(input_concat)
+
+ def forward(self, label, inst, image, feat, infer=False):
+ # Encode Inputs
+ input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
+
+ # Fake Generation
+ if self.use_features:
+ if not self.opt.load_features:
+ feat_map = self.netE.forward(real_image, inst_map)
+ input_concat = torch.cat((input_label, feat_map), dim=1)
+ else:
+ input_concat = input_label
+ fake_image = self.netG.forward(input_concat)
+
+ # Fake Detection and Loss
+ pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
+ loss_D_fake = self.criterionGAN(pred_fake_pool, False)
+
+ # Real Detection and Loss
+ pred_real = self.discriminate(input_label, real_image)
+ loss_D_real = self.criterionGAN(pred_real, True)
+
+ # GAN loss (Fake Passability Loss)
+ pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
+ loss_G_GAN = self.criterionGAN(pred_fake, True)
+
+ # GAN feature matching loss
+ loss_G_GAN_Feat = 0
+ if not self.opt.no_ganFeat_loss:
+ feat_weights = 4.0 / (self.opt.n_layers_D + 1)
+ D_weights = 1.0 / self.opt.num_D
+ for i in range(self.opt.num_D):
+ for j in range(len(pred_fake[i])-1):
+ loss_G_GAN_Feat += D_weights * feat_weights * \
+ self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
+
+ # VGG feature matching loss
+ loss_G_VGG = 0
+ if not self.opt.no_vgg_loss:
+ loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
+
+ # Only return the fake_B image if necessary to save BW
+ return [ [ loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ], None if not infer else fake_image ]
+
+ def inference(self, label, inst):
+ # Encode Inputs
+ input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True)
+
+ # Fake Generation
+ if self.use_features:
+ # sample clusters from precomputed features
+ feat_map = self.sample_features(inst_map)
+ input_concat = torch.cat((input_label, feat_map), dim=1)
+ else:
+ input_concat = input_label
+ fake_image = self.netG.forward(input_concat)
+ return fake_image
+
+ def sample_features(self, inst):
+ # read precomputed feature clusters
+ cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
+ features_clustered = np.load(cluster_path).item()
+
+ # randomly sample from the feature clusters
+ inst_np = inst.cpu().numpy().astype(int)
+ feat_map = torch.cuda.FloatTensor(1, self.opt.feat_num, inst.size()[2], inst.size()[3])
+ for i in np.unique(inst_np):
+ label = i if i < 1000 else i//1000
+ if label in features_clustered:
+ feat = features_clustered[label]
+ cluster_idx = np.random.randint(0, feat.shape[0])
+
+ idx = (inst == i).nonzero()
+ for k in range(self.opt.feat_num):
+ feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
+ return feat_map
+
+ def encode_features(self, image, inst):
+ image = Variable(image.cuda(), volatile=True)
+ feat_num = self.opt.feat_num
+ h, w = inst.size()[2], inst.size()[3]
+ block_num = 32
+ feat_map = self.netE.forward(image, inst.cuda())
+ inst_np = inst.cpu().numpy().astype(int)
+ feature = {}
+ for i in range(self.opt.label_nc):
+ feature[i] = np.zeros((0, feat_num+1))
+ for i in np.unique(inst_np):
+ label = i if i < 1000 else i//1000
+ idx = (inst == i).nonzero()
+ num = idx.size()[0]
+ idx = idx[num//2,:]
+ val = np.zeros((1, feat_num+1))
+ for k in range(feat_num):
+ val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
+ val[0, feat_num] = float(num) / (h * w // block_num)
+ feature[label] = np.append(feature[label], val, axis=0)
+ return feature
+
+ def get_edges(self, t):
+ edge = torch.cuda.ByteTensor(t.size()).zero_()
+ edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
+ edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
+ edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
+ edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
+ return edge.float()
+
+ def save(self, which_epoch):
+ self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
+ self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
+ if self.gen_features:
+ self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
+
+ def update_fixed_params(self):
+ # after fixing the global generator for a number of iterations, also start finetuning it
+ params = list(self.netG.parameters())
+ if self.gen_features:
+ params += list(self.netE.parameters())
+ self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
+ print('------------ Now also finetuning global generator -----------')
+
+ def update_learning_rate(self):
+ lrd = self.opt.lr / self.opt.niter_decay
+ lr = self.old_lr - lrd
+ for param_group in self.optimizer_D.param_groups:
+ param_group['lr'] = lr
+ for param_group in self.optimizer_G.param_groups:
+ param_group['lr'] = lr
+ print('update learning rate: %f -> %f' % (self.old_lr, lr))
+ self.old_lr = lr