diff options
Diffstat (limited to 'data/aligned_dataset.py')
| -rwxr-xr-x | data/aligned_dataset.py | 54 |
1 files changed, 28 insertions, 26 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py index a3cdc76..41468d2 100755 --- a/data/aligned_dataset.py +++ b/data/aligned_dataset.py @@ -10,14 +10,16 @@ class AlignedDataset(BaseDataset): self.opt = opt self.root = opt.dataroot - ### label maps - self.dir_label = os.path.join(opt.dataroot, opt.phase + '_label') - self.label_paths = sorted(make_dataset(self.dir_label)) + ### input A (label maps) + dir_A = '_A' if self.opt.label_nc == 0 else '_label' + self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A) + self.A_paths = sorted(make_dataset(self.dir_A)) - ### real images + ### input B (real images) if opt.isTrain: - self.dir_image = os.path.join(opt.dataroot, opt.phase + '_img') - self.image_paths = sorted(make_dataset(self.dir_image)) + dir_B = '_B' if self.opt.label_nc == 0 else '_img' + self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B) + self.B_paths = sorted(make_dataset(self.dir_B)) ### instance maps if not opt.no_instance: @@ -30,47 +32,47 @@ class AlignedDataset(BaseDataset): print('----------- loading features from %s ----------' % self.dir_feat) self.feat_paths = sorted(make_dataset(self.dir_feat)) - self.dataset_size = len(self.label_paths) + self.dataset_size = len(self.A_paths) def __getitem__(self, index): - ### label maps - label_path = self.label_paths[index] - label = Image.open(label_path) - params = get_params(self.opt, label.size) + ### input A (label maps) + A_path = self.A_paths[index] + A = Image.open(A_path) + params = get_params(self.opt, A.size) if self.opt.label_nc == 0: - transform_label = get_transform(self.opt, params) - label_tensor = transform_label(label.convert('RGB')) + transform_A = get_transform(self.opt, params) + A_tensor = transform_A(A.convert('RGB')) else: - transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) - label_tensor = transform_label(label) * 255.0 + transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) + A_tensor = transform_A(A) * 255.0 - image_tensor = inst_tensor = feat_tensor = 0 - ### real images + B_tensor = inst_tensor = feat_tensor = 0 + ### input B (real images) if self.opt.isTrain: - image_path = self.image_paths[index] - image = Image.open(image_path).convert('RGB') - transform_image = get_transform(self.opt, params) - image_tensor = transform_image(image) + B_path = self.B_paths[index] + B = Image.open(B_path).convert('RGB') + transform_B = get_transform(self.opt, params) + B_tensor = transform_B(B) ### if using instance maps if not self.opt.no_instance: inst_path = self.inst_paths[index] inst = Image.open(inst_path) - inst_tensor = transform_label(inst) + inst_tensor = transform_A(inst) if self.opt.load_features: feat_path = self.feat_paths[index] feat = Image.open(feat_path).convert('RGB') norm = normalize() - feat_tensor = norm(transform_label(feat)) + feat_tensor = norm(transform_A(feat)) - input_dict = {'label': label_tensor, 'inst': inst_tensor, 'image': image_tensor, - 'feat': feat_tensor, 'path': label_path} + input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, + 'feat': feat_tensor, 'path': A_path} return input_dict def __len__(self): - return len(self.label_paths) + return len(self.A_paths) def name(self): return 'AlignedDataset'
\ No newline at end of file |
