diff options
| author | tingchunw <tingchunw@nvidia.com> | 2017-12-04 16:52:46 -0800 |
|---|---|---|
| committer | tingchunw <tingchunw@nvidia.com> | 2017-12-04 16:52:46 -0800 |
| commit | 9054cf9b0c327a5077fd0793abe178f400da3315 (patch) | |
| tree | 3c69c07bdcba86c47d8442648fd69c0434e04136 /data/custom_dataset_data_loader.py | |
| parent | f9e9999541d67a908a169cc88407675133130e1f (diff) | |
first commit
Diffstat (limited to 'data/custom_dataset_data_loader.py')
| -rwxr-xr-x | data/custom_dataset_data_loader.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py new file mode 100755 index 0000000..0b98254 --- /dev/null +++ b/data/custom_dataset_data_loader.py @@ -0,0 +1,31 @@ +import torch.utils.data +from data.base_data_loader import BaseDataLoader + + +def CreateDataset(opt): + dataset = None + from data.aligned_dataset import AlignedDataset + dataset = AlignedDataset() + + print("dataset [%s] was created" % (dataset.name())) + dataset.initialize(opt) + return dataset + +class CustomDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.dataset = CreateDataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads)) + + def load_data(self): + return self.dataloader + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) |
