summaryrefslogtreecommitdiff
path: root/data/unaligned_data_loader.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/unaligned_data_loader.py')
-rw-r--r--data/unaligned_data_loader.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
index 4a06510..77f9274 100644
--- a/data/unaligned_data_loader.py
+++ b/data/unaligned_data_loader.py
@@ -2,21 +2,24 @@ import torch.utils.data
import torchvision.transforms as transforms
from data.base_data_loader import BaseDataLoader
from data.image_folder import ImageFolder
+# pip install future --upgrade
from builtins import object
from pdb import set_trace as st
class PairedData(object):
- def __init__(self, data_loader_A, data_loader_B):
+ def __init__(self, data_loader_A, data_loader_B, max_dataset_size):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
self.stop_A = False
self.stop_B = False
+ self.max_dataset_size = max_dataset_size
def __iter__(self):
self.stop_A = False
self.stop_B = False
self.data_loader_A_iter = iter(self.data_loader_A)
self.data_loader_B_iter = iter(self.data_loader_B)
+ self.iter = 0
return self
def __next__(self):
@@ -29,20 +32,21 @@ class PairedData(object):
self.stop_A = True
self.data_loader_A_iter = iter(self.data_loader_A)
A, A_paths = next(self.data_loader_A_iter)
+
try:
B, B_paths = next(self.data_loader_B_iter)
-
except StopIteration:
if B is None or B_paths is None:
self.stop_B = True
self.data_loader_B_iter = iter(self.data_loader_B)
B, B_paths = next(self.data_loader_B_iter)
- if self.stop_A and self.stop_B:
+ if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size:
self.stop_A = False
self.stop_B = False
raise StopIteration()
else:
+ self.iter += 1
return {'A': A, 'A_paths': A_paths,
'B': B, 'B_paths': B_paths}
@@ -51,7 +55,7 @@ class UnalignedDataLoader(BaseDataLoader):
BaseDataLoader.initialize(self, opt)
transform = transforms.Compose([
transforms.Scale(opt.loadSize),
- transforms.CenterCrop(opt.fineSize),
+ transforms.RandomCrop(opt.fineSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))])
@@ -75,7 +79,7 @@ class UnalignedDataLoader(BaseDataLoader):
num_workers=int(self.opt.nThreads))
self.dataset_A = dataset_A
self.dataset_B = dataset_B
- self.paired_data = PairedData(data_loader_A, data_loader_B)
+ self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.max_dataset_size)
def name(self):
return 'UnalignedDataLoader'
@@ -84,4 +88,4 @@ class UnalignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return max(len(self.dataset_A), len(self.dataset_B))
+ return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.max_dataset_size)