diff options
| author | SsnL <tongzhou.wang.1994@gmail.com> | 2017-07-06 22:19:53 -0500 |
|---|---|---|
| committer | SsnL <tongzhou.wang.1994@gmail.com> | 2017-07-06 22:31:24 -0500 |
| commit | 25124b8389f80d7a509b2d98ef69589cab597c9a (patch) | |
| tree | 185d876bb0fed0e681f163e79ad810e597c8dd8c /data/base_dataset.py | |
| parent | ee0a8292e2b87449c325bdb9439f90f911a0c0a1 (diff) | |
resize_or_crop and better display single image
Diffstat (limited to 'data/base_dataset.py')
| -rw-r--r-- | data/base_dataset.py | 37 |
1 files changed, 35 insertions, 2 deletions
diff --git a/data/base_dataset.py b/data/base_dataset.py index 49b9d98..a061a05 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -1,12 +1,45 @@ import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms class BaseDataset(data.Dataset): def __init__(self): super(BaseDataset, self).__init__() - + def name(self): return 'BaseDataset' - + def initialize(self, opt): pass +def get_transform(opt): + transform_list = [] + if opt.resize_or_crop == 'resize_and_crop': + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Scale(osize, Image.BICUBIC)) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'crop': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'scale_width': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.fineSize))) + elif opt.resize_or_crop == 'scale_width_and_crop': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.loadSize))) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + +def __scale_width(img, target_width): + ow, oh = img.size + if (ow == target_width): + return img + w = target_width + h = int(target_width * oh / ow) + return img.resize((w, h), Image.BICUBIC) |
