summaryrefslogtreecommitdiff
path: root/models/pix2pixHD_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/pix2pixHD_model.py')
-rwxr-xr-xmodels/pix2pixHD_model.py20
1 files changed, 12 insertions, 8 deletions
diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py
index fea9a1d..834fc18 100755
--- a/models/pix2pixHD_model.py
+++ b/models/pix2pixHD_model.py
@@ -19,10 +19,11 @@ class Pix2PixHDModel(BaseModel):
self.isTrain = opt.isTrain
self.use_features = opt.instance_feat or opt.label_feat
self.gen_features = self.use_features and not self.opt.load_features
+ input_nc = opt.label_nc if opt.label_nc != 0 else 3
##### define networks
# Generator network
- netG_input_nc = opt.label_nc
+ netG_input_nc = input_nc
if not opt.no_instance:
netG_input_nc += 1
if self.use_features:
@@ -34,7 +35,7 @@ class Pix2PixHDModel(BaseModel):
# Discriminator network
if self.isTrain:
use_sigmoid = opt.no_lsgan
- netD_input_nc = opt.label_nc + opt.output_nc
+ netD_input_nc = input_nc + opt.output_nc
if not opt.no_instance:
netD_input_nc += 1
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
@@ -93,12 +94,15 @@ class Pix2PixHDModel(BaseModel):
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
- def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
- # create one-hot vector for label map
- size = label_map.size()
- oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
- input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
- input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
+ def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
+ if self.opt.label_nc == 0:
+ input_label = label_map.data.cuda()
+ else:
+ # create one-hot vector for label map
+ size = label_map.size()
+ oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
+ input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
+ input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
# get edges from instance map
if not self.opt.no_instance: