summaryrefslogtreecommitdiff
path: root/data/aligned_data_loader.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/aligned_data_loader.py')
-rw-r--r--data/aligned_data_loader.py17
1 files changed, 11 insertions, 6 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
index bea3531..a1efde8 100644
--- a/data/aligned_data_loader.py
+++ b/data/aligned_data_loader.py
@@ -4,19 +4,26 @@ import torchvision.transforms as transforms
from data.base_data_loader import BaseDataLoader
from data.image_folder import ImageFolder
from pdb import set_trace as st
+# pip install future --upgrade
from builtins import object
class PairedData(object):
- def __init__(self, data_loader, fineSize):
+ def __init__(self, data_loader, fineSize, max_dataset_size):
self.data_loader = data_loader
self.fineSize = fineSize
+ self.max_dataset_size = max_dataset_size
# st()
def __iter__(self):
self.data_loader_iter = iter(self.data_loader)
+ self.iter = 0
return self
def __next__(self):
+ self.iter += 1
+ if self.iter > self.max_dataset_size:
+ raise StopIteration
+
AB, AB_paths = next(self.data_loader_iter)
w_total = AB.size(3)
w = int(w_total / 2)
@@ -24,7 +31,6 @@ class PairedData(object):
w_offset = random.randint(0, max(0, w - self.fineSize - 1))
h_offset = random.randint(0, max(0, h - self.fineSize - 1))
-
A = AB[:, :, h_offset:h_offset + self.fineSize,
w_offset:w_offset + self.fineSize]
B = AB[:, :, h_offset:h_offset + self.fineSize,
@@ -39,8 +45,7 @@ class AlignedDataLoader(BaseDataLoader):
self.fineSize = opt.fineSize
transform = transforms.Compose([
# TODO: Scale
- #transforms.Scale((opt.loadSize * 2, opt.loadSize)),
- #transforms.CenterCrop(opt.fineSize),
+ transforms.Scale(opt.loadSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))])
@@ -55,7 +60,7 @@ class AlignedDataLoader(BaseDataLoader):
num_workers=int(self.opt.nThreads))
self.dataset = dataset
- self.paired_data = PairedData(data_loader, opt.fineSize)
+ self.paired_data = PairedData(data_loader, opt.fineSize, opt.max_dataset_size)
def name(self):
return 'AlignedDataLoader'
@@ -64,4 +69,4 @@ class AlignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return len(self.dataset)
+ return min(len(self.dataset), self.opt.max_dataset_size)