summaryrefslogtreecommitdiff
path: root/data/aligned_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/aligned_dataset.py')
-rwxr-xr-xdata/aligned_dataset.py54
1 files changed, 28 insertions, 26 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
index a3cdc76..41468d2 100755
--- a/data/aligned_dataset.py
+++ b/data/aligned_dataset.py
@@ -10,14 +10,16 @@ class AlignedDataset(BaseDataset):
self.opt = opt
self.root = opt.dataroot
- ### label maps
- self.dir_label = os.path.join(opt.dataroot, opt.phase + '_label')
- self.label_paths = sorted(make_dataset(self.dir_label))
+ ### input A (label maps)
+ dir_A = '_A' if self.opt.label_nc == 0 else '_label'
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
+ self.A_paths = sorted(make_dataset(self.dir_A))
- ### real images
+ ### input B (real images)
if opt.isTrain:
- self.dir_image = os.path.join(opt.dataroot, opt.phase + '_img')
- self.image_paths = sorted(make_dataset(self.dir_image))
+ dir_B = '_B' if self.opt.label_nc == 0 else '_img'
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
+ self.B_paths = sorted(make_dataset(self.dir_B))
### instance maps
if not opt.no_instance:
@@ -30,47 +32,47 @@ class AlignedDataset(BaseDataset):
print('----------- loading features from %s ----------' % self.dir_feat)
self.feat_paths = sorted(make_dataset(self.dir_feat))
- self.dataset_size = len(self.label_paths)
+ self.dataset_size = len(self.A_paths)
def __getitem__(self, index):
- ### label maps
- label_path = self.label_paths[index]
- label = Image.open(label_path)
- params = get_params(self.opt, label.size)
+ ### input A (label maps)
+ A_path = self.A_paths[index]
+ A = Image.open(A_path)
+ params = get_params(self.opt, A.size)
if self.opt.label_nc == 0:
- transform_label = get_transform(self.opt, params)
- label_tensor = transform_label(label.convert('RGB'))
+ transform_A = get_transform(self.opt, params)
+ A_tensor = transform_A(A.convert('RGB'))
else:
- transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
- label_tensor = transform_label(label) * 255.0
+ transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
+ A_tensor = transform_A(A) * 255.0
- image_tensor = inst_tensor = feat_tensor = 0
- ### real images
+ B_tensor = inst_tensor = feat_tensor = 0
+ ### input B (real images)
if self.opt.isTrain:
- image_path = self.image_paths[index]
- image = Image.open(image_path).convert('RGB')
- transform_image = get_transform(self.opt, params)
- image_tensor = transform_image(image)
+ B_path = self.B_paths[index]
+ B = Image.open(B_path).convert('RGB')
+ transform_B = get_transform(self.opt, params)
+ B_tensor = transform_B(B)
### if using instance maps
if not self.opt.no_instance:
inst_path = self.inst_paths[index]
inst = Image.open(inst_path)
- inst_tensor = transform_label(inst)
+ inst_tensor = transform_A(inst)
if self.opt.load_features:
feat_path = self.feat_paths[index]
feat = Image.open(feat_path).convert('RGB')
norm = normalize()
- feat_tensor = norm(transform_label(feat))
+ feat_tensor = norm(transform_A(feat))
- input_dict = {'label': label_tensor, 'inst': inst_tensor, 'image': image_tensor,
- 'feat': feat_tensor, 'path': label_path}
+ input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
+ 'feat': feat_tensor, 'path': A_path}
return input_dict
def __len__(self):
- return len(self.label_paths)
+ return len(self.A_paths)
def name(self):
return 'AlignedDataset' \ No newline at end of file