diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/aligned_data_loader.py | 87 | ||||
| -rw-r--r-- | data/aligned_dataset.py | 56 | ||||
| -rw-r--r-- | data/base_data_loader.py | 14 | ||||
| -rw-r--r-- | data/base_dataset.py | 12 | ||||
| -rw-r--r-- | data/custom_dataset_data_loader.py | 41 | ||||
| -rw-r--r-- | data/data_loader.py | 9 | ||||
| -rw-r--r-- | data/image_folder.py | 7 | ||||
| -rw-r--r-- | data/single_dataset.py | 47 | ||||
| -rw-r--r-- | data/unaligned_data_loader.py | 100 | ||||
| -rw-r--r-- | data/unaligned_dataset.py | 56 |
10 files changed, 218 insertions, 211 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py deleted file mode 100644 index d1d4572..0000000 --- a/data/aligned_data_loader.py +++ /dev/null @@ -1,87 +0,0 @@ -import random -import numpy as np -import torch.utils.data -import torchvision.transforms as transforms -from data.base_data_loader import BaseDataLoader -from data.image_folder import ImageFolder -from pdb import set_trace as st -# pip install future --upgrade -from builtins import object - -class PairedData(object): - def __init__(self, data_loader, fineSize, max_dataset_size, flip): - self.data_loader = data_loader - self.fineSize = fineSize - self.max_dataset_size = max_dataset_size - self.flip = flip - # st() - - def __iter__(self): - self.data_loader_iter = iter(self.data_loader) - self.iter = 0 - return self - - def __next__(self): - self.iter += 1 - if self.iter > self.max_dataset_size: - raise StopIteration - - AB, AB_paths = next(self.data_loader_iter) - w_total = AB.size(3) - w = int(w_total / 2) - h = AB.size(2) - - w_offset = random.randint(0, max(0, w - self.fineSize - 1)) - h_offset = random.randint(0, max(0, h - self.fineSize - 1)) - A = AB[:, :, h_offset:h_offset + self.fineSize, - w_offset:w_offset + self.fineSize] - B = AB[:, :, h_offset:h_offset + self.fineSize, - w + w_offset:w + w_offset + self.fineSize] - - if self.flip and random.random() < 0.5: - idx = [i for i in range(A.size(3) - 1, -1, -1)] - idx = torch.LongTensor(idx) - A = A.index_select(3, idx) - B = B.index_select(3, idx) - - - - return {'A': A, 'A_paths': AB_paths, 'B': B, 'B_paths': AB_paths} - - -class AlignedDataLoader(BaseDataLoader): - def initialize(self, opt): - BaseDataLoader.initialize(self, opt) - self.fineSize = opt.fineSize - - transformations = [ - # TODO: Scale - transforms.Scale(opt.loadSize), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] - transform = transforms.Compose(transformations) - - # Dataset A - dataset = ImageFolder(root=opt.dataroot + '/' + opt.phase, - transform=transform, return_paths=True) - data_loader = torch.utils.data.DataLoader( - dataset, - batch_size=self.opt.batchSize, - shuffle=not self.opt.serial_batches, - num_workers=int(self.opt.nThreads)) - - self.dataset = dataset - - flip = opt.isTrain and not opt.no_flip - self.paired_data = PairedData(data_loader, opt.fineSize, - opt.max_dataset_size, flip) - - def name(self): - return 'AlignedDataLoader' - - def load_data(self): - return self.paired_data - - def __len__(self): - return min(len(self.dataset), self.opt.max_dataset_size) diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py new file mode 100644 index 0000000..0f45c40 --- /dev/null +++ b/data/aligned_dataset.py @@ -0,0 +1,56 @@ +import os.path +import random +import torchvision.transforms as transforms +import torch +from data.base_dataset import BaseDataset +from data.image_folder import make_dataset +from PIL import Image + + +class AlignedDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_AB = os.path.join(opt.dataroot, opt.phase) + + self.AB_paths = sorted(make_dataset(self.dir_AB)) + + assert(opt.resize_or_crop == 'resize_and_crop') + + transform_list = [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + + self.transform = transforms.Compose(transform_list) + + def __getitem__(self, index): + AB_path = self.AB_paths[index] + AB = Image.open(AB_path).convert('RGB') + AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC) + AB = self.transform(AB) + + w_total = AB.size(2) + w = int(w_total / 2) + h = AB.size(1) + w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1)) + h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1)) + + A = AB[:, h_offset:h_offset + self.opt.fineSize, + w_offset:w_offset + self.opt.fineSize] + B = AB[:, h_offset:h_offset + self.opt.fineSize, + w + w_offset:w + w_offset + self.opt.fineSize] + + if (not self.opt.no_flip) and random.random() < 0.5: + idx = [i for i in range(A.size(2) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A = A.index_select(2, idx) + B = B.index_select(2, idx) + + return {'A': A, 'B': B, + 'A_paths': AB_path, 'B_paths': AB_path} + + def __len__(self): + return len(self.AB_paths) + + def name(self): + return 'AlignedDataset' diff --git a/data/base_data_loader.py b/data/base_data_loader.py deleted file mode 100644 index 0e1deb5..0000000 --- a/data/base_data_loader.py +++ /dev/null @@ -1,14 +0,0 @@ - -class BaseDataLoader(): - def __init__(self): - pass - - def initialize(self, opt): - self.opt = opt - pass - - def load_data(): - return None - - - diff --git a/data/base_dataset.py b/data/base_dataset.py new file mode 100644 index 0000000..49b9d98 --- /dev/null +++ b/data/base_dataset.py @@ -0,0 +1,12 @@ +import torch.utils.data as data + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + def name(self): + return 'BaseDataset' + + def initialize(self, opt): + pass + diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py new file mode 100644 index 0000000..60180e0 --- /dev/null +++ b/data/custom_dataset_data_loader.py @@ -0,0 +1,41 @@ +import torch.utils.data +from data.base_data_loader import BaseDataLoader + + +def CreateDataset(opt): + dataset = None + if opt.dataset_mode == 'aligned': + from data.aligned_dataset import AlignedDataset + dataset = AlignedDataset() + elif opt.dataset_mode == 'unaligned': + from data.unaligned_dataset import UnalignedDataset + dataset = UnalignedDataset() + elif opt.dataset_mode == 'single': + from data.single_dataset import SingleDataset + dataset = SingleDataset() + else: + raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) + + print("dataset [%s] was created" % (dataset.name())) + dataset.initialize(opt) + return dataset + + +class CustomDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.dataset = CreateDataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads)) + + def load_data(self): + return self.dataloader + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) diff --git a/data/data_loader.py b/data/data_loader.py index 69035ea..2a4433a 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -1,12 +1,7 @@ def CreateDataLoader(opt): - data_loader = None - if opt.align_data > 0: - from data.aligned_data_loader import AlignedDataLoader - data_loader = AlignedDataLoader() - else: - from data.unaligned_data_loader import UnalignedDataLoader - data_loader = UnalignedDataLoader() + from data.custom_dataset_data_loader import CustomDatasetDataLoader + data_loader = CustomDatasetDataLoader() print(data_loader.name()) data_loader.initialize(opt) return data_loader diff --git a/data/image_folder.py b/data/image_folder.py index 44e15cb..898200b 100644 --- a/data/image_folder.py +++ b/data/image_folder.py @@ -1,9 +1,9 @@ -################################################################################ +############################################################################### # Code from # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py # Modified the original code so that it also loads images from the current # directory as well as the subdirectories -################################################################################ +############################################################################### import torch.utils.data as data @@ -45,7 +45,8 @@ class ImageFolder(data.Dataset): imgs = make_dataset(root) if len(imgs) == 0: raise(RuntimeError("Found 0 images in: " + root + "\n" - "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs diff --git a/data/single_dataset.py b/data/single_dataset.py new file mode 100644 index 0000000..106bea3 --- /dev/null +++ b/data/single_dataset.py @@ -0,0 +1,47 @@ +import os.path +import torchvision.transforms as transforms +from data.base_dataset import BaseDataset +from data.image_folder import make_dataset +from PIL import Image + + +class SingleDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot) + + self.A_paths = make_dataset(self.dir_A) + + self.A_paths = sorted(self.A_paths) + + transform_list = [] + if opt.resize_or_crop == 'resize_and_crop': + transform_list.append(transforms.Scale(opt.loadSize)) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + if opt.resize_or_crop != 'no_resize': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + + self.transform = transforms.Compose(transform_list) + + def __getitem__(self, index): + A_path = self.A_paths[index] + + A_img = Image.open(A_path).convert('RGB') + + A_img = self.transform(A_img) + + return {'A': A_img, 'A_paths': A_path} + + def __len__(self): + return len(self.A_paths) + + def name(self): + return 'SingleImageDataset' diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py deleted file mode 100644 index bd0ea75..0000000 --- a/data/unaligned_data_loader.py +++ /dev/null @@ -1,100 +0,0 @@ -import random -import torch.utils.data -import torchvision.transforms as transforms -from data.base_data_loader import BaseDataLoader -from data.image_folder import ImageFolder -# pip install future --upgrade -from builtins import object -from pdb import set_trace as st - -class PairedData(object): - def __init__(self, data_loader_A, data_loader_B, max_dataset_size, flip): - self.data_loader_A = data_loader_A - self.data_loader_B = data_loader_B - self.stop_A = False - self.stop_B = False - self.max_dataset_size = max_dataset_size - self.flip = flip - - def __iter__(self): - self.stop_A = False - self.stop_B = False - self.data_loader_A_iter = iter(self.data_loader_A) - self.data_loader_B_iter = iter(self.data_loader_B) - self.iter = 0 - return self - - def __next__(self): - A, A_paths = None, None - B, B_paths = None, None - try: - A, A_paths = next(self.data_loader_A_iter) - except StopIteration: - if A is None or A_paths is None: - self.stop_A = True - self.data_loader_A_iter = iter(self.data_loader_A) - A, A_paths = next(self.data_loader_A_iter) - - try: - B, B_paths = next(self.data_loader_B_iter) - except StopIteration: - if B is None or B_paths is None: - self.stop_B = True - self.data_loader_B_iter = iter(self.data_loader_B) - B, B_paths = next(self.data_loader_B_iter) - - if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size: - self.stop_A = False - self.stop_B = False - raise StopIteration() - else: - self.iter += 1 - if self.flip and random.random() < 0.5: - idx = [i for i in range(A.size(3) - 1, -1, -1)] - idx = torch.LongTensor(idx) - A = A.index_select(3, idx) - B = B.index_select(3, idx) - return {'A': A, 'A_paths': A_paths, - 'B': B, 'B_paths': B_paths} - -class UnalignedDataLoader(BaseDataLoader): - def initialize(self, opt): - BaseDataLoader.initialize(self, opt) - transformations = [transforms.Scale(opt.loadSize), - transforms.RandomCrop(opt.fineSize), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] - transform = transforms.Compose(transformations) - - # Dataset A - dataset_A = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'A', - transform=transform, return_paths=True) - data_loader_A = torch.utils.data.DataLoader( - dataset_A, - batch_size=self.opt.batchSize, - shuffle=not self.opt.serial_batches, - num_workers=int(self.opt.nThreads)) - - # Dataset B - dataset_B = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'B', - transform=transform, return_paths=True) - data_loader_B = torch.utils.data.DataLoader( - dataset_B, - batch_size=self.opt.batchSize, - shuffle=not self.opt.serial_batches, - num_workers=int(self.opt.nThreads)) - self.dataset_A = dataset_A - self.dataset_B = dataset_B - flip = opt.isTrain and not opt.no_flip - self.paired_data = PairedData(data_loader_A, data_loader_B, - self.opt.max_dataset_size, flip) - - def name(self): - return 'UnalignedDataLoader' - - def load_data(self): - return self.paired_data - - def __len__(self): - return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.max_dataset_size) diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py new file mode 100644 index 0000000..1f75b23 --- /dev/null +++ b/data/unaligned_dataset.py @@ -0,0 +1,56 @@ +import os.path +import torchvision.transforms as transforms +from data.base_dataset import BaseDataset +from data.image_folder import make_dataset +from PIL import Image +import PIL +from pdb import set_trace as st + + +class UnalignedDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') + self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') + + self.A_paths = make_dataset(self.dir_A) + self.B_paths = make_dataset(self.dir_B) + + self.A_paths = sorted(self.A_paths) + self.B_paths = sorted(self.B_paths) + + transform_list = [] + if opt.resize_or_crop == 'resize_and_crop': + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Scale(osize, Image.BICUBIC)) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + if opt.resize_or_crop != 'no_resize': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + self.transform = transforms.Compose(transform_list) + + def __getitem__(self, index): + A_path = self.A_paths[index] + B_path = self.B_paths[index] + + A_img = Image.open(A_path).convert('RGB') + B_img = Image.open(B_path).convert('RGB') + + A_img = self.transform(A_img) + B_img = self.transform(B_img) + + return {'A': A_img, 'B': B_img, + 'A_paths': A_path, 'B_paths': B_path} + + def __len__(self): + return min(len(self.A_paths), len(self.B_paths)) + + def name(self): + return 'UnalignedDataset' |
