summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-04-18 13:20:56 +0200
committerJules Laplace <julescarbon@gmail.com>2018-04-18 13:20:56 +0200
commit7f8b0de93cf8a12d5408db200a0ad2459cb0fee5 (patch)
tree090b479f854e1f83802b74d21b63e6b8fab5c58a /data
parent146014b8e7ff0e9fd1020e62cabcce471f988728 (diff)
k
Diffstat (limited to 'data')
-rw-r--r--data/__init__.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/data/__init__.py b/data/__init__.py
index a69f374..420c1f8 100644
--- a/data/__init__.py
+++ b/data/__init__.py
@@ -8,6 +8,12 @@ def CreateDataLoader(opt):
data_loader.initialize(opt)
return data_loader
+def CreateRecursiveDataLoader(opt):
+ data_loader = RecursiveDatasetDataLoader()
+ print(data_loader.name())
+ data_loader.initialize(opt)
+ return data_loader
+
def CreateDataset(opt):
dataset = None
@@ -55,3 +61,24 @@ class CustomDatasetDataLoader(BaseDataLoader):
if i * self.opt.batchSize >= self.opt.max_dataset_size:
break
yield data
+
+class RecursiveDatasetDataLoader(BaseDataLoader):
+ def name(self):
+ return 'CustomDatasetDataLoader'
+
+ def initialize(self, opt):
+ #BaseDataLoader.initialize(self, opt)
+ self.dataset = CreateDataset(opt)
+ self.dataloader = self.dataset
+
+ def load_data(self):
+ return self
+
+ def __len__(self):
+ return min(len(self.dataset), self.opt.max_dataset_size)
+
+ def __iter__(self):
+ for i, data in enumerate(self.dataloader):
+ if i * self.opt.batchSize >= self.opt.max_dataset_size:
+ break
+ yield data