diff options
Diffstat (limited to 'data/unaligned_data_loader.py')
| -rw-r--r-- | data/unaligned_data_loader.py | 32 |
1 files changed, 26 insertions, 6 deletions
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py index 95d9ac7..4f82dbe 100644 --- a/data/unaligned_data_loader.py +++ b/data/unaligned_data_loader.py @@ -3,12 +3,14 @@ import torchvision.transforms as transforms from data.base_data_loader import BaseDataLoader from data.image_folder import ImageFolder from builtins import object - +from pdb import set_trace as st class PairedData(object): def __init__(self, data_loader_A, data_loader_B): self.data_loader_A = data_loader_A self.data_loader_B = data_loader_B + self.stop_A = False + self.stop_B = False def __iter__(self): self.data_loader_A_iter = iter(self.data_loader_A) @@ -16,11 +18,29 @@ class PairedData(object): return self def __next__(self): - A, A_paths = next(self.data_loader_A_iter) - B, B_paths = next(self.data_loader_B_iter) - return {'A': A, 'A_paths': A_paths, - 'B': B, 'B_paths': B_paths} + A, A_paths = None, None + B, B_paths = None, None + try: + A, A_paths = next(self.data_loader_A_iter) + except StopIteration: + if A is None or A_paths is None: + self.stop_A = True + self.data_loader_A_iter = iter(self.data_loader_A) + A, A_paths = next(self.data_loader_A_iter) + try: + B, B_paths = next(self.data_loader_B_iter) + + except StopIteration: + if B is None or B_paths is None: + self.stop_B = True + self.data_loader_B_iter = iter(self.data_loader_B) + B, B_paths = next(self.data_loader_B_iter) + if self.stop_A and self.stop_B: + raise StopIteration() + else: + return {'A': A, 'A_paths': A_paths, + 'B': B, 'B_paths': B_paths} class UnalignedDataLoader(BaseDataLoader): def initialize(self, opt): @@ -60,4 +80,4 @@ class UnalignedDataLoader(BaseDataLoader): return self.paired_data def __len__(self): - return len(self.dataset_A) + return max(len(self.dataset_A), len(self.dataset_B)) |
