summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data/__init__.py5
-rw-r--r--data/recursive_dataset.py37
-rw-r--r--models/__init__.py2
-rw-r--r--models/test_model.py2
-rw-r--r--test.py5
-rw-r--r--util/visualizer.py1
6 files changed, 48 insertions, 4 deletions
diff --git a/data/__init__.py b/data/__init__.py
index ef581e7..a69f374 100644
--- a/data/__init__.py
+++ b/data/__init__.py
@@ -20,6 +20,9 @@ def CreateDataset(opt):
elif opt.dataset_mode == 'single':
from data.single_dataset import SingleDataset
dataset = SingleDataset()
+ elif opt.dataset_mode == 'recursive':
+ from data.recursive_dataset import RecursiveDataset
+ dataset = RecursiveDataset()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
@@ -49,6 +52,6 @@ class CustomDatasetDataLoader(BaseDataLoader):
def __iter__(self):
for i, data in enumerate(self.dataloader):
- if i >= self.opt.max_dataset_size:
+ if i * self.opt.batchSize >= self.opt.max_dataset_size:
break
yield data
diff --git a/data/recursive_dataset.py b/data/recursive_dataset.py
new file mode 100644
index 0000000..d85184e
--- /dev/null
+++ b/data/recursive_dataset.py
@@ -0,0 +1,37 @@
+import os.path
+from data.base_dataset import BaseDataset, get_transform
+from data.image_folder import make_dataset
+from PIL import Image
+
+
+class RecursiveDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+ self.next_image = opt.dataroot
+ #self.dir_A = os.path.join(opt.dataroot)
+ #self.A_paths = make_dataset(self.dir_A)
+ #self.A_paths = sorted(self.A_paths)
+
+ self.transform = get_transform(opt)
+
+ def __getitem__(self, index):
+ A_path = self.next_image
+ A_img = Image.open(A_path).convert('RGB')
+ A = self.transform(A_img)
+ if self.opt.which_direction == 'BtoA':
+ input_nc = self.opt.output_nc
+ else:
+ input_nc = self.opt.input_nc
+
+ if input_nc == 1: # RGB to gray
+ tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
+ A = tmp.unsqueeze(0)
+
+ return {'A': A, 'A_paths': A_path}
+
+ def __len__(self):
+ return float("inf")
+
+ def name(self):
+ return 'RecursiveImageDataset'
diff --git a/models/__init__.py b/models/__init__.py
index 681c6de..72a0d2e 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -10,7 +10,7 @@ def create_model(opt):
from .pix2pix_model import Pix2PixModel
model = Pix2PixModel()
elif opt.model == 'test':
- assert(opt.dataset_mode == 'single')
+ assert(opt.dataset_mode == 'single' or opt.dataset_mode == 'recursive')
from .test_model import TestModel
model = TestModel()
else:
diff --git a/models/test_model.py b/models/test_model.py
index f593c46..5dd4fb9 100644
--- a/models/test_model.py
+++ b/models/test_model.py
@@ -33,7 +33,7 @@ class TestModel(BaseModel):
self.image_paths = input['A_paths']
def test(self):
- self.real_A = Variable(self.input_A)
+ self.real_A = Variable(self.input_A, volatile=True)
self.fake_B = self.netG(self.real_A)
# get image paths
diff --git a/test.py b/test.py
index 8444bd9..651a164 100644
--- a/test.py
+++ b/test.py
@@ -29,6 +29,9 @@ if __name__ == '__main__':
visuals = model.get_current_visuals()
img_path = model.get_image_paths()
print('%04d: process image... %s' % (i, img_path))
- visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio)
+ ims = visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio)
+ if dataset.name() == 'RecursiveImageDataset':
+ # dataset.append(ims)
+ print ims
webpage.save()
diff --git a/util/visualizer.py b/util/visualizer.py
index a98b512..35ea0be 100644
--- a/util/visualizer.py
+++ b/util/visualizer.py
@@ -148,3 +148,4 @@ class Visualizer():
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=self.win_size)
+ return ims