diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-04-18 13:20:56 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-04-18 13:20:56 +0200 |
| commit | 7f8b0de93cf8a12d5408db200a0ad2459cb0fee5 (patch) | |
| tree | 090b479f854e1f83802b74d21b63e6b8fab5c58a | |
| parent | 146014b8e7ff0e9fd1020e62cabcce471f988728 (diff) | |
k
| -rw-r--r-- | data/__init__.py | 27 | ||||
| -rw-r--r-- | test.py | 3 |
2 files changed, 29 insertions, 1 deletions
diff --git a/data/__init__.py b/data/__init__.py index a69f374..420c1f8 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -8,6 +8,12 @@ def CreateDataLoader(opt): data_loader.initialize(opt) return data_loader +def CreateRecursiveDataLoader(opt): + data_loader = RecursiveDatasetDataLoader() + print(data_loader.name()) + data_loader.initialize(opt) + return data_loader + def CreateDataset(opt): dataset = None @@ -55,3 +61,24 @@ class CustomDatasetDataLoader(BaseDataLoader): if i * self.opt.batchSize >= self.opt.max_dataset_size: break yield data + +class RecursiveDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + #BaseDataLoader.initialize(self, opt) + self.dataset = CreateDataset(opt) + self.dataloader = self.dataset + + def load_data(self): + return self + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + for i, data in enumerate(self.dataloader): + if i * self.opt.batchSize >= self.opt.max_dataset_size: + break + yield data @@ -21,7 +21,8 @@ if __name__ == '__main__': web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) # test - for i, data in enumerate(dataset): + print(dataset.name()) + for i, data in enumerate(data_loader): if i >= opt.how_many: break model.set_input(data) |
