summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/custom_dataset_data_loader.py8
-rw-r--r--data/unaligned_dataset.py5
2 files changed, 11 insertions, 2 deletions
diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py
index 60180e0..787946f 100644
--- a/data/custom_dataset_data_loader.py
+++ b/data/custom_dataset_data_loader.py
@@ -35,7 +35,13 @@ class CustomDatasetDataLoader(BaseDataLoader):
num_workers=int(opt.nThreads))
def load_data(self):
- return self.dataloader
+ return self
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
+
+ def __iter__(self):
+ for i, data in enumerate(self.dataloader):
+ if i >= self.opt.max_dataset_size:
+ break
+ yield data
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
index c5e5460..ad0c11b 100644
--- a/data/unaligned_dataset.py
+++ b/data/unaligned_dataset.py
@@ -25,7 +25,10 @@ class UnalignedDataset(BaseDataset):
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
index_A = index % self.A_size
- index_B = random.randint(0, self.B_size - 1)
+ if self.opt.serial_batches:
+ index_B = index % self.B_size
+ else:
+ index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
# print('(A, B) = (%d, %d)' % (index_A, index_B))
A_img = Image.open(A_path).convert('RGB')