summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/aligned_dataset.py15
-rw-r--r--data/single_dataset.py12
-rw-r--r--data/unaligned_dataset.py20
3 files changed, 40 insertions, 7 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
index 0f45c40..bccd6fc 100644
--- a/data/aligned_dataset.py
+++ b/data/aligned_dataset.py
@@ -40,12 +40,27 @@ class AlignedDataset(BaseDataset):
B = AB[:, h_offset:h_offset + self.opt.fineSize,
w + w_offset:w + w_offset + self.opt.fineSize]
+ if self.opt.which_direction == 'BtoA':
+ input_nc = self.opt.output_nc
+ output_nc = self.opt.input_nc
+ else:
+ input_nc = self.opt.input_nc
+ output_nc = self.opt.output_nc
+
if (not self.opt.no_flip) and random.random() < 0.5:
idx = [i for i in range(A.size(2) - 1, -1, -1)]
idx = torch.LongTensor(idx)
A = A.index_select(2, idx)
B = B.index_select(2, idx)
+ if input_nc == 1: # RGB to gray
+ tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
+ A = tmp.unsqueeze(0)
+
+ if output_nc == 1: # RGB to gray
+ tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
+ B = tmp.unsqueeze(0)
+
return {'A': A, 'B': B,
'A_paths': AB_path, 'B_paths': AB_path}
diff --git a/data/single_dataset.py b/data/single_dataset.py
index faf416a..f8b4f1d 100644
--- a/data/single_dataset.py
+++ b/data/single_dataset.py
@@ -19,12 +19,18 @@ class SingleDataset(BaseDataset):
def __getitem__(self, index):
A_path = self.A_paths[index]
-
A_img = Image.open(A_path).convert('RGB')
+ A = self.transform(A_img)
+ if self.opt.which_direction == 'BtoA':
+ input_nc = self.opt.output_nc
+ else:
+ input_nc = self.opt.input_nc
- A_img = self.transform(A_img)
+ if input_nc == 1: # RGB to gray
+ tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
+ A = tmp.unsqueeze(0)
- return {'A': A_img, 'A_paths': A_path}
+ return {'A': A, 'A_paths': A_path}
def __len__(self):
return len(self.A_paths)
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
index d31eb05..c5e5460 100644
--- a/data/unaligned_dataset.py
+++ b/data/unaligned_dataset.py
@@ -4,7 +4,6 @@ from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import PIL
-from pdb import set_trace as st
import random
class UnalignedDataset(BaseDataset):
@@ -32,10 +31,23 @@ class UnalignedDataset(BaseDataset):
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
- A_img = self.transform(A_img)
- B_img = self.transform(B_img)
+ A = self.transform(A_img)
+ B = self.transform(B_img)
+ if self.opt.which_direction == 'BtoA':
+ input_nc = self.opt.output_nc
+ output_nc = self.opt.input_nc
+ else:
+ input_nc = self.opt.input_nc
+ output_nc = self.opt.output_nc
- return {'A': A_img, 'B': B_img,
+ if input_nc == 1: # RGB to gray
+ tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
+ A = tmp.unsqueeze(0)
+
+ if output_nc == 1: # RGB to gray
+ tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
+ B = tmp.unsqueeze(0)
+ return {'A': A, 'B': B,
'A_paths': A_path, 'B_paths': B_path}
def __len__(self):