summaryrefslogtreecommitdiff
path: root/data/aligned_dataset.py
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-06-12 23:52:56 -0700
committerjunyanz <junyanz@berkeley.edu>2017-06-12 23:52:56 -0700
commite6858e35f0a08c6139c133122d222d0d85e8005d (patch)
tree2647ff13a164c684113eab455123394a49a65dad /data/aligned_dataset.py
parent3b72a659c38141e502b74bee65ca08d51dc3eabf (diff)
update dataset mode
Diffstat (limited to 'data/aligned_dataset.py')
-rw-r--r--data/aligned_dataset.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
new file mode 100644
index 0000000..0f45c40
--- /dev/null
+++ b/data/aligned_dataset.py
@@ -0,0 +1,56 @@
+import os.path
+import random
+import torchvision.transforms as transforms
+import torch
+from data.base_dataset import BaseDataset
+from data.image_folder import make_dataset
+from PIL import Image
+
+
+class AlignedDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+ self.dir_AB = os.path.join(opt.dataroot, opt.phase)
+
+ self.AB_paths = sorted(make_dataset(self.dir_AB))
+
+ assert(opt.resize_or_crop == 'resize_and_crop')
+
+ transform_list = [transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+
+ self.transform = transforms.Compose(transform_list)
+
+ def __getitem__(self, index):
+ AB_path = self.AB_paths[index]
+ AB = Image.open(AB_path).convert('RGB')
+ AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC)
+ AB = self.transform(AB)
+
+ w_total = AB.size(2)
+ w = int(w_total / 2)
+ h = AB.size(1)
+ w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
+ h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
+
+ A = AB[:, h_offset:h_offset + self.opt.fineSize,
+ w_offset:w_offset + self.opt.fineSize]
+ B = AB[:, h_offset:h_offset + self.opt.fineSize,
+ w + w_offset:w + w_offset + self.opt.fineSize]
+
+ if (not self.opt.no_flip) and random.random() < 0.5:
+ idx = [i for i in range(A.size(2) - 1, -1, -1)]
+ idx = torch.LongTensor(idx)
+ A = A.index_select(2, idx)
+ B = B.index_select(2, idx)
+
+ return {'A': A, 'B': B,
+ 'A_paths': AB_path, 'B_paths': AB_path}
+
+ def __len__(self):
+ return len(self.AB_paths)
+
+ def name(self):
+ return 'AlignedDataset'