summaryrefslogtreecommitdiff
path: root/data/base_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/base_dataset.py')
-rwxr-xr-xdata/base_dataset.py92
1 files changed, 92 insertions, 0 deletions
diff --git a/data/base_dataset.py b/data/base_dataset.py
new file mode 100755
index 0000000..038d3d2
--- /dev/null
+++ b/data/base_dataset.py
@@ -0,0 +1,92 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import torch.utils.data as data
+from PIL import Image
+import torchvision.transforms as transforms
+import numpy as np
+import random
+
+class BaseDataset(data.Dataset):
+ def __init__(self):
+ super(BaseDataset, self).__init__()
+
+ def name(self):
+ return 'BaseDataset'
+
+ def initialize(self, opt):
+ pass
+
+def get_params(opt, size):
+ w, h = size
+ new_h = h
+ new_w = w
+ if opt.resize_or_crop == 'resize_and_crop':
+ new_h = new_w = opt.loadSize
+ elif opt.resize_or_crop == 'scale_width_and_crop':
+ new_w = opt.loadSize
+ new_h = opt.loadSize * h // w
+
+ x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
+ y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
+
+ flip = random.random() > 0.5
+ return {'crop_pos': (x, y), 'flip': flip}
+
+def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
+ transform_list = []
+ if 'resize' in opt.resize_or_crop:
+ osize = [opt.loadSize, opt.loadSize]
+ transform_list.append(transforms.Scale(osize, method))
+ elif 'scale_width' in opt.resize_or_crop:
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
+
+ if 'crop' in opt.resize_or_crop:
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
+
+ if opt.resize_or_crop == 'none':
+ base = float(2 ** opt.n_downsample_global)
+ if opt.netG == 'local':
+ base *= (2 ** opt.n_local_enhancers)
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
+
+ if opt.isTrain and not opt.no_flip:
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
+
+ transform_list += [transforms.ToTensor()]
+
+ if normalize:
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+ return transforms.Compose(transform_list)
+
+def normalize():
+ return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+
+def __make_power_2(img, base, method=Image.BICUBIC):
+ ow, oh = img.size
+ h = int(round(oh / base) * base)
+ w = int(round(ow / base) * base)
+ if (h == oh) and (w == ow):
+ return img
+ return img.resize((w, h), method)
+
+def __scale_width(img, target_width, method=Image.BICUBIC):
+ ow, oh = img.size
+ if (ow == target_width):
+ return img
+ w = target_width
+ h = int(target_width * oh / ow)
+ return img.resize((w, h), method)
+
+def __crop(img, pos, size):
+ ow, oh = img.size
+ x1, y1 = pos
+ tw = th = size
+ if (ow > tw or oh > th):
+ return img.crop((x1, y1, x1 + tw, y1 + th))
+ return img
+
+def __flip(img, flip):
+ if flip:
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
+ return img