summaryrefslogtreecommitdiff
path: root/data/single_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/single_dataset.py')
-rw-r--r--data/single_dataset.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/data/single_dataset.py b/data/single_dataset.py
new file mode 100644
index 0000000..106bea3
--- /dev/null
+++ b/data/single_dataset.py
@@ -0,0 +1,47 @@
+import os.path
+import torchvision.transforms as transforms
+from data.base_dataset import BaseDataset
+from data.image_folder import make_dataset
+from PIL import Image
+
+
+class SingleDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+ self.dir_A = os.path.join(opt.dataroot)
+
+ self.A_paths = make_dataset(self.dir_A)
+
+ self.A_paths = sorted(self.A_paths)
+
+ transform_list = []
+ if opt.resize_or_crop == 'resize_and_crop':
+ transform_list.append(transforms.Scale(opt.loadSize))
+
+ if opt.isTrain and not opt.no_flip:
+ transform_list.append(transforms.RandomHorizontalFlip())
+
+ if opt.resize_or_crop != 'no_resize':
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
+
+ transform_list += [transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+
+ self.transform = transforms.Compose(transform_list)
+
+ def __getitem__(self, index):
+ A_path = self.A_paths[index]
+
+ A_img = Image.open(A_path).convert('RGB')
+
+ A_img = self.transform(A_img)
+
+ return {'A': A_img, 'A_paths': A_path}
+
+ def __len__(self):
+ return len(self.A_paths)
+
+ def name(self):
+ return 'SingleImageDataset'