summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/aligned_data_loader.py87
-rw-r--r--data/aligned_dataset.py56
-rw-r--r--data/base_data_loader.py14
-rw-r--r--data/base_dataset.py12
-rw-r--r--data/custom_dataset_data_loader.py41
-rw-r--r--data/data_loader.py9
-rw-r--r--data/image_folder.py7
-rw-r--r--data/single_dataset.py47
-rw-r--r--data/unaligned_data_loader.py100
-rw-r--r--data/unaligned_dataset.py56
10 files changed, 218 insertions, 211 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
deleted file mode 100644
index d1d4572..0000000
--- a/data/aligned_data_loader.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import random
-import numpy as np
-import torch.utils.data
-import torchvision.transforms as transforms
-from data.base_data_loader import BaseDataLoader
-from data.image_folder import ImageFolder
-from pdb import set_trace as st
-# pip install future --upgrade
-from builtins import object
-
-class PairedData(object):
- def __init__(self, data_loader, fineSize, max_dataset_size, flip):
- self.data_loader = data_loader
- self.fineSize = fineSize
- self.max_dataset_size = max_dataset_size
- self.flip = flip
- # st()
-
- def __iter__(self):
- self.data_loader_iter = iter(self.data_loader)
- self.iter = 0
- return self
-
- def __next__(self):
- self.iter += 1
- if self.iter > self.max_dataset_size:
- raise StopIteration
-
- AB, AB_paths = next(self.data_loader_iter)
- w_total = AB.size(3)
- w = int(w_total / 2)
- h = AB.size(2)
-
- w_offset = random.randint(0, max(0, w - self.fineSize - 1))
- h_offset = random.randint(0, max(0, h - self.fineSize - 1))
- A = AB[:, :, h_offset:h_offset + self.fineSize,
- w_offset:w_offset + self.fineSize]
- B = AB[:, :, h_offset:h_offset + self.fineSize,
- w + w_offset:w + w_offset + self.fineSize]
-
- if self.flip and random.random() < 0.5:
- idx = [i for i in range(A.size(3) - 1, -1, -1)]
- idx = torch.LongTensor(idx)
- A = A.index_select(3, idx)
- B = B.index_select(3, idx)
-
-
-
- return {'A': A, 'A_paths': AB_paths, 'B': B, 'B_paths': AB_paths}
-
-
-class AlignedDataLoader(BaseDataLoader):
- def initialize(self, opt):
- BaseDataLoader.initialize(self, opt)
- self.fineSize = opt.fineSize
-
- transformations = [
- # TODO: Scale
- transforms.Scale(opt.loadSize),
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5),
- (0.5, 0.5, 0.5))]
- transform = transforms.Compose(transformations)
-
- # Dataset A
- dataset = ImageFolder(root=opt.dataroot + '/' + opt.phase,
- transform=transform, return_paths=True)
- data_loader = torch.utils.data.DataLoader(
- dataset,
- batch_size=self.opt.batchSize,
- shuffle=not self.opt.serial_batches,
- num_workers=int(self.opt.nThreads))
-
- self.dataset = dataset
-
- flip = opt.isTrain and not opt.no_flip
- self.paired_data = PairedData(data_loader, opt.fineSize,
- opt.max_dataset_size, flip)
-
- def name(self):
- return 'AlignedDataLoader'
-
- def load_data(self):
- return self.paired_data
-
- def __len__(self):
- return min(len(self.dataset), self.opt.max_dataset_size)
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
new file mode 100644
index 0000000..0f45c40
--- /dev/null
+++ b/data/aligned_dataset.py
@@ -0,0 +1,56 @@
+import os.path
+import random
+import torchvision.transforms as transforms
+import torch
+from data.base_dataset import BaseDataset
+from data.image_folder import make_dataset
+from PIL import Image
+
+
+class AlignedDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+ self.dir_AB = os.path.join(opt.dataroot, opt.phase)
+
+ self.AB_paths = sorted(make_dataset(self.dir_AB))
+
+ assert(opt.resize_or_crop == 'resize_and_crop')
+
+ transform_list = [transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+
+ self.transform = transforms.Compose(transform_list)
+
+ def __getitem__(self, index):
+ AB_path = self.AB_paths[index]
+ AB = Image.open(AB_path).convert('RGB')
+ AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC)
+ AB = self.transform(AB)
+
+ w_total = AB.size(2)
+ w = int(w_total / 2)
+ h = AB.size(1)
+ w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
+ h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
+
+ A = AB[:, h_offset:h_offset + self.opt.fineSize,
+ w_offset:w_offset + self.opt.fineSize]
+ B = AB[:, h_offset:h_offset + self.opt.fineSize,
+ w + w_offset:w + w_offset + self.opt.fineSize]
+
+ if (not self.opt.no_flip) and random.random() < 0.5:
+ idx = [i for i in range(A.size(2) - 1, -1, -1)]
+ idx = torch.LongTensor(idx)
+ A = A.index_select(2, idx)
+ B = B.index_select(2, idx)
+
+ return {'A': A, 'B': B,
+ 'A_paths': AB_path, 'B_paths': AB_path}
+
+ def __len__(self):
+ return len(self.AB_paths)
+
+ def name(self):
+ return 'AlignedDataset'
diff --git a/data/base_data_loader.py b/data/base_data_loader.py
deleted file mode 100644
index 0e1deb5..0000000
--- a/data/base_data_loader.py
+++ /dev/null
@@ -1,14 +0,0 @@
-
-class BaseDataLoader():
- def __init__(self):
- pass
-
- def initialize(self, opt):
- self.opt = opt
- pass
-
- def load_data():
- return None
-
-
-
diff --git a/data/base_dataset.py b/data/base_dataset.py
new file mode 100644
index 0000000..49b9d98
--- /dev/null
+++ b/data/base_dataset.py
@@ -0,0 +1,12 @@
+import torch.utils.data as data
+
+class BaseDataset(data.Dataset):
+ def __init__(self):
+ super(BaseDataset, self).__init__()
+
+ def name(self):
+ return 'BaseDataset'
+
+ def initialize(self, opt):
+ pass
+
diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py
new file mode 100644
index 0000000..60180e0
--- /dev/null
+++ b/data/custom_dataset_data_loader.py
@@ -0,0 +1,41 @@
+import torch.utils.data
+from data.base_data_loader import BaseDataLoader
+
+
+def CreateDataset(opt):
+ dataset = None
+ if opt.dataset_mode == 'aligned':
+ from data.aligned_dataset import AlignedDataset
+ dataset = AlignedDataset()
+ elif opt.dataset_mode == 'unaligned':
+ from data.unaligned_dataset import UnalignedDataset
+ dataset = UnalignedDataset()
+ elif opt.dataset_mode == 'single':
+ from data.single_dataset import SingleDataset
+ dataset = SingleDataset()
+ else:
+ raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
+
+ print("dataset [%s] was created" % (dataset.name()))
+ dataset.initialize(opt)
+ return dataset
+
+
+class CustomDatasetDataLoader(BaseDataLoader):
+ def name(self):
+ return 'CustomDatasetDataLoader'
+
+ def initialize(self, opt):
+ BaseDataLoader.initialize(self, opt)
+ self.dataset = CreateDataset(opt)
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batchSize,
+ shuffle=not opt.serial_batches,
+ num_workers=int(opt.nThreads))
+
+ def load_data(self):
+ return self.dataloader
+
+ def __len__(self):
+ return min(len(self.dataset), self.opt.max_dataset_size)
diff --git a/data/data_loader.py b/data/data_loader.py
index 69035ea..2a4433a 100644
--- a/data/data_loader.py
+++ b/data/data_loader.py
@@ -1,12 +1,7 @@
def CreateDataLoader(opt):
- data_loader = None
- if opt.align_data > 0:
- from data.aligned_data_loader import AlignedDataLoader
- data_loader = AlignedDataLoader()
- else:
- from data.unaligned_data_loader import UnalignedDataLoader
- data_loader = UnalignedDataLoader()
+ from data.custom_dataset_data_loader import CustomDatasetDataLoader
+ data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
diff --git a/data/image_folder.py b/data/image_folder.py
index 44e15cb..898200b 100644
--- a/data/image_folder.py
+++ b/data/image_folder.py
@@ -1,9 +1,9 @@
-################################################################################
+###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
-################################################################################
+###############################################################################
import torch.utils.data as data
@@ -45,7 +45,8 @@ class ImageFolder(data.Dataset):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
- "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
+ "Supported image extensions are: " +
+ ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
diff --git a/data/single_dataset.py b/data/single_dataset.py
new file mode 100644
index 0000000..106bea3
--- /dev/null
+++ b/data/single_dataset.py
@@ -0,0 +1,47 @@
+import os.path
+import torchvision.transforms as transforms
+from data.base_dataset import BaseDataset
+from data.image_folder import make_dataset
+from PIL import Image
+
+
+class SingleDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+ self.dir_A = os.path.join(opt.dataroot)
+
+ self.A_paths = make_dataset(self.dir_A)
+
+ self.A_paths = sorted(self.A_paths)
+
+ transform_list = []
+ if opt.resize_or_crop == 'resize_and_crop':
+ transform_list.append(transforms.Scale(opt.loadSize))
+
+ if opt.isTrain and not opt.no_flip:
+ transform_list.append(transforms.RandomHorizontalFlip())
+
+ if opt.resize_or_crop != 'no_resize':
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
+
+ transform_list += [transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+
+ self.transform = transforms.Compose(transform_list)
+
+ def __getitem__(self, index):
+ A_path = self.A_paths[index]
+
+ A_img = Image.open(A_path).convert('RGB')
+
+ A_img = self.transform(A_img)
+
+ return {'A': A_img, 'A_paths': A_path}
+
+ def __len__(self):
+ return len(self.A_paths)
+
+ def name(self):
+ return 'SingleImageDataset'
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
deleted file mode 100644
index bd0ea75..0000000
--- a/data/unaligned_data_loader.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import random
-import torch.utils.data
-import torchvision.transforms as transforms
-from data.base_data_loader import BaseDataLoader
-from data.image_folder import ImageFolder
-# pip install future --upgrade
-from builtins import object
-from pdb import set_trace as st
-
-class PairedData(object):
- def __init__(self, data_loader_A, data_loader_B, max_dataset_size, flip):
- self.data_loader_A = data_loader_A
- self.data_loader_B = data_loader_B
- self.stop_A = False
- self.stop_B = False
- self.max_dataset_size = max_dataset_size
- self.flip = flip
-
- def __iter__(self):
- self.stop_A = False
- self.stop_B = False
- self.data_loader_A_iter = iter(self.data_loader_A)
- self.data_loader_B_iter = iter(self.data_loader_B)
- self.iter = 0
- return self
-
- def __next__(self):
- A, A_paths = None, None
- B, B_paths = None, None
- try:
- A, A_paths = next(self.data_loader_A_iter)
- except StopIteration:
- if A is None or A_paths is None:
- self.stop_A = True
- self.data_loader_A_iter = iter(self.data_loader_A)
- A, A_paths = next(self.data_loader_A_iter)
-
- try:
- B, B_paths = next(self.data_loader_B_iter)
- except StopIteration:
- if B is None or B_paths is None:
- self.stop_B = True
- self.data_loader_B_iter = iter(self.data_loader_B)
- B, B_paths = next(self.data_loader_B_iter)
-
- if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size:
- self.stop_A = False
- self.stop_B = False
- raise StopIteration()
- else:
- self.iter += 1
- if self.flip and random.random() < 0.5:
- idx = [i for i in range(A.size(3) - 1, -1, -1)]
- idx = torch.LongTensor(idx)
- A = A.index_select(3, idx)
- B = B.index_select(3, idx)
- return {'A': A, 'A_paths': A_paths,
- 'B': B, 'B_paths': B_paths}
-
-class UnalignedDataLoader(BaseDataLoader):
- def initialize(self, opt):
- BaseDataLoader.initialize(self, opt)
- transformations = [transforms.Scale(opt.loadSize),
- transforms.RandomCrop(opt.fineSize),
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5),
- (0.5, 0.5, 0.5))]
- transform = transforms.Compose(transformations)
-
- # Dataset A
- dataset_A = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'A',
- transform=transform, return_paths=True)
- data_loader_A = torch.utils.data.DataLoader(
- dataset_A,
- batch_size=self.opt.batchSize,
- shuffle=not self.opt.serial_batches,
- num_workers=int(self.opt.nThreads))
-
- # Dataset B
- dataset_B = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'B',
- transform=transform, return_paths=True)
- data_loader_B = torch.utils.data.DataLoader(
- dataset_B,
- batch_size=self.opt.batchSize,
- shuffle=not self.opt.serial_batches,
- num_workers=int(self.opt.nThreads))
- self.dataset_A = dataset_A
- self.dataset_B = dataset_B
- flip = opt.isTrain and not opt.no_flip
- self.paired_data = PairedData(data_loader_A, data_loader_B,
- self.opt.max_dataset_size, flip)
-
- def name(self):
- return 'UnalignedDataLoader'
-
- def load_data(self):
- return self.paired_data
-
- def __len__(self):
- return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.max_dataset_size)
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
new file mode 100644
index 0000000..1f75b23
--- /dev/null
+++ b/data/unaligned_dataset.py
@@ -0,0 +1,56 @@
+import os.path
+import torchvision.transforms as transforms
+from data.base_dataset import BaseDataset
+from data.image_folder import make_dataset
+from PIL import Image
+import PIL
+from pdb import set_trace as st
+
+
+class UnalignedDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
+
+ self.A_paths = make_dataset(self.dir_A)
+ self.B_paths = make_dataset(self.dir_B)
+
+ self.A_paths = sorted(self.A_paths)
+ self.B_paths = sorted(self.B_paths)
+
+ transform_list = []
+ if opt.resize_or_crop == 'resize_and_crop':
+ osize = [opt.loadSize, opt.loadSize]
+ transform_list.append(transforms.Scale(osize, Image.BICUBIC))
+
+ if opt.isTrain and not opt.no_flip:
+ transform_list.append(transforms.RandomHorizontalFlip())
+
+ if opt.resize_or_crop != 'no_resize':
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
+
+ transform_list += [transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+ self.transform = transforms.Compose(transform_list)
+
+ def __getitem__(self, index):
+ A_path = self.A_paths[index]
+ B_path = self.B_paths[index]
+
+ A_img = Image.open(A_path).convert('RGB')
+ B_img = Image.open(B_path).convert('RGB')
+
+ A_img = self.transform(A_img)
+ B_img = self.transform(B_img)
+
+ return {'A': A_img, 'B': B_img,
+ 'A_paths': A_path, 'B_paths': B_path}
+
+ def __len__(self):
+ return min(len(self.A_paths), len(self.B_paths))
+
+ def name(self):
+ return 'UnalignedDataset'