diff options
Diffstat (limited to 'models/pix2pixHD_model.py')
| -rwxr-xr-x | models/pix2pixHD_model.py | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py index fea9a1d..834fc18 100755 --- a/models/pix2pixHD_model.py +++ b/models/pix2pixHD_model.py @@ -19,10 +19,11 @@ class Pix2PixHDModel(BaseModel): self.isTrain = opt.isTrain self.use_features = opt.instance_feat or opt.label_feat self.gen_features = self.use_features and not self.opt.load_features + input_nc = opt.label_nc if opt.label_nc != 0 else 3 ##### define networks # Generator network - netG_input_nc = opt.label_nc + netG_input_nc = input_nc if not opt.no_instance: netG_input_nc += 1 if self.use_features: @@ -34,7 +35,7 @@ class Pix2PixHDModel(BaseModel): # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan - netD_input_nc = opt.label_nc + opt.output_nc + netD_input_nc = input_nc + opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, @@ -93,12 +94,15 @@ class Pix2PixHDModel(BaseModel): params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) - def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): - # create one-hot vector for label map - size = label_map.size() - oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) - input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() - input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) + def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): + if self.opt.label_nc == 0: + input_label = label_map.data.cuda() + else: + # create one-hot vector for label map + size = label_map.size() + oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) + input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() + input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) # get edges from instance map if not self.opt.no_instance: |
