diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-06-12 23:52:56 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-06-12 23:52:56 -0700 |
| commit | e6858e35f0a08c6139c133122d222d0d85e8005d (patch) | |
| tree | 2647ff13a164c684113eab455123394a49a65dad /data/unaligned_dataset.py | |
| parent | 3b72a659c38141e502b74bee65ca08d51dc3eabf (diff) | |
update dataset mode
Diffstat (limited to 'data/unaligned_dataset.py')
| -rw-r--r-- | data/unaligned_dataset.py | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py new file mode 100644 index 0000000..1f75b23 --- /dev/null +++ b/data/unaligned_dataset.py @@ -0,0 +1,56 @@ +import os.path +import torchvision.transforms as transforms +from data.base_dataset import BaseDataset +from data.image_folder import make_dataset +from PIL import Image +import PIL +from pdb import set_trace as st + + +class UnalignedDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') + self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') + + self.A_paths = make_dataset(self.dir_A) + self.B_paths = make_dataset(self.dir_B) + + self.A_paths = sorted(self.A_paths) + self.B_paths = sorted(self.B_paths) + + transform_list = [] + if opt.resize_or_crop == 'resize_and_crop': + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Scale(osize, Image.BICUBIC)) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + if opt.resize_or_crop != 'no_resize': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + self.transform = transforms.Compose(transform_list) + + def __getitem__(self, index): + A_path = self.A_paths[index] + B_path = self.B_paths[index] + + A_img = Image.open(A_path).convert('RGB') + B_img = Image.open(B_path).convert('RGB') + + A_img = self.transform(A_img) + B_img = self.transform(B_img) + + return {'A': A_img, 'B': B_img, + 'A_paths': A_path, 'B_paths': B_path} + + def __len__(self): + return min(len(self.A_paths), len(self.B_paths)) + + def name(self): + return 'UnalignedDataset' |
