summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xREADME.md8
-rwxr-xr-xdata/aligned_dataset.py12
-rwxr-xr-xmodels/pix2pixHD_model.py20
-rwxr-xr-xutil/util.py16
4 files changed, 36 insertions, 20 deletions
diff --git a/README.md b/README.md
index f710b22..c6ff340 100755
--- a/README.md
+++ b/README.md
@@ -107,7 +107,13 @@ Note: this is not tested and we trained our model using single GPU only. Please
- To train the images at full resolution (2048 x 1024) requires a GPU with 24G memory (`bash ./scripts/train_1024p_24G.sh`).
If only GPUs with 12G memory are available, please use the 12G script (`bash ./scripts/train_1024p_12G.sh`), which will crop the images during training. Performance is not guaranteed using this script.
-## More Training/test Details
+### Training with your own dataset
+- If you want to train with your own dataset, please generate label maps which are one-channel whose pixel values correspond to the object labels (i.e. 0,1,...,N-1, where N is the number of labels). This is because we need to generate one-hot vectors from the label maps. Please also specity `--label_nc N` during both training and testing.
+- If your input is not a label map, please just specify `--label_nc 0` which will directly use the RGB colors as input.
+- If you don't have instance maps or don't want to use them, please specify `--no_instance`.
+- The default setting for preprocessing is `scale_width`, which will scale the width of all training images to `opt.loadSize` (1024) while keeping the aspect ratio. If you want a different setting, please change it by using the `--resize_or_crop` option. For example, `scale_width_and_crop` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. `crop` skips the resizing step and only performs random cropping. If you don't want any preprocessing, please specify `none`, which will do nothing other than making sure the image is divisible by 32.
+
+## More Training/Test Details
- Flags: see `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags.
- Instance map: we take in both label maps and instance maps as input. If you don't want to use instance maps, please specify the flag `--no_instance`.
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
index a0c9a0a..a3cdc76 100755
--- a/data/aligned_dataset.py
+++ b/data/aligned_dataset.py
@@ -33,12 +33,16 @@ class AlignedDataset(BaseDataset):
self.dataset_size = len(self.label_paths)
def __getitem__(self, index):
- ### label maps
+ ### label maps
label_path = self.label_paths[index]
label = Image.open(label_path)
- params = get_params(self.opt, label.size)
- transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
- label_tensor = transform_label(label) * 255.0
+ params = get_params(self.opt, label.size)
+ if self.opt.label_nc == 0:
+ transform_label = get_transform(self.opt, params)
+ label_tensor = transform_label(label.convert('RGB'))
+ else:
+ transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
+ label_tensor = transform_label(label) * 255.0
image_tensor = inst_tensor = feat_tensor = 0
### real images
diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py
index fea9a1d..834fc18 100755
--- a/models/pix2pixHD_model.py
+++ b/models/pix2pixHD_model.py
@@ -19,10 +19,11 @@ class Pix2PixHDModel(BaseModel):
self.isTrain = opt.isTrain
self.use_features = opt.instance_feat or opt.label_feat
self.gen_features = self.use_features and not self.opt.load_features
+ input_nc = opt.label_nc if opt.label_nc != 0 else 3
##### define networks
# Generator network
- netG_input_nc = opt.label_nc
+ netG_input_nc = input_nc
if not opt.no_instance:
netG_input_nc += 1
if self.use_features:
@@ -34,7 +35,7 @@ class Pix2PixHDModel(BaseModel):
# Discriminator network
if self.isTrain:
use_sigmoid = opt.no_lsgan
- netD_input_nc = opt.label_nc + opt.output_nc
+ netD_input_nc = input_nc + opt.output_nc
if not opt.no_instance:
netD_input_nc += 1
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
@@ -93,12 +94,15 @@ class Pix2PixHDModel(BaseModel):
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
- def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
- # create one-hot vector for label map
- size = label_map.size()
- oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
- input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
- input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
+ def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
+ if self.opt.label_nc == 0:
+ input_label = label_map.data.cuda()
+ else:
+ # create one-hot vector for label map
+ size = label_map.size()
+ oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
+ input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
+ input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
# get edges from instance map
if not self.opt.no_instance:
diff --git a/util/util.py b/util/util.py
index 95d1315..f5ed60a 100755
--- a/util/util.py
+++ b/util/util.py
@@ -24,13 +24,15 @@ def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
return image_numpy.astype(imtype)
# Converts a one-hot tensor into a colorful label map
-def tensor2label(output, n_label, imtype=np.uint8):
- output = output.cpu().float()
- if output.size()[0] > 1:
- output = output.max(0, keepdim=True)[1]
- output = Colorize(n_label)(output)
- output = np.transpose(output.numpy(), (1, 2, 0))
- return output.astype(imtype)
+def tensor2label(label_tensor, n_label, imtype=np.uint8):
+ if n_label == 0:
+ return tensor2im(label_tensor, imtype)
+ label_tensor = label_tensor.cpu().float()
+ if label_tensor.size()[0] > 1:
+ label_tensor = label_tensor.max(0, keepdim=True)[1]
+ label_tensor = Colorize(n_label)(label_tensor)
+ label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
+ return label_numpy.astype(imtype)
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)