summaryrefslogtreecommitdiff
path: root/data/custom_dataset_data_loader.py
diff options
context:
space:
mode:
authortingchunw <tingchunw@nvidia.com>2017-12-04 16:52:46 -0800
committertingchunw <tingchunw@nvidia.com>2017-12-04 16:52:46 -0800
commit9054cf9b0c327a5077fd0793abe178f400da3315 (patch)
tree3c69c07bdcba86c47d8442648fd69c0434e04136 /data/custom_dataset_data_loader.py
parentf9e9999541d67a908a169cc88407675133130e1f (diff)
first commit
Diffstat (limited to 'data/custom_dataset_data_loader.py')
-rwxr-xr-xdata/custom_dataset_data_loader.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py
new file mode 100755
index 0000000..0b98254
--- /dev/null
+++ b/data/custom_dataset_data_loader.py
@@ -0,0 +1,31 @@
+import torch.utils.data
+from data.base_data_loader import BaseDataLoader
+
+
+def CreateDataset(opt):
+ dataset = None
+ from data.aligned_dataset import AlignedDataset
+ dataset = AlignedDataset()
+
+ print("dataset [%s] was created" % (dataset.name()))
+ dataset.initialize(opt)
+ return dataset
+
+class CustomDatasetDataLoader(BaseDataLoader):
+ def name(self):
+ return 'CustomDatasetDataLoader'
+
+ def initialize(self, opt):
+ BaseDataLoader.initialize(self, opt)
+ self.dataset = CreateDataset(opt)
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batchSize,
+ shuffle=not opt.serial_batches,
+ num_workers=int(opt.nThreads))
+
+ def load_data(self):
+ return self.dataloader
+
+ def __len__(self):
+ return min(len(self.dataset), self.opt.max_dataset_size)