summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/__init__.py54
-rw-r--r--data/custom_dataset_data_loader.py47
-rw-r--r--data/data_loader.py6
3 files changed, 54 insertions, 53 deletions
diff --git a/data/__init__.py b/data/__init__.py
index e69de29..ef581e7 100644
--- a/data/__init__.py
+++ b/data/__init__.py
@@ -0,0 +1,54 @@
+import torch.utils.data
+from data.base_data_loader import BaseDataLoader
+
+
+def CreateDataLoader(opt):
+ data_loader = CustomDatasetDataLoader()
+ print(data_loader.name())
+ data_loader.initialize(opt)
+ return data_loader
+
+
+def CreateDataset(opt):
+ dataset = None
+ if opt.dataset_mode == 'aligned':
+ from data.aligned_dataset import AlignedDataset
+ dataset = AlignedDataset()
+ elif opt.dataset_mode == 'unaligned':
+ from data.unaligned_dataset import UnalignedDataset
+ dataset = UnalignedDataset()
+ elif opt.dataset_mode == 'single':
+ from data.single_dataset import SingleDataset
+ dataset = SingleDataset()
+ else:
+ raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
+
+ print("dataset [%s] was created" % (dataset.name()))
+ dataset.initialize(opt)
+ return dataset
+
+
+class CustomDatasetDataLoader(BaseDataLoader):
+ def name(self):
+ return 'CustomDatasetDataLoader'
+
+ def initialize(self, opt):
+ BaseDataLoader.initialize(self, opt)
+ self.dataset = CreateDataset(opt)
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batchSize,
+ shuffle=not opt.serial_batches,
+ num_workers=int(opt.nThreads))
+
+ def load_data(self):
+ return self
+
+ def __len__(self):
+ return min(len(self.dataset), self.opt.max_dataset_size)
+
+ def __iter__(self):
+ for i, data in enumerate(self.dataloader):
+ if i >= self.opt.max_dataset_size:
+ break
+ yield data
diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py
deleted file mode 100644
index 787946f..0000000
--- a/data/custom_dataset_data_loader.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import torch.utils.data
-from data.base_data_loader import BaseDataLoader
-
-
-def CreateDataset(opt):
- dataset = None
- if opt.dataset_mode == 'aligned':
- from data.aligned_dataset import AlignedDataset
- dataset = AlignedDataset()
- elif opt.dataset_mode == 'unaligned':
- from data.unaligned_dataset import UnalignedDataset
- dataset = UnalignedDataset()
- elif opt.dataset_mode == 'single':
- from data.single_dataset import SingleDataset
- dataset = SingleDataset()
- else:
- raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
-
- print("dataset [%s] was created" % (dataset.name()))
- dataset.initialize(opt)
- return dataset
-
-
-class CustomDatasetDataLoader(BaseDataLoader):
- def name(self):
- return 'CustomDatasetDataLoader'
-
- def initialize(self, opt):
- BaseDataLoader.initialize(self, opt)
- self.dataset = CreateDataset(opt)
- self.dataloader = torch.utils.data.DataLoader(
- self.dataset,
- batch_size=opt.batchSize,
- shuffle=not opt.serial_batches,
- num_workers=int(opt.nThreads))
-
- def load_data(self):
- return self
-
- def __len__(self):
- return min(len(self.dataset), self.opt.max_dataset_size)
-
- def __iter__(self):
- for i, data in enumerate(self.dataloader):
- if i >= self.opt.max_dataset_size:
- break
- yield data
diff --git a/data/data_loader.py b/data/data_loader.py
deleted file mode 100644
index 22b6a8f..0000000
--- a/data/data_loader.py
+++ /dev/null
@@ -1,6 +0,0 @@
-def CreateDataLoader(opt):
- from data.custom_dataset_data_loader import CustomDatasetDataLoader
- data_loader = CustomDatasetDataLoader()
- print(data_loader.name())
- data_loader.initialize(opt)
- return data_loader