diff options
| author | Ting-Chun Wang <tcwang0509@berkeley.edu> | 2018-05-30 22:39:01 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2018-05-30 22:39:01 -0700 |
| commit | a2340c3fff9de44c8ef1fea5b90fced756fbbb18 (patch) | |
| tree | 39f3c05c80a94d721ec6fed0f0da65ecbc3bc603 /models/pix2pixHD_model.py | |
| parent | 1b89cd010dce2e6edaa07d23c8edd8dfe146e0e1 (diff) | |
| parent | 25e205604e7eafa83867a15cfda526461fe58455 (diff) | |
Merge pull request #33 from borisfom/fp16
Added data size and ONNX export options, FP16 inference is working
Diffstat (limited to 'models/pix2pixHD_model.py')
| -rwxr-xr-x | models/pix2pixHD_model.py | 35 |
1 files changed, 26 insertions, 9 deletions
diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py index b77868a..79ebabd 100755 --- a/models/pix2pixHD_model.py +++ b/models/pix2pixHD_model.py @@ -50,8 +50,8 @@ class Pix2PixHDModel(BaseModel): if self.gen_features: self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) - - print('---------- Networks initialized -------------') + if self.opt.verbose: + print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: @@ -84,7 +84,8 @@ class Pix2PixHDModel(BaseModel): # initialize optimizers # optimizer G if opt.niter_fix_global > 0: - print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) + if self.opt.verbose: + print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): @@ -111,13 +112,15 @@ class Pix2PixHDModel(BaseModel): oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) + if self.opt.data_type==16: + input_label = input_label.half() # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) - input_label = Variable(input_label, volatile=infer) + input_label = Variable(input_label, requires_grad = not infer) # real images for training if real_image is not None: @@ -212,7 +215,9 @@ class Pix2PixHDModel(BaseModel): idx = (inst == i).nonzero() for k in range(self.opt.feat_num): - feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] + feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] + if self.opt.data_type==16: + feat_map = feat_map.half() return feat_map def encode_features(self, image, inst): @@ -243,7 +248,10 @@ class Pix2PixHDModel(BaseModel): edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) - return edge.float() + if self.opt.data_type==16: + return edge.half() + else: + return edge.float() def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) @@ -256,8 +264,9 @@ class Pix2PixHDModel(BaseModel): params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) - self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) - print('------------ Now also finetuning global generator -----------') + self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) + if self.opt.verbose: + print('------------ Now also finetuning global generator -----------') def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay @@ -266,5 +275,13 @@ class Pix2PixHDModel(BaseModel): param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr - print('update learning rate: %f -> %f' % (self.old_lr, lr)) + if self.opt.verbose: + print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr + +class InferenceModel(Pix2PixHDModel): + def forward(self, inp): + label, inst = inp + return self.inference(label, inst) + + |
