summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-04-18 12:35:39 +0200
committerJules Laplace <julescarbon@gmail.com>2018-04-18 12:35:39 +0200
commitd3044d50514586b0cf9702bc2c16a6486af249f6 (patch)
tree2c05c6d37efe7b23abc93c5419d6b2cfb037e6bb
parente3726af25b83134d6240b926386fa0243f6a6a96 (diff)
setting up recursive dataset
-rw-r--r--data/__init__.py3
-rw-r--r--data/recursive_dataset.py37
-rw-r--r--models/__init__.py2
-rw-r--r--test.py2
4 files changed, 43 insertions, 1 deletions
diff --git a/data/__init__.py b/data/__init__.py
index 341281d..a69f374 100644
--- a/data/__init__.py
+++ b/data/__init__.py
@@ -20,6 +20,9 @@ def CreateDataset(opt):
elif opt.dataset_mode == 'single':
from data.single_dataset import SingleDataset
dataset = SingleDataset()
+ elif opt.dataset_mode == 'recursive':
+ from data.recursive_dataset import RecursiveDataset
+ dataset = RecursiveDataset()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
diff --git a/data/recursive_dataset.py b/data/recursive_dataset.py
new file mode 100644
index 0000000..d85184e
--- /dev/null
+++ b/data/recursive_dataset.py
@@ -0,0 +1,37 @@
+import os.path
+from data.base_dataset import BaseDataset, get_transform
+from data.image_folder import make_dataset
+from PIL import Image
+
+
+class RecursiveDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+ self.next_image = opt.dataroot
+ #self.dir_A = os.path.join(opt.dataroot)
+ #self.A_paths = make_dataset(self.dir_A)
+ #self.A_paths = sorted(self.A_paths)
+
+ self.transform = get_transform(opt)
+
+ def __getitem__(self, index):
+ A_path = self.next_image
+ A_img = Image.open(A_path).convert('RGB')
+ A = self.transform(A_img)
+ if self.opt.which_direction == 'BtoA':
+ input_nc = self.opt.output_nc
+ else:
+ input_nc = self.opt.input_nc
+
+ if input_nc == 1: # RGB to gray
+ tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
+ A = tmp.unsqueeze(0)
+
+ return {'A': A, 'A_paths': A_path}
+
+ def __len__(self):
+ return float("inf")
+
+ def name(self):
+ return 'RecursiveImageDataset'
diff --git a/models/__init__.py b/models/__init__.py
index 681c6de..72a0d2e 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -10,7 +10,7 @@ def create_model(opt):
from .pix2pix_model import Pix2PixModel
model = Pix2PixModel()
elif opt.model == 'test':
- assert(opt.dataset_mode == 'single')
+ assert(opt.dataset_mode == 'single' or opt.dataset_mode == 'recursive')
from .test_model import TestModel
model = TestModel()
else:
diff --git a/test.py b/test.py
index 8444bd9..0a2fb9f 100644
--- a/test.py
+++ b/test.py
@@ -30,5 +30,7 @@ if __name__ == '__main__':
img_path = model.get_image_paths()
print('%04d: process image... %s' % (i, img_path))
visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio)
+ if dataset.name() == 'RecursiveImageDataset':
+ dataset.append()
webpage.save()