diff options
| -rwxr-xr-x | README.md | 8 | ||||
| -rwxr-xr-x | data/aligned_dataset.py | 12 | ||||
| -rwxr-xr-x | models/pix2pixHD_model.py | 20 | ||||
| -rwxr-xr-x | util/util.py | 16 |
4 files changed, 36 insertions, 20 deletions
@@ -107,7 +107,13 @@ Note: this is not tested and we trained our model using single GPU only. Please - To train the images at full resolution (2048 x 1024) requires a GPU with 24G memory (`bash ./scripts/train_1024p_24G.sh`).
If only GPUs with 12G memory are available, please use the 12G script (`bash ./scripts/train_1024p_12G.sh`), which will crop the images during training. Performance is not guaranteed using this script.
-## More Training/test Details
+### Training with your own dataset
+- If you want to train with your own dataset, please generate label maps which are one-channel whose pixel values correspond to the object labels (i.e. 0,1,...,N-1, where N is the number of labels). This is because we need to generate one-hot vectors from the label maps. Please also specity `--label_nc N` during both training and testing.
+- If your input is not a label map, please just specify `--label_nc 0` which will directly use the RGB colors as input.
+- If you don't have instance maps or don't want to use them, please specify `--no_instance`.
+- The default setting for preprocessing is `scale_width`, which will scale the width of all training images to `opt.loadSize` (1024) while keeping the aspect ratio. If you want a different setting, please change it by using the `--resize_or_crop` option. For example, `scale_width_and_crop` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. `crop` skips the resizing step and only performs random cropping. If you don't want any preprocessing, please specify `none`, which will do nothing other than making sure the image is divisible by 32.
+
+## More Training/Test Details
- Flags: see `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags.
- Instance map: we take in both label maps and instance maps as input. If you don't want to use instance maps, please specify the flag `--no_instance`.
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py index a0c9a0a..a3cdc76 100755 --- a/data/aligned_dataset.py +++ b/data/aligned_dataset.py @@ -33,12 +33,16 @@ class AlignedDataset(BaseDataset): self.dataset_size = len(self.label_paths) def __getitem__(self, index): - ### label maps + ### label maps label_path = self.label_paths[index] label = Image.open(label_path) - params = get_params(self.opt, label.size) - transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) - label_tensor = transform_label(label) * 255.0 + params = get_params(self.opt, label.size) + if self.opt.label_nc == 0: + transform_label = get_transform(self.opt, params) + label_tensor = transform_label(label.convert('RGB')) + else: + transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) + label_tensor = transform_label(label) * 255.0 image_tensor = inst_tensor = feat_tensor = 0 ### real images diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py index fea9a1d..834fc18 100755 --- a/models/pix2pixHD_model.py +++ b/models/pix2pixHD_model.py @@ -19,10 +19,11 @@ class Pix2PixHDModel(BaseModel): self.isTrain = opt.isTrain self.use_features = opt.instance_feat or opt.label_feat self.gen_features = self.use_features and not self.opt.load_features + input_nc = opt.label_nc if opt.label_nc != 0 else 3 ##### define networks # Generator network - netG_input_nc = opt.label_nc + netG_input_nc = input_nc if not opt.no_instance: netG_input_nc += 1 if self.use_features: @@ -34,7 +35,7 @@ class Pix2PixHDModel(BaseModel): # Discriminator network if self.isTrain: use_sigmoid = opt.no_lsgan - netD_input_nc = opt.label_nc + opt.output_nc + netD_input_nc = input_nc + opt.output_nc if not opt.no_instance: netD_input_nc += 1 self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, @@ -93,12 +94,15 @@ class Pix2PixHDModel(BaseModel): params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) - def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): - # create one-hot vector for label map - size = label_map.size() - oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) - input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() - input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) + def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): + if self.opt.label_nc == 0: + input_label = label_map.data.cuda() + else: + # create one-hot vector for label map + size = label_map.size() + oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) + input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() + input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) # get edges from instance map if not self.opt.no_instance: diff --git a/util/util.py b/util/util.py index 95d1315..f5ed60a 100755 --- a/util/util.py +++ b/util/util.py @@ -24,13 +24,15 @@ def tensor2im(image_tensor, imtype=np.uint8, normalize=True): return image_numpy.astype(imtype) # Converts a one-hot tensor into a colorful label map -def tensor2label(output, n_label, imtype=np.uint8): - output = output.cpu().float() - if output.size()[0] > 1: - output = output.max(0, keepdim=True)[1] - output = Colorize(n_label)(output) - output = np.transpose(output.numpy(), (1, 2, 0)) - return output.astype(imtype) +def tensor2label(label_tensor, n_label, imtype=np.uint8): + if n_label == 0: + return tensor2im(label_tensor, imtype) + label_tensor = label_tensor.cpu().float() + if label_tensor.size()[0] > 1: + label_tensor = label_tensor.max(0, keepdim=True)[1] + label_tensor = Colorize(n_label)(label_tensor) + label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) + return label_numpy.astype(imtype) def save_image(image_numpy, image_path): image_pil = Image.fromarray(image_numpy) |
