summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-04-18 03:38:47 -0700
committerjunyanz <junyanz@berkeley.edu>2017-04-18 03:38:47 -0700
commitc99ce7c4e781712e0252c6127ad1a4e8021cc489 (patch)
treeba99dfd56a47036d9c1f18620abf4efc248839ab /data
first commit
Diffstat (limited to 'data')
-rw-r--r--data/__init__.py0
-rw-r--r--data/aligned_data_loader.py69
-rw-r--r--data/base_data_loader.py14
-rw-r--r--data/data_loader.py12
-rw-r--r--data/image_folder.py67
-rw-r--r--data/unaligned_data_loader.py63
6 files changed, 225 insertions, 0 deletions
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/data/__init__.py
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
new file mode 100644
index 0000000..01dbf89
--- /dev/null
+++ b/data/aligned_data_loader.py
@@ -0,0 +1,69 @@
+import random
+import torch.utils.data
+import torchvision.transforms as transforms
+from data.base_data_loader import BaseDataLoader
+from data.image_folder import ImageFolder
+from pdb import set_trace as st
+from builtins import object
+
+class PairedData(object):
+ def __init__(self, data_loader, fineSize):
+ self.data_loader = data_loader
+ self.fineSize = fineSize
+ # st()
+
+ def __iter__(self):
+ self.data_loader_iter = iter(self.data_loader)
+ return self
+
+ def __next__(self):
+ # st()
+ AB, AB_paths = next(self.data_loader_iter)
+ # st()
+ w_total = AB.size(3)
+ w = int(w_total / 2)
+ h = AB.size(2)
+
+ w_offset = random.randint(0, max(0, w - self.fineSize - 1))
+ h_offset = random.randint(0, max(0, h - self.fineSize - 1))
+
+ A = AB[:, :, h_offset:h_offset + self.fineSize,
+ w_offset:w_offset + self.fineSize]
+ B = AB[:, :, h_offset:h_offset + self.fineSize,
+ w + w_offset:w + w_offset + self.fineSize]
+
+ return {'A': A, 'A_paths': AB_paths, 'B': B, 'B_paths': AB_paths}
+
+
+class AlignedDataLoader(BaseDataLoader):
+ def initialize(self, opt):
+ BaseDataLoader.initialize(self, opt)
+ self.fineSize = opt.fineSize
+ transform = transforms.Compose([
+ # TODO: Scale
+ #transforms.Scale((opt.loadSize * 2, opt.loadSize)),
+ #transforms.CenterCrop(opt.fineSize),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))])
+
+ # Dataset A
+ dataset = ImageFolder(root=opt.dataroot + '/' + opt.phase,
+ transform=transform, return_paths=True)
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=self.opt.batchSize,
+ shuffle=not self.opt.serial_batches,
+ num_workers=int(self.opt.nThreads))
+
+ self.dataset = dataset
+ self.paired_data = PairedData(data_loader, opt.fineSize)
+
+ def name(self):
+ return 'AlignedDataLoader'
+
+ def load_data(self):
+ return self.paired_data
+
+ def __len__(self):
+ return len(self.dataset)
diff --git a/data/base_data_loader.py b/data/base_data_loader.py
new file mode 100644
index 0000000..0e1deb5
--- /dev/null
+++ b/data/base_data_loader.py
@@ -0,0 +1,14 @@
+
+class BaseDataLoader():
+ def __init__(self):
+ pass
+
+ def initialize(self, opt):
+ self.opt = opt
+ pass
+
+ def load_data():
+ return None
+
+
+
diff --git a/data/data_loader.py b/data/data_loader.py
new file mode 100644
index 0000000..69035ea
--- /dev/null
+++ b/data/data_loader.py
@@ -0,0 +1,12 @@
+
+def CreateDataLoader(opt):
+ data_loader = None
+ if opt.align_data > 0:
+ from data.aligned_data_loader import AlignedDataLoader
+ data_loader = AlignedDataLoader()
+ else:
+ from data.unaligned_data_loader import UnalignedDataLoader
+ data_loader = UnalignedDataLoader()
+ print(data_loader.name())
+ data_loader.initialize(opt)
+ return data_loader
diff --git a/data/image_folder.py b/data/image_folder.py
new file mode 100644
index 0000000..44e15cb
--- /dev/null
+++ b/data/image_folder.py
@@ -0,0 +1,67 @@
+################################################################################
+# Code from
+# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
+# Modified the original code so that it also loads images from the current
+# directory as well as the subdirectories
+################################################################################
+
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def make_dataset(dir):
+ images = []
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
+
+ for root, _, fnames in sorted(os.walk(dir)):
+ for fname in fnames:
+ if is_image_file(fname):
+ path = os.path.join(root, fname)
+ images.append(path)
+
+ return images
+
+
+def default_loader(path):
+ return Image.open(path).convert('RGB')
+
+
+class ImageFolder(data.Dataset):
+
+ def __init__(self, root, transform=None, return_paths=False,
+ loader=default_loader):
+ imgs = make_dataset(root)
+ if len(imgs) == 0:
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
+
+ self.root = root
+ self.imgs = imgs
+ self.transform = transform
+ self.return_paths = return_paths
+ self.loader = loader
+
+ def __getitem__(self, index):
+ path = self.imgs[index]
+ img = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.return_paths:
+ return img, path
+ else:
+ return img
+
+ def __len__(self):
+ return len(self.imgs)
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
new file mode 100644
index 0000000..95d9ac7
--- /dev/null
+++ b/data/unaligned_data_loader.py
@@ -0,0 +1,63 @@
+import torch.utils.data
+import torchvision.transforms as transforms
+from data.base_data_loader import BaseDataLoader
+from data.image_folder import ImageFolder
+from builtins import object
+
+
+class PairedData(object):
+ def __init__(self, data_loader_A, data_loader_B):
+ self.data_loader_A = data_loader_A
+ self.data_loader_B = data_loader_B
+
+ def __iter__(self):
+ self.data_loader_A_iter = iter(self.data_loader_A)
+ self.data_loader_B_iter = iter(self.data_loader_B)
+ return self
+
+ def __next__(self):
+ A, A_paths = next(self.data_loader_A_iter)
+ B, B_paths = next(self.data_loader_B_iter)
+ return {'A': A, 'A_paths': A_paths,
+ 'B': B, 'B_paths': B_paths}
+
+
+class UnalignedDataLoader(BaseDataLoader):
+ def initialize(self, opt):
+ BaseDataLoader.initialize(self, opt)
+ transform = transforms.Compose([
+ transforms.Scale(opt.loadSize),
+ transforms.CenterCrop(opt.fineSize),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))])
+
+ # Dataset A
+ dataset_A = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'A',
+ transform=transform, return_paths=True)
+ data_loader_A = torch.utils.data.DataLoader(
+ dataset_A,
+ batch_size=self.opt.batchSize,
+ shuffle=not self.opt.serial_batches,
+ num_workers=int(self.opt.nThreads))
+
+ # Dataset B
+ dataset_B = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'B',
+ transform=transform, return_paths=True)
+ data_loader_B = torch.utils.data.DataLoader(
+ dataset_B,
+ batch_size=self.opt.batchSize,
+ shuffle=not self.opt.serial_batches,
+ num_workers=int(self.opt.nThreads))
+ self.dataset_A = dataset_A
+ self.dataset_B = dataset_B
+ self.paired_data = PairedData(data_loader_A, data_loader_B)
+
+ def name(self):
+ return 'UnalignedDataLoader'
+
+ def load_data(self):
+ return self.paired_data
+
+ def __len__(self):
+ return len(self.dataset_A)