From af7420fc67f1a69349f8155bb85a3536314e377b Mon Sep 17 00:00:00 2001 From: Taesung Park Date: Wed, 26 Apr 2017 15:50:54 -0700 Subject: Added support for the option ntrain --- data/unaligned_data_loader.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'data/unaligned_data_loader.py') diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py index 6e926d3..4a29c7c 100644 --- a/data/unaligned_data_loader.py +++ b/data/unaligned_data_loader.py @@ -7,17 +7,19 @@ from builtins import object from pdb import set_trace as st class PairedData(object): - def __init__(self, data_loader_A, data_loader_B): + def __init__(self, data_loader_A, data_loader_B, ntrain): self.data_loader_A = data_loader_A self.data_loader_B = data_loader_B self.stop_A = False self.stop_B = False + self.ntrain = ntrain def __iter__(self): self.stop_A = False self.stop_B = False self.data_loader_A_iter = iter(self.data_loader_A) self.data_loader_B_iter = iter(self.data_loader_B) + self.iter = 0 return self def __next__(self): @@ -30,20 +32,21 @@ class PairedData(object): self.stop_A = True self.data_loader_A_iter = iter(self.data_loader_A) A, A_paths = next(self.data_loader_A_iter) + try: B, B_paths = next(self.data_loader_B_iter) - except StopIteration: if B is None or B_paths is None: self.stop_B = True self.data_loader_B_iter = iter(self.data_loader_B) B, B_paths = next(self.data_loader_B_iter) - if self.stop_A and self.stop_B: + if (self.stop_A and self.stop_B) or self.iter > self.ntrain: self.stop_A = False self.stop_B = False raise StopIteration() else: + self.iter += 1 return {'A': A, 'A_paths': A_paths, 'B': B, 'B_paths': B_paths} @@ -76,7 +79,7 @@ class UnalignedDataLoader(BaseDataLoader): num_workers=int(self.opt.nThreads)) self.dataset_A = dataset_A self.dataset_B = dataset_B - self.paired_data = PairedData(data_loader_A, data_loader_B) + self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.ntrain) def name(self): return 'UnalignedDataLoader' @@ -85,4 +88,4 @@ class UnalignedDataLoader(BaseDataLoader): return self.paired_data def __len__(self): - return max(len(self.dataset_A), len(self.dataset_B)) + return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.ntrain) -- cgit v1.2.3-70-g09d2