diff options
| author | Taesung Park <taesung_park@berkeley.edu> | 2017-04-27 01:28:53 -0700 |
|---|---|---|
| committer | Taesung Park <taesung_park@berkeley.edu> | 2017-04-27 01:28:53 -0700 |
| commit | e5b2fd6d36b4297c4314478e88820cd10943d192 (patch) | |
| tree | aa6bf853e47a858b3e9d8a5ed051377f9b0b6415 /data/unaligned_data_loader.py | |
| parent | af7420fc67f1a69349f8155bb85a3536314e377b (diff) | |
1. Added one_direction_test_model that generates the outputs in only one direction
2. Changed the option naming from ntrain to max_dataset_size
Diffstat (limited to 'data/unaligned_data_loader.py')
| -rw-r--r-- | data/unaligned_data_loader.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py index 4a29c7c..77f9274 100644 --- a/data/unaligned_data_loader.py +++ b/data/unaligned_data_loader.py @@ -7,12 +7,12 @@ from builtins import object from pdb import set_trace as st class PairedData(object): - def __init__(self, data_loader_A, data_loader_B, ntrain): + def __init__(self, data_loader_A, data_loader_B, max_dataset_size): self.data_loader_A = data_loader_A self.data_loader_B = data_loader_B self.stop_A = False self.stop_B = False - self.ntrain = ntrain + self.max_dataset_size = max_dataset_size def __iter__(self): self.stop_A = False @@ -41,7 +41,7 @@ class PairedData(object): self.data_loader_B_iter = iter(self.data_loader_B) B, B_paths = next(self.data_loader_B_iter) - if (self.stop_A and self.stop_B) or self.iter > self.ntrain: + if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size: self.stop_A = False self.stop_B = False raise StopIteration() @@ -79,7 +79,7 @@ class UnalignedDataLoader(BaseDataLoader): num_workers=int(self.opt.nThreads)) self.dataset_A = dataset_A self.dataset_B = dataset_B - self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.ntrain) + self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.max_dataset_size) def name(self): return 'UnalignedDataLoader' @@ -88,4 +88,4 @@ class UnalignedDataLoader(BaseDataLoader): return self.paired_data def __len__(self): - return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.ntrain) + return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.max_dataset_size) |
