summaryrefslogtreecommitdiff
path: root/data/aligned_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/aligned_dataset.py')
-rwxr-xr-xdata/aligned_dataset.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
index a0c9a0a..a3cdc76 100755
--- a/data/aligned_dataset.py
+++ b/data/aligned_dataset.py
@@ -33,12 +33,16 @@ class AlignedDataset(BaseDataset):
self.dataset_size = len(self.label_paths)
def __getitem__(self, index):
- ### label maps
+ ### label maps
label_path = self.label_paths[index]
label = Image.open(label_path)
- params = get_params(self.opt, label.size)
- transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
- label_tensor = transform_label(label) * 255.0
+ params = get_params(self.opt, label.size)
+ if self.opt.label_nc == 0:
+ transform_label = get_transform(self.opt, params)
+ label_tensor = transform_label(label.convert('RGB'))
+ else:
+ transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
+ label_tensor = transform_label(label) * 255.0
image_tensor = inst_tensor = feat_tensor = 0
### real images