summaryrefslogtreecommitdiff
path: root/data/aligned_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/aligned_dataset.py')
-rw-r--r--data/aligned_dataset.py15
1 files changed, 5 insertions, 10 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
index bccd6fc..8899cb2 100644
--- a/data/aligned_dataset.py
+++ b/data/aligned_dataset.py
@@ -12,22 +12,14 @@ class AlignedDataset(BaseDataset):
self.opt = opt
self.root = opt.dataroot
self.dir_AB = os.path.join(opt.dataroot, opt.phase)
-
self.AB_paths = sorted(make_dataset(self.dir_AB))
-
assert(opt.resize_or_crop == 'resize_and_crop')
- transform_list = [transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5),
- (0.5, 0.5, 0.5))]
-
- self.transform = transforms.Compose(transform_list)
-
def __getitem__(self, index):
AB_path = self.AB_paths[index]
AB = Image.open(AB_path).convert('RGB')
AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC)
- AB = self.transform(AB)
+ AB = transforms.ToTensor()(AB)
w_total = AB.size(2)
w = int(w_total / 2)
@@ -40,6 +32,9 @@ class AlignedDataset(BaseDataset):
B = AB[:, h_offset:h_offset + self.opt.fineSize,
w + w_offset:w + w_offset + self.opt.fineSize]
+ A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A)
+ B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B)
+
if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
output_nc = self.opt.input_nc
@@ -60,7 +55,7 @@ class AlignedDataset(BaseDataset):
if output_nc == 1: # RGB to gray
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
B = tmp.unsqueeze(0)
-
+
return {'A': A, 'B': B,
'A_paths': AB_path, 'B_paths': AB_path}