From d3044d50514586b0cf9702bc2c16a6486af249f6 Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Wed, 18 Apr 2018 12:35:39 +0200 Subject: setting up recursive dataset --- data/__init__.py | 3 +++ data/recursive_dataset.py | 37 +++++++++++++++++++++++++++++++++++++ models/__init__.py | 2 +- test.py | 2 ++ 4 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 data/recursive_dataset.py diff --git a/data/__init__.py b/data/__init__.py index 341281d..a69f374 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -20,6 +20,9 @@ def CreateDataset(opt): elif opt.dataset_mode == 'single': from data.single_dataset import SingleDataset dataset = SingleDataset() + elif opt.dataset_mode == 'recursive': + from data.recursive_dataset import RecursiveDataset + dataset = RecursiveDataset() else: raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) diff --git a/data/recursive_dataset.py b/data/recursive_dataset.py new file mode 100644 index 0000000..d85184e --- /dev/null +++ b/data/recursive_dataset.py @@ -0,0 +1,37 @@ +import os.path +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image + + +class RecursiveDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.next_image = opt.dataroot + #self.dir_A = os.path.join(opt.dataroot) + #self.A_paths = make_dataset(self.dir_A) + #self.A_paths = sorted(self.A_paths) + + self.transform = get_transform(opt) + + def __getitem__(self, index): + A_path = self.next_image + A_img = Image.open(A_path).convert('RGB') + A = self.transform(A_img) + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + else: + input_nc = self.opt.input_nc + + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) + + return {'A': A, 'A_paths': A_path} + + def __len__(self): + return float("inf") + + def name(self): + return 'RecursiveImageDataset' diff --git a/models/__init__.py b/models/__init__.py index 681c6de..72a0d2e 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -10,7 +10,7 @@ def create_model(opt): from .pix2pix_model import Pix2PixModel model = Pix2PixModel() elif opt.model == 'test': - assert(opt.dataset_mode == 'single') + assert(opt.dataset_mode == 'single' or opt.dataset_mode == 'recursive') from .test_model import TestModel model = TestModel() else: diff --git a/test.py b/test.py index 8444bd9..0a2fb9f 100644 --- a/test.py +++ b/test.py @@ -30,5 +30,7 @@ if __name__ == '__main__': img_path = model.get_image_paths() print('%04d: process image... %s' % (i, img_path)) visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio) + if dataset.name() == 'RecursiveImageDataset': + dataset.append() webpage.save() -- cgit v1.2.3-70-g09d2