diff options
| author | SsnL <tongzhou.wang.1994@gmail.com> | 2017-07-06 22:19:53 -0500 |
|---|---|---|
| committer | SsnL <tongzhou.wang.1994@gmail.com> | 2017-07-06 22:31:24 -0500 |
| commit | 25124b8389f80d7a509b2d98ef69589cab597c9a (patch) | |
| tree | 185d876bb0fed0e681f163e79ad810e597c8dd8c /data | |
| parent | ee0a8292e2b87449c325bdb9439f90f911a0c0a1 (diff) | |
resize_or_crop and better display single image
Diffstat (limited to 'data')
| -rw-r--r-- | data/base_dataset.py | 37 | ||||
| -rw-r--r-- | data/single_dataset.py | 18 | ||||
| -rw-r--r-- | data/unaligned_dataset.py | 19 |
3 files changed, 39 insertions, 35 deletions
diff --git a/data/base_dataset.py b/data/base_dataset.py index 49b9d98..a061a05 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -1,12 +1,45 @@ import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms class BaseDataset(data.Dataset): def __init__(self): super(BaseDataset, self).__init__() - + def name(self): return 'BaseDataset' - + def initialize(self, opt): pass +def get_transform(opt): + transform_list = [] + if opt.resize_or_crop == 'resize_and_crop': + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Scale(osize, Image.BICUBIC)) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'crop': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'scale_width': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.fineSize))) + elif opt.resize_or_crop == 'scale_width_and_crop': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.loadSize))) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + +def __scale_width(img, target_width): + ow, oh = img.size + if (ow == target_width): + return img + w = target_width + h = int(target_width * oh / ow) + return img.resize((w, h), Image.BICUBIC) diff --git a/data/single_dataset.py b/data/single_dataset.py index 106bea3..faf416a 100644 --- a/data/single_dataset.py +++ b/data/single_dataset.py @@ -1,6 +1,6 @@ import os.path import torchvision.transforms as transforms -from data.base_dataset import BaseDataset +from data.base_dataset import BaseDataset, get_transform from data.image_folder import make_dataset from PIL import Image @@ -15,21 +15,7 @@ class SingleDataset(BaseDataset): self.A_paths = sorted(self.A_paths) - transform_list = [] - if opt.resize_or_crop == 'resize_and_crop': - transform_list.append(transforms.Scale(opt.loadSize)) - - if opt.isTrain and not opt.no_flip: - transform_list.append(transforms.RandomHorizontalFlip()) - - if opt.resize_or_crop != 'no_resize': - transform_list.append(transforms.RandomCrop(opt.fineSize)) - - transform_list += [transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] - - self.transform = transforms.Compose(transform_list) + self.transform = get_transform(opt) def __getitem__(self, index): A_path = self.A_paths[index] diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index 7333d16..3864bf3 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -1,6 +1,6 @@ import os.path import torchvision.transforms as transforms -from data.base_dataset import BaseDataset +from data.base_dataset import BaseDataset, get_transform from data.image_folder import make_dataset from PIL import Image import PIL @@ -21,22 +21,7 @@ class UnalignedDataset(BaseDataset): self.B_paths = sorted(self.B_paths) self.A_size = len(self.A_paths) self.B_size = len(self.B_paths) - - transform_list = [] - if opt.resize_or_crop == 'resize_and_crop': - osize = [opt.loadSize, opt.loadSize] - transform_list.append(transforms.Scale(osize, Image.BICUBIC)) - - if opt.isTrain and not opt.no_flip: - transform_list.append(transforms.RandomHorizontalFlip()) - - if opt.resize_or_crop != 'no_resize': - transform_list.append(transforms.RandomCrop(opt.fineSize)) - - transform_list += [transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] - self.transform = transforms.Compose(transform_list) + self.transform = get_transform(opt) def __getitem__(self, index): A_path = self.A_paths[index % self.A_size] |
