summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/aligned_data_loader.py4
-rw-r--r--data/unaligned_data_loader.py32
2 files changed, 27 insertions, 9 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
index 01dbf89..bea3531 100644
--- a/data/aligned_data_loader.py
+++ b/data/aligned_data_loader.py
@@ -17,9 +17,7 @@ class PairedData(object):
return self
def __next__(self):
- # st()
AB, AB_paths = next(self.data_loader_iter)
- # st()
w_total = AB.size(3)
w = int(w_total / 2)
h = AB.size(2)
@@ -55,7 +53,7 @@ class AlignedDataLoader(BaseDataLoader):
batch_size=self.opt.batchSize,
shuffle=not self.opt.serial_batches,
num_workers=int(self.opt.nThreads))
-
+
self.dataset = dataset
self.paired_data = PairedData(data_loader, opt.fineSize)
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
index 95d9ac7..4f82dbe 100644
--- a/data/unaligned_data_loader.py
+++ b/data/unaligned_data_loader.py
@@ -3,12 +3,14 @@ import torchvision.transforms as transforms
from data.base_data_loader import BaseDataLoader
from data.image_folder import ImageFolder
from builtins import object
-
+from pdb import set_trace as st
class PairedData(object):
def __init__(self, data_loader_A, data_loader_B):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
+ self.stop_A = False
+ self.stop_B = False
def __iter__(self):
self.data_loader_A_iter = iter(self.data_loader_A)
@@ -16,11 +18,29 @@ class PairedData(object):
return self
def __next__(self):
- A, A_paths = next(self.data_loader_A_iter)
- B, B_paths = next(self.data_loader_B_iter)
- return {'A': A, 'A_paths': A_paths,
- 'B': B, 'B_paths': B_paths}
+ A, A_paths = None, None
+ B, B_paths = None, None
+ try:
+ A, A_paths = next(self.data_loader_A_iter)
+ except StopIteration:
+ if A is None or A_paths is None:
+ self.stop_A = True
+ self.data_loader_A_iter = iter(self.data_loader_A)
+ A, A_paths = next(self.data_loader_A_iter)
+ try:
+ B, B_paths = next(self.data_loader_B_iter)
+
+ except StopIteration:
+ if B is None or B_paths is None:
+ self.stop_B = True
+ self.data_loader_B_iter = iter(self.data_loader_B)
+ B, B_paths = next(self.data_loader_B_iter)
+ if self.stop_A and self.stop_B:
+ raise StopIteration()
+ else:
+ return {'A': A, 'A_paths': A_paths,
+ 'B': B, 'B_paths': B_paths}
class UnalignedDataLoader(BaseDataLoader):
def initialize(self, opt):
@@ -60,4 +80,4 @@ class UnalignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return len(self.dataset_A)
+ return max(len(self.dataset_A), len(self.dataset_B))