From c9f390499155d9047693dda1abd2de39efd161d3 Mon Sep 17 00:00:00 2001 From: junyanz Date: Sat, 24 Jun 2017 15:12:34 -0700 Subject: fix small issues in the unaligned_dataset --- data/unaligned_dataset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'data/unaligned_dataset.py') diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index 1f75b23..7333d16 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -19,6 +19,8 @@ class UnalignedDataset(BaseDataset): self.A_paths = sorted(self.A_paths) self.B_paths = sorted(self.B_paths) + self.A_size = len(self.A_paths) + self.B_size = len(self.B_paths) transform_list = [] if opt.resize_or_crop == 'resize_and_crop': @@ -37,8 +39,8 @@ class UnalignedDataset(BaseDataset): self.transform = transforms.Compose(transform_list) def __getitem__(self, index): - A_path = self.A_paths[index] - B_path = self.B_paths[index] + A_path = self.A_paths[index % self.A_size] + B_path = self.B_paths[index % self.B_size] A_img = Image.open(A_path).convert('RGB') B_img = Image.open(B_path).convert('RGB') @@ -50,7 +52,7 @@ class UnalignedDataset(BaseDataset): 'A_paths': A_path, 'B_paths': B_path} def __len__(self): - return min(len(self.A_paths), len(self.B_paths)) + return max(self.A_size, self.B_size) def name(self): return 'UnalignedDataset' -- cgit v1.2.3-70-g09d2