summaryrefslogtreecommitdiff
path: root/data/unaligned_dataset.py
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-06-24 15:12:34 -0700
committerjunyanz <junyanz@berkeley.edu>2017-06-24 15:12:34 -0700
commitc9f390499155d9047693dda1abd2de39efd161d3 (patch)
tree2d8100910621abf2d9beda75a241411b872762a7 /data/unaligned_dataset.py
parentff172d0799373b3d53ecdd2a0ca6a1776b1c1063 (diff)
fix small issues in the unaligned_dataset
Diffstat (limited to 'data/unaligned_dataset.py')
-rw-r--r--data/unaligned_dataset.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
index 1f75b23..7333d16 100644
--- a/data/unaligned_dataset.py
+++ b/data/unaligned_dataset.py
@@ -19,6 +19,8 @@ class UnalignedDataset(BaseDataset):
self.A_paths = sorted(self.A_paths)
self.B_paths = sorted(self.B_paths)
+ self.A_size = len(self.A_paths)
+ self.B_size = len(self.B_paths)
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
@@ -37,8 +39,8 @@ class UnalignedDataset(BaseDataset):
self.transform = transforms.Compose(transform_list)
def __getitem__(self, index):
- A_path = self.A_paths[index]
- B_path = self.B_paths[index]
+ A_path = self.A_paths[index % self.A_size]
+ B_path = self.B_paths[index % self.B_size]
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
@@ -50,7 +52,7 @@ class UnalignedDataset(BaseDataset):
'A_paths': A_path, 'B_paths': B_path}
def __len__(self):
- return min(len(self.A_paths), len(self.B_paths))
+ return max(self.A_size, self.B_size)
def name(self):
return 'UnalignedDataset'