summaryrefslogtreecommitdiff
path: root/data/unaligned_data_loader.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/unaligned_data_loader.py')
-rw-r--r--data/unaligned_data_loader.py32
1 files changed, 26 insertions, 6 deletions
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
index 95d9ac7..4f82dbe 100644
--- a/data/unaligned_data_loader.py
+++ b/data/unaligned_data_loader.py
@@ -3,12 +3,14 @@ import torchvision.transforms as transforms
from data.base_data_loader import BaseDataLoader
from data.image_folder import ImageFolder
from builtins import object
-
+from pdb import set_trace as st
class PairedData(object):
def __init__(self, data_loader_A, data_loader_B):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
+ self.stop_A = False
+ self.stop_B = False
def __iter__(self):
self.data_loader_A_iter = iter(self.data_loader_A)
@@ -16,11 +18,29 @@ class PairedData(object):
return self
def __next__(self):
- A, A_paths = next(self.data_loader_A_iter)
- B, B_paths = next(self.data_loader_B_iter)
- return {'A': A, 'A_paths': A_paths,
- 'B': B, 'B_paths': B_paths}
+ A, A_paths = None, None
+ B, B_paths = None, None
+ try:
+ A, A_paths = next(self.data_loader_A_iter)
+ except StopIteration:
+ if A is None or A_paths is None:
+ self.stop_A = True
+ self.data_loader_A_iter = iter(self.data_loader_A)
+ A, A_paths = next(self.data_loader_A_iter)
+ try:
+ B, B_paths = next(self.data_loader_B_iter)
+
+ except StopIteration:
+ if B is None or B_paths is None:
+ self.stop_B = True
+ self.data_loader_B_iter = iter(self.data_loader_B)
+ B, B_paths = next(self.data_loader_B_iter)
+ if self.stop_A and self.stop_B:
+ raise StopIteration()
+ else:
+ return {'A': A, 'A_paths': A_paths,
+ 'B': B, 'B_paths': B_paths}
class UnalignedDataLoader(BaseDataLoader):
def initialize(self, opt):
@@ -60,4 +80,4 @@ class UnalignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return len(self.dataset_A)
+ return max(len(self.dataset_A), len(self.dataset_B))