summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTaesung Park <taesung_park@berkeley.edu>2017-04-26 15:50:54 -0700
committerTaesung Park <taesung_park@berkeley.edu>2017-04-26 15:50:54 -0700
commitaf7420fc67f1a69349f8155bb85a3536314e377b (patch)
tree00e1ffa058f86958e8e21e413ec470f9c4db6e01
parentd5a38496cf624cb110cb292d18e0822b826b61ab (diff)
Added support for the option ntrain
-rw-r--r--data/aligned_data_loader.py12
-rw-r--r--data/unaligned_data_loader.py13
2 files changed, 17 insertions, 8 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
index 8171bc2..b7a228b 100644
--- a/data/aligned_data_loader.py
+++ b/data/aligned_data_loader.py
@@ -8,16 +8,22 @@ from pdb import set_trace as st
from builtins import object
class PairedData(object):
- def __init__(self, data_loader, fineSize):
+ def __init__(self, data_loader, fineSize, ntrain):
self.data_loader = data_loader
self.fineSize = fineSize
+ self.ntrain = ntrain
# st()
def __iter__(self):
self.data_loader_iter = iter(self.data_loader)
+ self.iter = 0
return self
def __next__(self):
+ self.iter += 1
+ if self.iter > self.ntrain:
+ raise StopIteration
+
AB, AB_paths = next(self.data_loader_iter)
w_total = AB.size(3)
w = int(w_total / 2)
@@ -54,7 +60,7 @@ class AlignedDataLoader(BaseDataLoader):
num_workers=int(self.opt.nThreads))
self.dataset = dataset
- self.paired_data = PairedData(data_loader, opt.fineSize)
+ self.paired_data = PairedData(data_loader, opt.fineSize, opt.ntrain)
def name(self):
return 'AlignedDataLoader'
@@ -63,4 +69,4 @@ class AlignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return len(self.dataset)
+ return min(len(self.dataset), self.opt.ntrain)
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
index 6e926d3..4a29c7c 100644
--- a/data/unaligned_data_loader.py
+++ b/data/unaligned_data_loader.py
@@ -7,17 +7,19 @@ from builtins import object
from pdb import set_trace as st
class PairedData(object):
- def __init__(self, data_loader_A, data_loader_B):
+ def __init__(self, data_loader_A, data_loader_B, ntrain):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
self.stop_A = False
self.stop_B = False
+ self.ntrain = ntrain
def __iter__(self):
self.stop_A = False
self.stop_B = False
self.data_loader_A_iter = iter(self.data_loader_A)
self.data_loader_B_iter = iter(self.data_loader_B)
+ self.iter = 0
return self
def __next__(self):
@@ -30,20 +32,21 @@ class PairedData(object):
self.stop_A = True
self.data_loader_A_iter = iter(self.data_loader_A)
A, A_paths = next(self.data_loader_A_iter)
+
try:
B, B_paths = next(self.data_loader_B_iter)
-
except StopIteration:
if B is None or B_paths is None:
self.stop_B = True
self.data_loader_B_iter = iter(self.data_loader_B)
B, B_paths = next(self.data_loader_B_iter)
- if self.stop_A and self.stop_B:
+ if (self.stop_A and self.stop_B) or self.iter > self.ntrain:
self.stop_A = False
self.stop_B = False
raise StopIteration()
else:
+ self.iter += 1
return {'A': A, 'A_paths': A_paths,
'B': B, 'B_paths': B_paths}
@@ -76,7 +79,7 @@ class UnalignedDataLoader(BaseDataLoader):
num_workers=int(self.opt.nThreads))
self.dataset_A = dataset_A
self.dataset_B = dataset_B
- self.paired_data = PairedData(data_loader_A, data_loader_B)
+ self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.ntrain)
def name(self):
return 'UnalignedDataLoader'
@@ -85,4 +88,4 @@ class UnalignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return max(len(self.dataset_A), len(self.dataset_B))
+ return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.ntrain)