diff options
Diffstat (limited to 'models')
| -rwxr-xr-x | models/base_model.py | 8 | ||||
| -rwxr-xr-x | models/models.py | 3 | ||||
| -rwxr-xr-x | models/pix2pixHD_model.py | 28 |
3 files changed, 26 insertions, 13 deletions
diff --git a/models/base_model.py b/models/base_model.py index 88e0587..2cda12f 100755 --- a/models/base_model.py +++ b/models/base_model.py @@ -68,7 +68,8 @@ class BaseModel(torch.nn.Module): try: pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} network.load_state_dict(pretrained_dict) - print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) + if self.opt.verbose: + print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) except: print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) if sys.version_info >= (3,0): @@ -82,8 +83,9 @@ class BaseModel(torch.nn.Module): for k, v in model_dict.items(): if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): - not_initialized.add(k.split('.')[0]) - print(sorted(not_initialized)) + not_initialized.add(k.split('.')[0]) + if self.opt.verbose: + print(sorted(not_initialized)) network.load_state_dict(model_dict) def update_learning_rate(): diff --git a/models/models.py b/models/models.py index 0ba442f..805696f 100755 --- a/models/models.py +++ b/models/models.py @@ -10,7 +10,8 @@ def create_model(opt): from .ui_model import UIModel model = UIModel() model.initialize(opt) - print("model [%s] was created" % (model.name())) + if opt.verbose: + print("model [%s] was created" % (model.name())) if opt.isTrain and len(opt.gpu_ids): model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py index 834fc18..de594ab 100755 --- a/models/pix2pixHD_model.py +++ b/models/pix2pixHD_model.py @@ -45,8 +45,8 @@ class Pix2PixHDModel(BaseModel): if self.gen_features: self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) - - print('---------- Networks initialized -------------') + if self.opt.verbose: + print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: @@ -76,7 +76,8 @@ class Pix2PixHDModel(BaseModel): # initialize optimizers # optimizer G if opt.niter_fix_global > 0: - print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) + if self.opt.verbose: + print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): @@ -103,13 +104,15 @@ class Pix2PixHDModel(BaseModel): oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) + if self.opt.data_type==16: + input_label = input_label.half() # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) - input_label = Variable(input_label, volatile=infer) + input_label = Variable(input_label, requires_grad = not infer) # real images for training if real_image is not None: @@ -204,7 +207,9 @@ class Pix2PixHDModel(BaseModel): idx = (inst == i).nonzero() for k in range(self.opt.feat_num): - feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] + feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] + if self.opt.data_type==16: + feat_map = feat_map.half() return feat_map def encode_features(self, image, inst): @@ -235,7 +240,10 @@ class Pix2PixHDModel(BaseModel): edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) - return edge.float() + if self.opt.data_type==16: + return edge.half() + else: + return edge.float() def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) @@ -248,8 +256,9 @@ class Pix2PixHDModel(BaseModel): params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) - self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) - print('------------ Now also finetuning global generator -----------') + self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) + if self.opt.verbose: + print('------------ Now also finetuning global generator -----------') def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay @@ -258,5 +267,6 @@ class Pix2PixHDModel(BaseModel): param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr - print('update learning rate: %f -> %f' % (self.old_lr, lr)) + if self.opt.verbose: + print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr |
