summaryrefslogtreecommitdiff
path: root/data/aligned_data_loader.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/aligned_data_loader.py')
-rw-r--r--data/aligned_data_loader.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
index b7a228b..a1efde8 100644
--- a/data/aligned_data_loader.py
+++ b/data/aligned_data_loader.py
@@ -8,10 +8,10 @@ from pdb import set_trace as st
from builtins import object
class PairedData(object):
- def __init__(self, data_loader, fineSize, ntrain):
+ def __init__(self, data_loader, fineSize, max_dataset_size):
self.data_loader = data_loader
self.fineSize = fineSize
- self.ntrain = ntrain
+ self.max_dataset_size = max_dataset_size
# st()
def __iter__(self):
@@ -21,7 +21,7 @@ class PairedData(object):
def __next__(self):
self.iter += 1
- if self.iter > self.ntrain:
+ if self.iter > self.max_dataset_size:
raise StopIteration
AB, AB_paths = next(self.data_loader_iter)
@@ -60,7 +60,7 @@ class AlignedDataLoader(BaseDataLoader):
num_workers=int(self.opt.nThreads))
self.dataset = dataset
- self.paired_data = PairedData(data_loader, opt.fineSize, opt.ntrain)
+ self.paired_data = PairedData(data_loader, opt.fineSize, opt.max_dataset_size)
def name(self):
return 'AlignedDataLoader'
@@ -69,4 +69,4 @@ class AlignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return min(len(self.dataset), self.opt.ntrain)
+ return min(len(self.dataset), self.opt.max_dataset_size)