summaryrefslogtreecommitdiff
path: root/data/unaligned_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/unaligned_dataset.py')
-rw-r--r--data/unaligned_dataset.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
new file mode 100644
index 0000000..1f75b23
--- /dev/null
+++ b/data/unaligned_dataset.py
@@ -0,0 +1,56 @@
+import os.path
+import torchvision.transforms as transforms
+from data.base_dataset import BaseDataset
+from data.image_folder import make_dataset
+from PIL import Image
+import PIL
+from pdb import set_trace as st
+
+
+class UnalignedDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
+
+ self.A_paths = make_dataset(self.dir_A)
+ self.B_paths = make_dataset(self.dir_B)
+
+ self.A_paths = sorted(self.A_paths)
+ self.B_paths = sorted(self.B_paths)
+
+ transform_list = []
+ if opt.resize_or_crop == 'resize_and_crop':
+ osize = [opt.loadSize, opt.loadSize]
+ transform_list.append(transforms.Scale(osize, Image.BICUBIC))
+
+ if opt.isTrain and not opt.no_flip:
+ transform_list.append(transforms.RandomHorizontalFlip())
+
+ if opt.resize_or_crop != 'no_resize':
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
+
+ transform_list += [transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+ self.transform = transforms.Compose(transform_list)
+
+ def __getitem__(self, index):
+ A_path = self.A_paths[index]
+ B_path = self.B_paths[index]
+
+ A_img = Image.open(A_path).convert('RGB')
+ B_img = Image.open(B_path).convert('RGB')
+
+ A_img = self.transform(A_img)
+ B_img = self.transform(B_img)
+
+ return {'A': A_img, 'B': B_img,
+ 'A_paths': A_path, 'B_paths': B_path}
+
+ def __len__(self):
+ return min(len(self.A_paths), len(self.B_paths))
+
+ def name(self):
+ return 'UnalignedDataset'