summaryrefslogtreecommitdiff
path: root/data/unaligned_data_loader.py
diff options
context:
space:
mode:
authorTaesung Park <taesung_park@berkeley.edu>2017-04-27 01:28:53 -0700
committerTaesung Park <taesung_park@berkeley.edu>2017-04-27 01:28:53 -0700
commite5b2fd6d36b4297c4314478e88820cd10943d192 (patch)
treeaa6bf853e47a858b3e9d8a5ed051377f9b0b6415 /data/unaligned_data_loader.py
parentaf7420fc67f1a69349f8155bb85a3536314e377b (diff)
1. Added one_direction_test_model that generates the outputs in only one direction
2. Changed the option naming from ntrain to max_dataset_size
Diffstat (limited to 'data/unaligned_data_loader.py')
-rw-r--r--data/unaligned_data_loader.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
index 4a29c7c..77f9274 100644
--- a/data/unaligned_data_loader.py
+++ b/data/unaligned_data_loader.py
@@ -7,12 +7,12 @@ from builtins import object
from pdb import set_trace as st
class PairedData(object):
- def __init__(self, data_loader_A, data_loader_B, ntrain):
+ def __init__(self, data_loader_A, data_loader_B, max_dataset_size):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
self.stop_A = False
self.stop_B = False
- self.ntrain = ntrain
+ self.max_dataset_size = max_dataset_size
def __iter__(self):
self.stop_A = False
@@ -41,7 +41,7 @@ class PairedData(object):
self.data_loader_B_iter = iter(self.data_loader_B)
B, B_paths = next(self.data_loader_B_iter)
- if (self.stop_A and self.stop_B) or self.iter > self.ntrain:
+ if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size:
self.stop_A = False
self.stop_B = False
raise StopIteration()
@@ -79,7 +79,7 @@ class UnalignedDataLoader(BaseDataLoader):
num_workers=int(self.opt.nThreads))
self.dataset_A = dataset_A
self.dataset_B = dataset_B
- self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.ntrain)
+ self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.max_dataset_size)
def name(self):
return 'UnalignedDataLoader'
@@ -88,4 +88,4 @@ class UnalignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.ntrain)
+ return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.max_dataset_size)