diff options
Diffstat (limited to 'data/unaligned_dataset.py')
| -rw-r--r-- | data/unaligned_dataset.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index 1f75b23..7333d16 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -19,6 +19,8 @@ class UnalignedDataset(BaseDataset): self.A_paths = sorted(self.A_paths) self.B_paths = sorted(self.B_paths) + self.A_size = len(self.A_paths) + self.B_size = len(self.B_paths) transform_list = [] if opt.resize_or_crop == 'resize_and_crop': @@ -37,8 +39,8 @@ class UnalignedDataset(BaseDataset): self.transform = transforms.Compose(transform_list) def __getitem__(self, index): - A_path = self.A_paths[index] - B_path = self.B_paths[index] + A_path = self.A_paths[index % self.A_size] + B_path = self.B_paths[index % self.B_size] A_img = Image.open(A_path).convert('RGB') B_img = Image.open(B_path).convert('RGB') @@ -50,7 +52,7 @@ class UnalignedDataset(BaseDataset): 'A_paths': A_path, 'B_paths': B_path} def __len__(self): - return min(len(self.A_paths), len(self.B_paths)) + return max(self.A_size, self.B_size) def name(self): return 'UnalignedDataset' |
