diff options
Diffstat (limited to 'data/aligned_dataset.py')
| -rw-r--r-- | data/aligned_dataset.py | 15 |
1 files changed, 5 insertions, 10 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py index bccd6fc..8899cb2 100644 --- a/data/aligned_dataset.py +++ b/data/aligned_dataset.py @@ -12,22 +12,14 @@ class AlignedDataset(BaseDataset): self.opt = opt self.root = opt.dataroot self.dir_AB = os.path.join(opt.dataroot, opt.phase) - self.AB_paths = sorted(make_dataset(self.dir_AB)) - assert(opt.resize_or_crop == 'resize_and_crop') - transform_list = [transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] - - self.transform = transforms.Compose(transform_list) - def __getitem__(self, index): AB_path = self.AB_paths[index] AB = Image.open(AB_path).convert('RGB') AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC) - AB = self.transform(AB) + AB = transforms.ToTensor()(AB) w_total = AB.size(2) w = int(w_total / 2) @@ -40,6 +32,9 @@ class AlignedDataset(BaseDataset): B = AB[:, h_offset:h_offset + self.opt.fineSize, w + w_offset:w + w_offset + self.opt.fineSize] + A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) + B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B) + if self.opt.which_direction == 'BtoA': input_nc = self.opt.output_nc output_nc = self.opt.input_nc @@ -60,7 +55,7 @@ class AlignedDataset(BaseDataset): if output_nc == 1: # RGB to gray tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 B = tmp.unsqueeze(0) - + return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} |
