summaryrefslogtreecommitdiff
path: root/data/aligned_data_loader.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/aligned_data_loader.py')
-rw-r--r--data/aligned_data_loader.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
index 8171bc2..b7a228b 100644
--- a/data/aligned_data_loader.py
+++ b/data/aligned_data_loader.py
@@ -8,16 +8,22 @@ from pdb import set_trace as st
from builtins import object
class PairedData(object):
- def __init__(self, data_loader, fineSize):
+ def __init__(self, data_loader, fineSize, ntrain):
self.data_loader = data_loader
self.fineSize = fineSize
+ self.ntrain = ntrain
# st()
def __iter__(self):
self.data_loader_iter = iter(self.data_loader)
+ self.iter = 0
return self
def __next__(self):
+ self.iter += 1
+ if self.iter > self.ntrain:
+ raise StopIteration
+
AB, AB_paths = next(self.data_loader_iter)
w_total = AB.size(3)
w = int(w_total / 2)
@@ -54,7 +60,7 @@ class AlignedDataLoader(BaseDataLoader):
num_workers=int(self.opt.nThreads))
self.dataset = dataset
- self.paired_data = PairedData(data_loader, opt.fineSize)
+ self.paired_data = PairedData(data_loader, opt.fineSize, opt.ntrain)
def name(self):
return 'AlignedDataLoader'
@@ -63,4 +69,4 @@ class AlignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return len(self.dataset)
+ return min(len(self.dataset), self.opt.ntrain)