diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/aligned_data_loader.py | 4 | ||||
| -rw-r--r-- | data/unaligned_data_loader.py | 32 |
2 files changed, 27 insertions, 9 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py index 01dbf89..bea3531 100644 --- a/data/aligned_data_loader.py +++ b/data/aligned_data_loader.py @@ -17,9 +17,7 @@ class PairedData(object): return self def __next__(self): - # st() AB, AB_paths = next(self.data_loader_iter) - # st() w_total = AB.size(3) w = int(w_total / 2) h = AB.size(2) @@ -55,7 +53,7 @@ class AlignedDataLoader(BaseDataLoader): batch_size=self.opt.batchSize, shuffle=not self.opt.serial_batches, num_workers=int(self.opt.nThreads)) - + self.dataset = dataset self.paired_data = PairedData(data_loader, opt.fineSize) diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py index 95d9ac7..4f82dbe 100644 --- a/data/unaligned_data_loader.py +++ b/data/unaligned_data_loader.py @@ -3,12 +3,14 @@ import torchvision.transforms as transforms from data.base_data_loader import BaseDataLoader from data.image_folder import ImageFolder from builtins import object - +from pdb import set_trace as st class PairedData(object): def __init__(self, data_loader_A, data_loader_B): self.data_loader_A = data_loader_A self.data_loader_B = data_loader_B + self.stop_A = False + self.stop_B = False def __iter__(self): self.data_loader_A_iter = iter(self.data_loader_A) @@ -16,11 +18,29 @@ class PairedData(object): return self def __next__(self): - A, A_paths = next(self.data_loader_A_iter) - B, B_paths = next(self.data_loader_B_iter) - return {'A': A, 'A_paths': A_paths, - 'B': B, 'B_paths': B_paths} + A, A_paths = None, None + B, B_paths = None, None + try: + A, A_paths = next(self.data_loader_A_iter) + except StopIteration: + if A is None or A_paths is None: + self.stop_A = True + self.data_loader_A_iter = iter(self.data_loader_A) + A, A_paths = next(self.data_loader_A_iter) + try: + B, B_paths = next(self.data_loader_B_iter) + + except StopIteration: + if B is None or B_paths is None: + self.stop_B = True + self.data_loader_B_iter = iter(self.data_loader_B) + B, B_paths = next(self.data_loader_B_iter) + if self.stop_A and self.stop_B: + raise StopIteration() + else: + return {'A': A, 'A_paths': A_paths, + 'B': B, 'B_paths': B_paths} class UnalignedDataLoader(BaseDataLoader): def initialize(self, opt): @@ -60,4 +80,4 @@ class UnalignedDataLoader(BaseDataLoader): return self.paired_data def __len__(self): - return len(self.dataset_A) + return max(len(self.dataset_A), len(self.dataset_B)) |
