diff options
Diffstat (limited to 'data')
| -rwxr-xr-x | data/__init__.py | 0 | ||||
| -rwxr-xr-x | data/aligned_dataset.py | 76 | ||||
| -rwxr-xr-x | data/base_data_loader.py | 14 | ||||
| -rwxr-xr-x | data/base_dataset.py | 92 | ||||
| -rwxr-xr-x | data/custom_dataset_data_loader.py | 31 | ||||
| -rwxr-xr-x | data/data_loader.py | 7 | ||||
| -rwxr-xr-x | data/image_folder.py | 68 |
7 files changed, 288 insertions, 0 deletions
diff --git a/data/__init__.py b/data/__init__.py new file mode 100755 index 0000000..e69de29 --- /dev/null +++ b/data/__init__.py diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py new file mode 100755 index 0000000..50390f3 --- /dev/null +++ b/data/aligned_dataset.py @@ -0,0 +1,76 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import os.path +import random +import torchvision.transforms as transforms +import torch +from data.base_dataset import BaseDataset, get_params, get_transform, normalize +from data.image_folder import make_dataset +from PIL import Image +import numpy as np + +class AlignedDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + + ### label maps + self.dir_label = os.path.join(opt.dataroot, opt.phase + '_label') + self.label_paths = sorted(make_dataset(self.dir_label)) + + ### real images + if opt.isTrain: + self.dir_image = os.path.join(opt.dataroot, opt.phase + '_img') + self.image_paths = sorted(make_dataset(self.dir_image)) + + ### instance maps + if not opt.no_instance: + self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst') + self.inst_paths = sorted(make_dataset(self.dir_inst)) + + ### load precomputed instance-wise encoded features + if opt.load_features: + self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat') + print('----------- loading features from %s ----------' % self.dir_feat) + self.feat_paths = sorted(make_dataset(self.dir_feat)) + + self.dataset_size = len(self.label_paths) + + def __getitem__(self, index): + ### label maps + label_path = self.label_paths[index] + label = Image.open(label_path) + params = get_params(self.opt, label.size) + transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) + label_tensor = transform_label(label) * 255.0 + + image_tensor = inst_tensor = feat_tensor = 0 + ### real images + if self.opt.isTrain: + image_path = self.image_paths[index] + image = Image.open(image_path).convert('RGB') + transform_image = get_transform(self.opt, params) + image_tensor = transform_image(image) + + ### if using instance maps + if not self.opt.no_instance: + inst_path = self.inst_paths[index] + inst = Image.open(inst_path) + inst_tensor = transform_label(inst) + + if self.opt.load_features: + feat_path = self.feat_paths[index] + feat = Image.open(feat_path).convert('RGB') + norm = normalize() + feat_tensor = norm(transform_label(feat)) + + input_dict = {'label': label_tensor, 'inst': inst_tensor, 'image': image_tensor, + 'feat': feat_tensor, 'path': label_path} + + return input_dict + + def __len__(self): + return len(self.label_paths) + + def name(self): + return 'AlignedDataset'
\ No newline at end of file diff --git a/data/base_data_loader.py b/data/base_data_loader.py new file mode 100755 index 0000000..0e1deb5 --- /dev/null +++ b/data/base_data_loader.py @@ -0,0 +1,14 @@ + +class BaseDataLoader(): + def __init__(self): + pass + + def initialize(self, opt): + self.opt = opt + pass + + def load_data(): + return None + + + diff --git a/data/base_dataset.py b/data/base_dataset.py new file mode 100755 index 0000000..038d3d2 --- /dev/null +++ b/data/base_dataset.py @@ -0,0 +1,92 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +import numpy as np +import random + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + def name(self): + return 'BaseDataset' + + def initialize(self, opt): + pass + +def get_params(opt, size): + w, h = size + new_h = h + new_w = w + if opt.resize_or_crop == 'resize_and_crop': + new_h = new_w = opt.loadSize + elif opt.resize_or_crop == 'scale_width_and_crop': + new_w = opt.loadSize + new_h = opt.loadSize * h // w + + x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) + y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) + + flip = random.random() > 0.5 + return {'crop_pos': (x, y), 'flip': flip} + +def get_transform(opt, params, method=Image.BICUBIC, normalize=True): + transform_list = [] + if 'resize' in opt.resize_or_crop: + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Scale(osize, method)) + elif 'scale_width' in opt.resize_or_crop: + transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) + + if 'crop' in opt.resize_or_crop: + transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) + + if opt.resize_or_crop == 'none': + base = float(2 ** opt.n_downsample_global) + if opt.netG == 'local': + base *= (2 ** opt.n_local_enhancers) + transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) + + transform_list += [transforms.ToTensor()] + + if normalize: + transform_list += [transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + +def normalize(): + return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + +def __make_power_2(img, base, method=Image.BICUBIC): + ow, oh = img.size + h = int(round(oh / base) * base) + w = int(round(ow / base) * base) + if (h == oh) and (w == ow): + return img + return img.resize((w, h), method) + +def __scale_width(img, target_width, method=Image.BICUBIC): + ow, oh = img.size + if (ow == target_width): + return img + w = target_width + h = int(target_width * oh / ow) + return img.resize((w, h), method) + +def __crop(img, pos, size): + ow, oh = img.size + x1, y1 = pos + tw = th = size + if (ow > tw or oh > th): + return img.crop((x1, y1, x1 + tw, y1 + th)) + return img + +def __flip(img, flip): + if flip: + return img.transpose(Image.FLIP_LEFT_RIGHT) + return img diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py new file mode 100755 index 0000000..0b98254 --- /dev/null +++ b/data/custom_dataset_data_loader.py @@ -0,0 +1,31 @@ +import torch.utils.data +from data.base_data_loader import BaseDataLoader + + +def CreateDataset(opt): + dataset = None + from data.aligned_dataset import AlignedDataset + dataset = AlignedDataset() + + print("dataset [%s] was created" % (dataset.name())) + dataset.initialize(opt) + return dataset + +class CustomDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.dataset = CreateDataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads)) + + def load_data(self): + return self.dataloader + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) diff --git a/data/data_loader.py b/data/data_loader.py new file mode 100755 index 0000000..2a4433a --- /dev/null +++ b/data/data_loader.py @@ -0,0 +1,7 @@ + +def CreateDataLoader(opt): + from data.custom_dataset_data_loader import CustomDatasetDataLoader + data_loader = CustomDatasetDataLoader() + print(data_loader.name()) + data_loader.initialize(opt) + return data_loader diff --git a/data/image_folder.py b/data/image_folder.py new file mode 100755 index 0000000..16a447c --- /dev/null +++ b/data/image_folder.py @@ -0,0 +1,68 @@ +############################################################################### +# Code from +# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py +# Modified the original code so that it also loads images from the current +# directory as well as the subdirectories +############################################################################### + +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + + return images + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) |
