diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-04-18 12:35:39 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-04-18 12:35:39 +0200 |
| commit | d3044d50514586b0cf9702bc2c16a6486af249f6 (patch) | |
| tree | 2c05c6d37efe7b23abc93c5419d6b2cfb037e6bb | |
| parent | e3726af25b83134d6240b926386fa0243f6a6a96 (diff) | |
setting up recursive dataset
| -rw-r--r-- | data/__init__.py | 3 | ||||
| -rw-r--r-- | data/recursive_dataset.py | 37 | ||||
| -rw-r--r-- | models/__init__.py | 2 | ||||
| -rw-r--r-- | test.py | 2 |
4 files changed, 43 insertions, 1 deletions
diff --git a/data/__init__.py b/data/__init__.py index 341281d..a69f374 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -20,6 +20,9 @@ def CreateDataset(opt): elif opt.dataset_mode == 'single': from data.single_dataset import SingleDataset dataset = SingleDataset() + elif opt.dataset_mode == 'recursive': + from data.recursive_dataset import RecursiveDataset + dataset = RecursiveDataset() else: raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) diff --git a/data/recursive_dataset.py b/data/recursive_dataset.py new file mode 100644 index 0000000..d85184e --- /dev/null +++ b/data/recursive_dataset.py @@ -0,0 +1,37 @@ +import os.path +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image + + +class RecursiveDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.next_image = opt.dataroot + #self.dir_A = os.path.join(opt.dataroot) + #self.A_paths = make_dataset(self.dir_A) + #self.A_paths = sorted(self.A_paths) + + self.transform = get_transform(opt) + + def __getitem__(self, index): + A_path = self.next_image + A_img = Image.open(A_path).convert('RGB') + A = self.transform(A_img) + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + else: + input_nc = self.opt.input_nc + + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) + + return {'A': A, 'A_paths': A_path} + + def __len__(self): + return float("inf") + + def name(self): + return 'RecursiveImageDataset' diff --git a/models/__init__.py b/models/__init__.py index 681c6de..72a0d2e 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -10,7 +10,7 @@ def create_model(opt): from .pix2pix_model import Pix2PixModel model = Pix2PixModel() elif opt.model == 'test': - assert(opt.dataset_mode == 'single') + assert(opt.dataset_mode == 'single' or opt.dataset_mode == 'recursive') from .test_model import TestModel model = TestModel() else: @@ -30,5 +30,7 @@ if __name__ == '__main__': img_path = model.get_image_paths() print('%04d: process image... %s' % (i, img_path)) visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio) + if dataset.name() == 'RecursiveImageDataset': + dataset.append() webpage.save() |
