summaryrefslogtreecommitdiff
path: root/data/aligned_dataset.py
diff options
context:
space:
mode:
authortingchunw <tingchunw@nvidia.com>2017-12-09 01:02:39 +0000
committertingchunw <tingchunw@nvidia.com>2017-12-09 01:02:39 +0000
commitdb4a24df4a8482089bcb88d0e79c9b484307fc75 (patch)
tree50f05acb700ff3b71e94d937ccd1b57da165cea0 /data/aligned_dataset.py
parent99d031b469478434ea185e1da07f12b7b007c6b6 (diff)
add explanation for training with new dataset
Diffstat (limited to 'data/aligned_dataset.py')
-rwxr-xr-xdata/aligned_dataset.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
index a0c9a0a..a3cdc76 100755
--- a/data/aligned_dataset.py
+++ b/data/aligned_dataset.py
@@ -33,12 +33,16 @@ class AlignedDataset(BaseDataset):
self.dataset_size = len(self.label_paths)
def __getitem__(self, index):
- ### label maps
+ ### label maps
label_path = self.label_paths[index]
label = Image.open(label_path)
- params = get_params(self.opt, label.size)
- transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
- label_tensor = transform_label(label) * 255.0
+ params = get_params(self.opt, label.size)
+ if self.opt.label_nc == 0:
+ transform_label = get_transform(self.opt, params)
+ label_tensor = transform_label(label.convert('RGB'))
+ else:
+ transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
+ label_tensor = transform_label(label) * 255.0
image_tensor = inst_tensor = feat_tensor = 0
### real images