From e5b2fd6d36b4297c4314478e88820cd10943d192 Mon Sep 17 00:00:00 2001 From: Taesung Park Date: Thu, 27 Apr 2017 01:28:53 -0700 Subject: 1. Added one_direction_test_model that generates the outputs in only one direction 2. Changed the option naming from ntrain to max_dataset_size --- data/aligned_data_loader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'data/aligned_data_loader.py') diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py index b7a228b..a1efde8 100644 --- a/data/aligned_data_loader.py +++ b/data/aligned_data_loader.py @@ -8,10 +8,10 @@ from pdb import set_trace as st from builtins import object class PairedData(object): - def __init__(self, data_loader, fineSize, ntrain): + def __init__(self, data_loader, fineSize, max_dataset_size): self.data_loader = data_loader self.fineSize = fineSize - self.ntrain = ntrain + self.max_dataset_size = max_dataset_size # st() def __iter__(self): @@ -21,7 +21,7 @@ class PairedData(object): def __next__(self): self.iter += 1 - if self.iter > self.ntrain: + if self.iter > self.max_dataset_size: raise StopIteration AB, AB_paths = next(self.data_loader_iter) @@ -60,7 +60,7 @@ class AlignedDataLoader(BaseDataLoader): num_workers=int(self.opt.nThreads)) self.dataset = dataset - self.paired_data = PairedData(data_loader, opt.fineSize, opt.ntrain) + self.paired_data = PairedData(data_loader, opt.fineSize, opt.max_dataset_size) def name(self): return 'AlignedDataLoader' @@ -69,4 +69,4 @@ class AlignedDataLoader(BaseDataLoader): return self.paired_data def __len__(self): - return min(len(self.dataset), self.opt.ntrain) + return min(len(self.dataset), self.opt.max_dataset_size) -- cgit v1.2.3-70-g09d2