diff options
Diffstat (limited to 'data/base_dataset.py')
| -rwxr-xr-x | data/base_dataset.py | 92 |
1 files changed, 92 insertions, 0 deletions
diff --git a/data/base_dataset.py b/data/base_dataset.py new file mode 100755 index 0000000..038d3d2 --- /dev/null +++ b/data/base_dataset.py @@ -0,0 +1,92 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +import numpy as np +import random + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + def name(self): + return 'BaseDataset' + + def initialize(self, opt): + pass + +def get_params(opt, size): + w, h = size + new_h = h + new_w = w + if opt.resize_or_crop == 'resize_and_crop': + new_h = new_w = opt.loadSize + elif opt.resize_or_crop == 'scale_width_and_crop': + new_w = opt.loadSize + new_h = opt.loadSize * h // w + + x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) + y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) + + flip = random.random() > 0.5 + return {'crop_pos': (x, y), 'flip': flip} + +def get_transform(opt, params, method=Image.BICUBIC, normalize=True): + transform_list = [] + if 'resize' in opt.resize_or_crop: + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Scale(osize, method)) + elif 'scale_width' in opt.resize_or_crop: + transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) + + if 'crop' in opt.resize_or_crop: + transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) + + if opt.resize_or_crop == 'none': + base = float(2 ** opt.n_downsample_global) + if opt.netG == 'local': + base *= (2 ** opt.n_local_enhancers) + transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) + + transform_list += [transforms.ToTensor()] + + if normalize: + transform_list += [transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + +def normalize(): + return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + +def __make_power_2(img, base, method=Image.BICUBIC): + ow, oh = img.size + h = int(round(oh / base) * base) + w = int(round(ow / base) * base) + if (h == oh) and (w == ow): + return img + return img.resize((w, h), method) + +def __scale_width(img, target_width, method=Image.BICUBIC): + ow, oh = img.size + if (ow == target_width): + return img + w = target_width + h = int(target_width * oh / ow) + return img.resize((w, h), method) + +def __crop(img, pos, size): + ow, oh = img.size + x1, y1 = pos + tw = th = size + if (ow > tw or oh > th): + return img.crop((x1, y1, x1 + tw, y1 + th)) + return img + +def __flip(img, flip): + if flip: + return img.transpose(Image.FLIP_LEFT_RIGHT) + return img |
