summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rwxr-xr-xmodels/base_model.py8
-rwxr-xr-xmodels/models.py10
-rwxr-xr-xmodels/pix2pixHD_model.py35
3 files changed, 38 insertions, 15 deletions
diff --git a/models/base_model.py b/models/base_model.py
index 88e0587..2cda12f 100755
--- a/models/base_model.py
+++ b/models/base_model.py
@@ -68,7 +68,8 @@ class BaseModel(torch.nn.Module):
try:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
- print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
+ if self.opt.verbose:
+ print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
except:
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
if sys.version_info >= (3,0):
@@ -82,8 +83,9 @@ class BaseModel(torch.nn.Module):
for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
- not_initialized.add(k.split('.')[0])
- print(sorted(not_initialized))
+ not_initialized.add(k.split('.')[0])
+ if self.opt.verbose:
+ print(sorted(not_initialized))
network.load_state_dict(model_dict)
def update_learning_rate():
diff --git a/models/models.py b/models/models.py
index 0ba442f..8e72e46 100755
--- a/models/models.py
+++ b/models/models.py
@@ -4,13 +4,17 @@ import torch
def create_model(opt):
if opt.model == 'pix2pixHD':
- from .pix2pixHD_model import Pix2PixHDModel
- model = Pix2PixHDModel()
+ from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
+ if opt.isTrain:
+ model = Pix2PixHDModel()
+ else:
+ model = InferenceModel()
else:
from .ui_model import UIModel
model = UIModel()
model.initialize(opt)
- print("model [%s] was created" % (model.name()))
+ if opt.verbose:
+ print("model [%s] was created" % (model.name()))
if opt.isTrain and len(opt.gpu_ids):
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py
index b77868a..79ebabd 100755
--- a/models/pix2pixHD_model.py
+++ b/models/pix2pixHD_model.py
@@ -50,8 +50,8 @@ class Pix2PixHDModel(BaseModel):
if self.gen_features:
self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)
-
- print('---------- Networks initialized -------------')
+ if self.opt.verbose:
+ print('---------- Networks initialized -------------')
# load networks
if not self.isTrain or opt.continue_train or opt.load_pretrain:
@@ -84,7 +84,8 @@ class Pix2PixHDModel(BaseModel):
# initialize optimizers
# optimizer G
if opt.niter_fix_global > 0:
- print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
+ if self.opt.verbose:
+ print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
params_dict = dict(self.netG.named_parameters())
params = []
for key, value in params_dict.items():
@@ -111,13 +112,15 @@ class Pix2PixHDModel(BaseModel):
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
+ if self.opt.data_type==16:
+ input_label = input_label.half()
# get edges from instance map
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map)
input_label = torch.cat((input_label, edge_map), dim=1)
- input_label = Variable(input_label, volatile=infer)
+ input_label = Variable(input_label, requires_grad = not infer)
# real images for training
if real_image is not None:
@@ -212,7 +215,9 @@ class Pix2PixHDModel(BaseModel):
idx = (inst == i).nonzero()
for k in range(self.opt.feat_num):
- feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
+ feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
+ if self.opt.data_type==16:
+ feat_map = feat_map.half()
return feat_map
def encode_features(self, image, inst):
@@ -243,7 +248,10 @@ class Pix2PixHDModel(BaseModel):
edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
- return edge.float()
+ if self.opt.data_type==16:
+ return edge.half()
+ else:
+ return edge.float()
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
@@ -256,8 +264,9 @@ class Pix2PixHDModel(BaseModel):
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
- self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
- print('------------ Now also finetuning global generator -----------')
+ self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
+ if self.opt.verbose:
+ print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
@@ -266,5 +275,13 @@ class Pix2PixHDModel(BaseModel):
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
- print('update learning rate: %f -> %f' % (self.old_lr, lr))
+ if self.opt.verbose:
+ print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
+
+class InferenceModel(Pix2PixHDModel):
+ def forward(self, inp):
+ label, inst = inp
+ return self.inference(label, inst)
+
+