summaryrefslogtreecommitdiff
path: root/data/recursive_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/recursive_dataset.py')
-rw-r--r--data/recursive_dataset.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/data/recursive_dataset.py b/data/recursive_dataset.py
new file mode 100644
index 0000000..d85184e
--- /dev/null
+++ b/data/recursive_dataset.py
@@ -0,0 +1,37 @@
+import os.path
+from data.base_dataset import BaseDataset, get_transform
+from data.image_folder import make_dataset
+from PIL import Image
+
+
+class RecursiveDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+ self.next_image = opt.dataroot
+ #self.dir_A = os.path.join(opt.dataroot)
+ #self.A_paths = make_dataset(self.dir_A)
+ #self.A_paths = sorted(self.A_paths)
+
+ self.transform = get_transform(opt)
+
+ def __getitem__(self, index):
+ A_path = self.next_image
+ A_img = Image.open(A_path).convert('RGB')
+ A = self.transform(A_img)
+ if self.opt.which_direction == 'BtoA':
+ input_nc = self.opt.output_nc
+ else:
+ input_nc = self.opt.input_nc
+
+ if input_nc == 1: # RGB to gray
+ tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
+ A = tmp.unsqueeze(0)
+
+ return {'A': A, 'A_paths': A_path}
+
+ def __len__(self):
+ return float("inf")
+
+ def name(self):
+ return 'RecursiveImageDataset'