summaryrefslogtreecommitdiff
path: root/data/single_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/single_dataset.py')
-rw-r--r--data/single_dataset.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/data/single_dataset.py b/data/single_dataset.py
index faf416a..f8b4f1d 100644
--- a/data/single_dataset.py
+++ b/data/single_dataset.py
@@ -19,12 +19,18 @@ class SingleDataset(BaseDataset):
def __getitem__(self, index):
A_path = self.A_paths[index]
-
A_img = Image.open(A_path).convert('RGB')
+ A = self.transform(A_img)
+ if self.opt.which_direction == 'BtoA':
+ input_nc = self.opt.output_nc
+ else:
+ input_nc = self.opt.input_nc
- A_img = self.transform(A_img)
+ if input_nc == 1: # RGB to gray
+ tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
+ A = tmp.unsqueeze(0)
- return {'A': A_img, 'A_paths': A_path}
+ return {'A': A, 'A_paths': A_path}
def __len__(self):
return len(self.A_paths)