diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-04-18 12:35:39 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-04-18 12:35:39 +0200 |
| commit | d3044d50514586b0cf9702bc2c16a6486af249f6 (patch) | |
| tree | 2c05c6d37efe7b23abc93c5419d6b2cfb037e6bb /data/recursive_dataset.py | |
| parent | e3726af25b83134d6240b926386fa0243f6a6a96 (diff) | |
setting up recursive dataset
Diffstat (limited to 'data/recursive_dataset.py')
| -rw-r--r-- | data/recursive_dataset.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/data/recursive_dataset.py b/data/recursive_dataset.py new file mode 100644 index 0000000..d85184e --- /dev/null +++ b/data/recursive_dataset.py @@ -0,0 +1,37 @@ +import os.path +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image + + +class RecursiveDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.next_image = opt.dataroot + #self.dir_A = os.path.join(opt.dataroot) + #self.A_paths = make_dataset(self.dir_A) + #self.A_paths = sorted(self.A_paths) + + self.transform = get_transform(opt) + + def __getitem__(self, index): + A_path = self.next_image + A_img = Image.open(A_path).convert('RGB') + A = self.transform(A_img) + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + else: + input_nc = self.opt.input_nc + + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) + + return {'A': A, 'A_paths': A_path} + + def __len__(self): + return float("inf") + + def name(self): + return 'RecursiveImageDataset' |
