diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/__init__.py | 54 | ||||
| -rw-r--r-- | data/custom_dataset_data_loader.py | 47 | ||||
| -rw-r--r-- | data/data_loader.py | 6 |
3 files changed, 54 insertions, 53 deletions
diff --git a/data/__init__.py b/data/__init__.py index e69de29..ef581e7 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -0,0 +1,54 @@ +import torch.utils.data +from data.base_data_loader import BaseDataLoader + + +def CreateDataLoader(opt): + data_loader = CustomDatasetDataLoader() + print(data_loader.name()) + data_loader.initialize(opt) + return data_loader + + +def CreateDataset(opt): + dataset = None + if opt.dataset_mode == 'aligned': + from data.aligned_dataset import AlignedDataset + dataset = AlignedDataset() + elif opt.dataset_mode == 'unaligned': + from data.unaligned_dataset import UnalignedDataset + dataset = UnalignedDataset() + elif opt.dataset_mode == 'single': + from data.single_dataset import SingleDataset + dataset = SingleDataset() + else: + raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) + + print("dataset [%s] was created" % (dataset.name())) + dataset.initialize(opt) + return dataset + + +class CustomDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.dataset = CreateDataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads)) + + def load_data(self): + return self + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + for i, data in enumerate(self.dataloader): + if i >= self.opt.max_dataset_size: + break + yield data diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py deleted file mode 100644 index 787946f..0000000 --- a/data/custom_dataset_data_loader.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch.utils.data -from data.base_data_loader import BaseDataLoader - - -def CreateDataset(opt): - dataset = None - if opt.dataset_mode == 'aligned': - from data.aligned_dataset import AlignedDataset - dataset = AlignedDataset() - elif opt.dataset_mode == 'unaligned': - from data.unaligned_dataset import UnalignedDataset - dataset = UnalignedDataset() - elif opt.dataset_mode == 'single': - from data.single_dataset import SingleDataset - dataset = SingleDataset() - else: - raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) - - print("dataset [%s] was created" % (dataset.name())) - dataset.initialize(opt) - return dataset - - -class CustomDatasetDataLoader(BaseDataLoader): - def name(self): - return 'CustomDatasetDataLoader' - - def initialize(self, opt): - BaseDataLoader.initialize(self, opt) - self.dataset = CreateDataset(opt) - self.dataloader = torch.utils.data.DataLoader( - self.dataset, - batch_size=opt.batchSize, - shuffle=not opt.serial_batches, - num_workers=int(opt.nThreads)) - - def load_data(self): - return self - - def __len__(self): - return min(len(self.dataset), self.opt.max_dataset_size) - - def __iter__(self): - for i, data in enumerate(self.dataloader): - if i >= self.opt.max_dataset_size: - break - yield data diff --git a/data/data_loader.py b/data/data_loader.py deleted file mode 100644 index 22b6a8f..0000000 --- a/data/data_loader.py +++ /dev/null @@ -1,6 +0,0 @@ -def CreateDataLoader(opt): - from data.custom_dataset_data_loader import CustomDatasetDataLoader - data_loader = CustomDatasetDataLoader() - print(data_loader.name()) - data_loader.initialize(opt) - return data_loader |
