summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rwxr-xr-xdata/__init__.py0
-rwxr-xr-xdata/aligned_dataset.py76
-rwxr-xr-xdata/base_data_loader.py14
-rwxr-xr-xdata/base_dataset.py92
-rwxr-xr-xdata/custom_dataset_data_loader.py31
-rwxr-xr-xdata/data_loader.py7
-rwxr-xr-xdata/image_folder.py68
7 files changed, 288 insertions, 0 deletions
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100755
index 0000000..e69de29
--- /dev/null
+++ b/data/__init__.py
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
new file mode 100755
index 0000000..50390f3
--- /dev/null
+++ b/data/aligned_dataset.py
@@ -0,0 +1,76 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import os.path
+import random
+import torchvision.transforms as transforms
+import torch
+from data.base_dataset import BaseDataset, get_params, get_transform, normalize
+from data.image_folder import make_dataset
+from PIL import Image
+import numpy as np
+
+class AlignedDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+
+ ### label maps
+ self.dir_label = os.path.join(opt.dataroot, opt.phase + '_label')
+ self.label_paths = sorted(make_dataset(self.dir_label))
+
+ ### real images
+ if opt.isTrain:
+ self.dir_image = os.path.join(opt.dataroot, opt.phase + '_img')
+ self.image_paths = sorted(make_dataset(self.dir_image))
+
+ ### instance maps
+ if not opt.no_instance:
+ self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
+ self.inst_paths = sorted(make_dataset(self.dir_inst))
+
+ ### load precomputed instance-wise encoded features
+ if opt.load_features:
+ self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
+ print('----------- loading features from %s ----------' % self.dir_feat)
+ self.feat_paths = sorted(make_dataset(self.dir_feat))
+
+ self.dataset_size = len(self.label_paths)
+
+ def __getitem__(self, index):
+ ### label maps
+ label_path = self.label_paths[index]
+ label = Image.open(label_path)
+ params = get_params(self.opt, label.size)
+ transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
+ label_tensor = transform_label(label) * 255.0
+
+ image_tensor = inst_tensor = feat_tensor = 0
+ ### real images
+ if self.opt.isTrain:
+ image_path = self.image_paths[index]
+ image = Image.open(image_path).convert('RGB')
+ transform_image = get_transform(self.opt, params)
+ image_tensor = transform_image(image)
+
+ ### if using instance maps
+ if not self.opt.no_instance:
+ inst_path = self.inst_paths[index]
+ inst = Image.open(inst_path)
+ inst_tensor = transform_label(inst)
+
+ if self.opt.load_features:
+ feat_path = self.feat_paths[index]
+ feat = Image.open(feat_path).convert('RGB')
+ norm = normalize()
+ feat_tensor = norm(transform_label(feat))
+
+ input_dict = {'label': label_tensor, 'inst': inst_tensor, 'image': image_tensor,
+ 'feat': feat_tensor, 'path': label_path}
+
+ return input_dict
+
+ def __len__(self):
+ return len(self.label_paths)
+
+ def name(self):
+ return 'AlignedDataset' \ No newline at end of file
diff --git a/data/base_data_loader.py b/data/base_data_loader.py
new file mode 100755
index 0000000..0e1deb5
--- /dev/null
+++ b/data/base_data_loader.py
@@ -0,0 +1,14 @@
+
+class BaseDataLoader():
+ def __init__(self):
+ pass
+
+ def initialize(self, opt):
+ self.opt = opt
+ pass
+
+ def load_data():
+ return None
+
+
+
diff --git a/data/base_dataset.py b/data/base_dataset.py
new file mode 100755
index 0000000..038d3d2
--- /dev/null
+++ b/data/base_dataset.py
@@ -0,0 +1,92 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import torch.utils.data as data
+from PIL import Image
+import torchvision.transforms as transforms
+import numpy as np
+import random
+
+class BaseDataset(data.Dataset):
+ def __init__(self):
+ super(BaseDataset, self).__init__()
+
+ def name(self):
+ return 'BaseDataset'
+
+ def initialize(self, opt):
+ pass
+
+def get_params(opt, size):
+ w, h = size
+ new_h = h
+ new_w = w
+ if opt.resize_or_crop == 'resize_and_crop':
+ new_h = new_w = opt.loadSize
+ elif opt.resize_or_crop == 'scale_width_and_crop':
+ new_w = opt.loadSize
+ new_h = opt.loadSize * h // w
+
+ x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
+ y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
+
+ flip = random.random() > 0.5
+ return {'crop_pos': (x, y), 'flip': flip}
+
+def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
+ transform_list = []
+ if 'resize' in opt.resize_or_crop:
+ osize = [opt.loadSize, opt.loadSize]
+ transform_list.append(transforms.Scale(osize, method))
+ elif 'scale_width' in opt.resize_or_crop:
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
+
+ if 'crop' in opt.resize_or_crop:
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
+
+ if opt.resize_or_crop == 'none':
+ base = float(2 ** opt.n_downsample_global)
+ if opt.netG == 'local':
+ base *= (2 ** opt.n_local_enhancers)
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
+
+ if opt.isTrain and not opt.no_flip:
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
+
+ transform_list += [transforms.ToTensor()]
+
+ if normalize:
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+ return transforms.Compose(transform_list)
+
+def normalize():
+ return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+
+def __make_power_2(img, base, method=Image.BICUBIC):
+ ow, oh = img.size
+ h = int(round(oh / base) * base)
+ w = int(round(ow / base) * base)
+ if (h == oh) and (w == ow):
+ return img
+ return img.resize((w, h), method)
+
+def __scale_width(img, target_width, method=Image.BICUBIC):
+ ow, oh = img.size
+ if (ow == target_width):
+ return img
+ w = target_width
+ h = int(target_width * oh / ow)
+ return img.resize((w, h), method)
+
+def __crop(img, pos, size):
+ ow, oh = img.size
+ x1, y1 = pos
+ tw = th = size
+ if (ow > tw or oh > th):
+ return img.crop((x1, y1, x1 + tw, y1 + th))
+ return img
+
+def __flip(img, flip):
+ if flip:
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
+ return img
diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py
new file mode 100755
index 0000000..0b98254
--- /dev/null
+++ b/data/custom_dataset_data_loader.py
@@ -0,0 +1,31 @@
+import torch.utils.data
+from data.base_data_loader import BaseDataLoader
+
+
+def CreateDataset(opt):
+ dataset = None
+ from data.aligned_dataset import AlignedDataset
+ dataset = AlignedDataset()
+
+ print("dataset [%s] was created" % (dataset.name()))
+ dataset.initialize(opt)
+ return dataset
+
+class CustomDatasetDataLoader(BaseDataLoader):
+ def name(self):
+ return 'CustomDatasetDataLoader'
+
+ def initialize(self, opt):
+ BaseDataLoader.initialize(self, opt)
+ self.dataset = CreateDataset(opt)
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batchSize,
+ shuffle=not opt.serial_batches,
+ num_workers=int(opt.nThreads))
+
+ def load_data(self):
+ return self.dataloader
+
+ def __len__(self):
+ return min(len(self.dataset), self.opt.max_dataset_size)
diff --git a/data/data_loader.py b/data/data_loader.py
new file mode 100755
index 0000000..2a4433a
--- /dev/null
+++ b/data/data_loader.py
@@ -0,0 +1,7 @@
+
+def CreateDataLoader(opt):
+ from data.custom_dataset_data_loader import CustomDatasetDataLoader
+ data_loader = CustomDatasetDataLoader()
+ print(data_loader.name())
+ data_loader.initialize(opt)
+ return data_loader
diff --git a/data/image_folder.py b/data/image_folder.py
new file mode 100755
index 0000000..16a447c
--- /dev/null
+++ b/data/image_folder.py
@@ -0,0 +1,68 @@
+###############################################################################
+# Code from
+# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
+# Modified the original code so that it also loads images from the current
+# directory as well as the subdirectories
+###############################################################################
+
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def make_dataset(dir):
+ images = []
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
+
+ for root, _, fnames in sorted(os.walk(dir)):
+ for fname in fnames:
+ if is_image_file(fname):
+ path = os.path.join(root, fname)
+ images.append(path)
+
+ return images
+
+
+def default_loader(path):
+ return Image.open(path).convert('RGB')
+
+
+class ImageFolder(data.Dataset):
+
+ def __init__(self, root, transform=None, return_paths=False,
+ loader=default_loader):
+ imgs = make_dataset(root)
+ if len(imgs) == 0:
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
+ "Supported image extensions are: " +
+ ",".join(IMG_EXTENSIONS)))
+
+ self.root = root
+ self.imgs = imgs
+ self.transform = transform
+ self.return_paths = return_paths
+ self.loader = loader
+
+ def __getitem__(self, index):
+ path = self.imgs[index]
+ img = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.return_paths:
+ return img, path
+ else:
+ return img
+
+ def __len__(self):
+ return len(self.imgs)