summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-04-18 12:35:39 +0200
committerJules Laplace <julescarbon@gmail.com>2018-04-18 12:35:39 +0200
commitd3044d50514586b0cf9702bc2c16a6486af249f6 (patch)
tree2c05c6d37efe7b23abc93c5419d6b2cfb037e6bb /data
parente3726af25b83134d6240b926386fa0243f6a6a96 (diff)
setting up recursive dataset
Diffstat (limited to 'data')
-rw-r--r--data/__init__.py3
-rw-r--r--data/recursive_dataset.py37
2 files changed, 40 insertions, 0 deletions
diff --git a/data/__init__.py b/data/__init__.py
index 341281d..a69f374 100644
--- a/data/__init__.py
+++ b/data/__init__.py
@@ -20,6 +20,9 @@ def CreateDataset(opt):
elif opt.dataset_mode == 'single':
from data.single_dataset import SingleDataset
dataset = SingleDataset()
+ elif opt.dataset_mode == 'recursive':
+ from data.recursive_dataset import RecursiveDataset
+ dataset = RecursiveDataset()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
diff --git a/data/recursive_dataset.py b/data/recursive_dataset.py
new file mode 100644
index 0000000..d85184e
--- /dev/null
+++ b/data/recursive_dataset.py
@@ -0,0 +1,37 @@
+import os.path
+from data.base_dataset import BaseDataset, get_transform
+from data.image_folder import make_dataset
+from PIL import Image
+
+
+class RecursiveDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+ self.next_image = opt.dataroot
+ #self.dir_A = os.path.join(opt.dataroot)
+ #self.A_paths = make_dataset(self.dir_A)
+ #self.A_paths = sorted(self.A_paths)
+
+ self.transform = get_transform(opt)
+
+ def __getitem__(self, index):
+ A_path = self.next_image
+ A_img = Image.open(A_path).convert('RGB')
+ A = self.transform(A_img)
+ if self.opt.which_direction == 'BtoA':
+ input_nc = self.opt.output_nc
+ else:
+ input_nc = self.opt.input_nc
+
+ if input_nc == 1: # RGB to gray
+ tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
+ A = tmp.unsqueeze(0)
+
+ return {'A': A, 'A_paths': A_path}
+
+ def __len__(self):
+ return float("inf")
+
+ def name(self):
+ return 'RecursiveImageDataset'