diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/aligned_dataset.py | 15 | ||||
| -rw-r--r-- | data/single_dataset.py | 12 | ||||
| -rw-r--r-- | data/unaligned_dataset.py | 20 |
3 files changed, 40 insertions, 7 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py index 0f45c40..bccd6fc 100644 --- a/data/aligned_dataset.py +++ b/data/aligned_dataset.py @@ -40,12 +40,27 @@ class AlignedDataset(BaseDataset): B = AB[:, h_offset:h_offset + self.opt.fineSize, w + w_offset:w + w_offset + self.opt.fineSize] + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + output_nc = self.opt.input_nc + else: + input_nc = self.opt.input_nc + output_nc = self.opt.output_nc + if (not self.opt.no_flip) and random.random() < 0.5: idx = [i for i in range(A.size(2) - 1, -1, -1)] idx = torch.LongTensor(idx) A = A.index_select(2, idx) B = B.index_select(2, idx) + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) + + if output_nc == 1: # RGB to gray + tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 + B = tmp.unsqueeze(0) + return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} diff --git a/data/single_dataset.py b/data/single_dataset.py index faf416a..f8b4f1d 100644 --- a/data/single_dataset.py +++ b/data/single_dataset.py @@ -19,12 +19,18 @@ class SingleDataset(BaseDataset): def __getitem__(self, index): A_path = self.A_paths[index] - A_img = Image.open(A_path).convert('RGB') + A = self.transform(A_img) + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + else: + input_nc = self.opt.input_nc - A_img = self.transform(A_img) + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) - return {'A': A_img, 'A_paths': A_path} + return {'A': A, 'A_paths': A_path} def __len__(self): return len(self.A_paths) diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index d31eb05..c5e5460 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -4,7 +4,6 @@ from data.base_dataset import BaseDataset, get_transform from data.image_folder import make_dataset from PIL import Image import PIL -from pdb import set_trace as st import random class UnalignedDataset(BaseDataset): @@ -32,10 +31,23 @@ class UnalignedDataset(BaseDataset): A_img = Image.open(A_path).convert('RGB') B_img = Image.open(B_path).convert('RGB') - A_img = self.transform(A_img) - B_img = self.transform(B_img) + A = self.transform(A_img) + B = self.transform(B_img) + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + output_nc = self.opt.input_nc + else: + input_nc = self.opt.input_nc + output_nc = self.opt.output_nc - return {'A': A_img, 'B': B_img, + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) + + if output_nc == 1: # RGB to gray + tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 + B = tmp.unsqueeze(0) + return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path} def __len__(self): |
