summaryrefslogtreecommitdiff
path: root/data/unaligned_data_loader.py
diff options
context:
space:
mode:
authorTaesung Park <taesung_park@berkeley.edu>2017-04-20 03:39:31 -0700
committerTaesung Park <taesung_park@berkeley.edu>2017-04-20 03:39:31 -0700
commitd4206aad119326c57fefeb97176f8fbda6cd8d1e (patch)
treeab5aed69e069c1d446918f13c28049d156c6be90 /data/unaligned_data_loader.py
parent03d01ea7723015b29aac078daa2d2797e042923a (diff)
parent443bc13554769d6a18eefdbac779cf385fb6dbb3 (diff)
merged conflicts
Diffstat (limited to 'data/unaligned_data_loader.py')
-rw-r--r--data/unaligned_data_loader.py32
1 files changed, 26 insertions, 6 deletions
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
index 95d9ac7..4f82dbe 100644
--- a/data/unaligned_data_loader.py
+++ b/data/unaligned_data_loader.py
@@ -3,12 +3,14 @@ import torchvision.transforms as transforms
from data.base_data_loader import BaseDataLoader
from data.image_folder import ImageFolder
from builtins import object
-
+from pdb import set_trace as st
class PairedData(object):
def __init__(self, data_loader_A, data_loader_B):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
+ self.stop_A = False
+ self.stop_B = False
def __iter__(self):
self.data_loader_A_iter = iter(self.data_loader_A)
@@ -16,11 +18,29 @@ class PairedData(object):
return self
def __next__(self):
- A, A_paths = next(self.data_loader_A_iter)
- B, B_paths = next(self.data_loader_B_iter)
- return {'A': A, 'A_paths': A_paths,
- 'B': B, 'B_paths': B_paths}
+ A, A_paths = None, None
+ B, B_paths = None, None
+ try:
+ A, A_paths = next(self.data_loader_A_iter)
+ except StopIteration:
+ if A is None or A_paths is None:
+ self.stop_A = True
+ self.data_loader_A_iter = iter(self.data_loader_A)
+ A, A_paths = next(self.data_loader_A_iter)
+ try:
+ B, B_paths = next(self.data_loader_B_iter)
+
+ except StopIteration:
+ if B is None or B_paths is None:
+ self.stop_B = True
+ self.data_loader_B_iter = iter(self.data_loader_B)
+ B, B_paths = next(self.data_loader_B_iter)
+ if self.stop_A and self.stop_B:
+ raise StopIteration()
+ else:
+ return {'A': A, 'A_paths': A_paths,
+ 'B': B, 'B_paths': B_paths}
class UnalignedDataLoader(BaseDataLoader):
def initialize(self, opt):
@@ -60,4 +80,4 @@ class UnalignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return len(self.dataset_A)
+ return max(len(self.dataset_A), len(self.dataset_B))