diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/custom_dataset_data_loader.py | 8 | ||||
| -rw-r--r-- | data/unaligned_dataset.py | 5 |
2 files changed, 11 insertions, 2 deletions
diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py index 60180e0..787946f 100644 --- a/data/custom_dataset_data_loader.py +++ b/data/custom_dataset_data_loader.py @@ -35,7 +35,13 @@ class CustomDatasetDataLoader(BaseDataLoader): num_workers=int(opt.nThreads)) def load_data(self): - return self.dataloader + return self def __len__(self): return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + for i, data in enumerate(self.dataloader): + if i >= self.opt.max_dataset_size: + break + yield data diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index c5e5460..ad0c11b 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -25,7 +25,10 @@ class UnalignedDataset(BaseDataset): def __getitem__(self, index): A_path = self.A_paths[index % self.A_size] index_A = index % self.A_size - index_B = random.randint(0, self.B_size - 1) + if self.opt.serial_batches: + index_B = index % self.B_size + else: + index_B = random.randint(0, self.B_size - 1) B_path = self.B_paths[index_B] # print('(A, B) = (%d, %d)' % (index_A, index_B)) A_img = Image.open(A_path).convert('RGB') |
