diff options
| -rw-r--r-- | data/__init__.py | 5 | ||||
| -rw-r--r-- | data/recursive_dataset.py | 37 | ||||
| -rw-r--r-- | models/__init__.py | 2 | ||||
| -rw-r--r-- | models/test_model.py | 2 | ||||
| -rw-r--r-- | test.py | 5 | ||||
| -rw-r--r-- | util/visualizer.py | 1 |
6 files changed, 48 insertions, 4 deletions
diff --git a/data/__init__.py b/data/__init__.py index ef581e7..a69f374 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -20,6 +20,9 @@ def CreateDataset(opt): elif opt.dataset_mode == 'single': from data.single_dataset import SingleDataset dataset = SingleDataset() + elif opt.dataset_mode == 'recursive': + from data.recursive_dataset import RecursiveDataset + dataset = RecursiveDataset() else: raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) @@ -49,6 +52,6 @@ class CustomDatasetDataLoader(BaseDataLoader): def __iter__(self): for i, data in enumerate(self.dataloader): - if i >= self.opt.max_dataset_size: + if i * self.opt.batchSize >= self.opt.max_dataset_size: break yield data diff --git a/data/recursive_dataset.py b/data/recursive_dataset.py new file mode 100644 index 0000000..d85184e --- /dev/null +++ b/data/recursive_dataset.py @@ -0,0 +1,37 @@ +import os.path +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image + + +class RecursiveDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.next_image = opt.dataroot + #self.dir_A = os.path.join(opt.dataroot) + #self.A_paths = make_dataset(self.dir_A) + #self.A_paths = sorted(self.A_paths) + + self.transform = get_transform(opt) + + def __getitem__(self, index): + A_path = self.next_image + A_img = Image.open(A_path).convert('RGB') + A = self.transform(A_img) + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + else: + input_nc = self.opt.input_nc + + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) + + return {'A': A, 'A_paths': A_path} + + def __len__(self): + return float("inf") + + def name(self): + return 'RecursiveImageDataset' diff --git a/models/__init__.py b/models/__init__.py index 681c6de..72a0d2e 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -10,7 +10,7 @@ def create_model(opt): from .pix2pix_model import Pix2PixModel model = Pix2PixModel() elif opt.model == 'test': - assert(opt.dataset_mode == 'single') + assert(opt.dataset_mode == 'single' or opt.dataset_mode == 'recursive') from .test_model import TestModel model = TestModel() else: diff --git a/models/test_model.py b/models/test_model.py index f593c46..5dd4fb9 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -33,7 +33,7 @@ class TestModel(BaseModel): self.image_paths = input['A_paths'] def test(self): - self.real_A = Variable(self.input_A) + self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG(self.real_A) # get image paths @@ -29,6 +29,9 @@ if __name__ == '__main__': visuals = model.get_current_visuals() img_path = model.get_image_paths() print('%04d: process image... %s' % (i, img_path)) - visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio) + ims = visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio) + if dataset.name() == 'RecursiveImageDataset': + # dataset.append(ims) + print ims webpage.save() diff --git a/util/visualizer.py b/util/visualizer.py index a98b512..35ea0be 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -148,3 +148,4 @@ class Visualizer(): txts.append(label) links.append(image_name) webpage.add_images(ims, txts, links, width=self.win_size) + return ims |
