diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-06-24 15:12:34 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-06-24 15:12:34 -0700 |
| commit | c9f390499155d9047693dda1abd2de39efd161d3 (patch) | |
| tree | 2d8100910621abf2d9beda75a241411b872762a7 /data/unaligned_dataset.py | |
| parent | ff172d0799373b3d53ecdd2a0ca6a1776b1c1063 (diff) | |
fix small issues in the unaligned_dataset
Diffstat (limited to 'data/unaligned_dataset.py')
| -rw-r--r-- | data/unaligned_dataset.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index 1f75b23..7333d16 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -19,6 +19,8 @@ class UnalignedDataset(BaseDataset): self.A_paths = sorted(self.A_paths) self.B_paths = sorted(self.B_paths) + self.A_size = len(self.A_paths) + self.B_size = len(self.B_paths) transform_list = [] if opt.resize_or_crop == 'resize_and_crop': @@ -37,8 +39,8 @@ class UnalignedDataset(BaseDataset): self.transform = transforms.Compose(transform_list) def __getitem__(self, index): - A_path = self.A_paths[index] - B_path = self.B_paths[index] + A_path = self.A_paths[index % self.A_size] + B_path = self.B_paths[index % self.B_size] A_img = Image.open(A_path).convert('RGB') B_img = Image.open(B_path).convert('RGB') @@ -50,7 +52,7 @@ class UnalignedDataset(BaseDataset): 'A_paths': A_path, 'B_paths': B_path} def __len__(self): - return min(len(self.A_paths), len(self.B_paths)) + return max(self.A_size, self.B_size) def name(self): return 'UnalignedDataset' |
