diff options
| author | Taesung Park <taesung_park@berkeley.edu> | 2017-04-26 15:50:54 -0700 |
|---|---|---|
| committer | Taesung Park <taesung_park@berkeley.edu> | 2017-04-26 15:50:54 -0700 |
| commit | af7420fc67f1a69349f8155bb85a3536314e377b (patch) | |
| tree | 00e1ffa058f86958e8e21e413ec470f9c4db6e01 /data/aligned_data_loader.py | |
| parent | d5a38496cf624cb110cb292d18e0822b826b61ab (diff) | |
Added support for the option ntrain
Diffstat (limited to 'data/aligned_data_loader.py')
| -rw-r--r-- | data/aligned_data_loader.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py index 8171bc2..b7a228b 100644 --- a/data/aligned_data_loader.py +++ b/data/aligned_data_loader.py @@ -8,16 +8,22 @@ from pdb import set_trace as st from builtins import object class PairedData(object): - def __init__(self, data_loader, fineSize): + def __init__(self, data_loader, fineSize, ntrain): self.data_loader = data_loader self.fineSize = fineSize + self.ntrain = ntrain # st() def __iter__(self): self.data_loader_iter = iter(self.data_loader) + self.iter = 0 return self def __next__(self): + self.iter += 1 + if self.iter > self.ntrain: + raise StopIteration + AB, AB_paths = next(self.data_loader_iter) w_total = AB.size(3) w = int(w_total / 2) @@ -54,7 +60,7 @@ class AlignedDataLoader(BaseDataLoader): num_workers=int(self.opt.nThreads)) self.dataset = dataset - self.paired_data = PairedData(data_loader, opt.fineSize) + self.paired_data = PairedData(data_loader, opt.fineSize, opt.ntrain) def name(self): return 'AlignedDataLoader' @@ -63,4 +69,4 @@ class AlignedDataLoader(BaseDataLoader): return self.paired_data def __len__(self): - return len(self.dataset) + return min(len(self.dataset), self.opt.ntrain) |
