summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data/__init__.py27
-rw-r--r--test.py3
2 files changed, 29 insertions, 1 deletions
diff --git a/data/__init__.py b/data/__init__.py
index a69f374..420c1f8 100644
--- a/data/__init__.py
+++ b/data/__init__.py
@@ -8,6 +8,12 @@ def CreateDataLoader(opt):
data_loader.initialize(opt)
return data_loader
+def CreateRecursiveDataLoader(opt):
+ data_loader = RecursiveDatasetDataLoader()
+ print(data_loader.name())
+ data_loader.initialize(opt)
+ return data_loader
+
def CreateDataset(opt):
dataset = None
@@ -55,3 +61,24 @@ class CustomDatasetDataLoader(BaseDataLoader):
if i * self.opt.batchSize >= self.opt.max_dataset_size:
break
yield data
+
+class RecursiveDatasetDataLoader(BaseDataLoader):
+ def name(self):
+ return 'CustomDatasetDataLoader'
+
+ def initialize(self, opt):
+ #BaseDataLoader.initialize(self, opt)
+ self.dataset = CreateDataset(opt)
+ self.dataloader = self.dataset
+
+ def load_data(self):
+ return self
+
+ def __len__(self):
+ return min(len(self.dataset), self.opt.max_dataset_size)
+
+ def __iter__(self):
+ for i, data in enumerate(self.dataloader):
+ if i * self.opt.batchSize >= self.opt.max_dataset_size:
+ break
+ yield data
diff --git a/test.py b/test.py
index 78d0c98..1ddb894 100644
--- a/test.py
+++ b/test.py
@@ -21,7 +21,8 @@ if __name__ == '__main__':
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
- for i, data in enumerate(dataset):
+ print(dataset.name())
+ for i, data in enumerate(data_loader):
if i >= opt.how_many:
break
model.set_input(data)