diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-04-18 03:38:47 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-04-18 03:38:47 -0700 |
| commit | c99ce7c4e781712e0252c6127ad1a4e8021cc489 (patch) | |
| tree | ba99dfd56a47036d9c1f18620abf4efc248839ab /data | |
first commit
Diffstat (limited to 'data')
| -rw-r--r-- | data/__init__.py | 0 | ||||
| -rw-r--r-- | data/aligned_data_loader.py | 69 | ||||
| -rw-r--r-- | data/base_data_loader.py | 14 | ||||
| -rw-r--r-- | data/data_loader.py | 12 | ||||
| -rw-r--r-- | data/image_folder.py | 67 | ||||
| -rw-r--r-- | data/unaligned_data_loader.py | 63 |
6 files changed, 225 insertions, 0 deletions
diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/data/__init__.py diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py new file mode 100644 index 0000000..01dbf89 --- /dev/null +++ b/data/aligned_data_loader.py @@ -0,0 +1,69 @@ +import random +import torch.utils.data +import torchvision.transforms as transforms +from data.base_data_loader import BaseDataLoader +from data.image_folder import ImageFolder +from pdb import set_trace as st +from builtins import object + +class PairedData(object): + def __init__(self, data_loader, fineSize): + self.data_loader = data_loader + self.fineSize = fineSize + # st() + + def __iter__(self): + self.data_loader_iter = iter(self.data_loader) + return self + + def __next__(self): + # st() + AB, AB_paths = next(self.data_loader_iter) + # st() + w_total = AB.size(3) + w = int(w_total / 2) + h = AB.size(2) + + w_offset = random.randint(0, max(0, w - self.fineSize - 1)) + h_offset = random.randint(0, max(0, h - self.fineSize - 1)) + + A = AB[:, :, h_offset:h_offset + self.fineSize, + w_offset:w_offset + self.fineSize] + B = AB[:, :, h_offset:h_offset + self.fineSize, + w + w_offset:w + w_offset + self.fineSize] + + return {'A': A, 'A_paths': AB_paths, 'B': B, 'B_paths': AB_paths} + + +class AlignedDataLoader(BaseDataLoader): + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.fineSize = opt.fineSize + transform = transforms.Compose([ + # TODO: Scale + #transforms.Scale((opt.loadSize * 2, opt.loadSize)), + #transforms.CenterCrop(opt.fineSize), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))]) + + # Dataset A + dataset = ImageFolder(root=opt.dataroot + '/' + opt.phase, + transform=transform, return_paths=True) + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.opt.batchSize, + shuffle=not self.opt.serial_batches, + num_workers=int(self.opt.nThreads)) + + self.dataset = dataset + self.paired_data = PairedData(data_loader, opt.fineSize) + + def name(self): + return 'AlignedDataLoader' + + def load_data(self): + return self.paired_data + + def __len__(self): + return len(self.dataset) diff --git a/data/base_data_loader.py b/data/base_data_loader.py new file mode 100644 index 0000000..0e1deb5 --- /dev/null +++ b/data/base_data_loader.py @@ -0,0 +1,14 @@ + +class BaseDataLoader(): + def __init__(self): + pass + + def initialize(self, opt): + self.opt = opt + pass + + def load_data(): + return None + + + diff --git a/data/data_loader.py b/data/data_loader.py new file mode 100644 index 0000000..69035ea --- /dev/null +++ b/data/data_loader.py @@ -0,0 +1,12 @@ + +def CreateDataLoader(opt): + data_loader = None + if opt.align_data > 0: + from data.aligned_data_loader import AlignedDataLoader + data_loader = AlignedDataLoader() + else: + from data.unaligned_data_loader import UnalignedDataLoader + data_loader = UnalignedDataLoader() + print(data_loader.name()) + data_loader.initialize(opt) + return data_loader diff --git a/data/image_folder.py b/data/image_folder.py new file mode 100644 index 0000000..44e15cb --- /dev/null +++ b/data/image_folder.py @@ -0,0 +1,67 @@ +################################################################################ +# Code from +# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py +# Modified the original code so that it also loads images from the current +# directory as well as the subdirectories +################################################################################ + +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + + return images + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py new file mode 100644 index 0000000..95d9ac7 --- /dev/null +++ b/data/unaligned_data_loader.py @@ -0,0 +1,63 @@ +import torch.utils.data +import torchvision.transforms as transforms +from data.base_data_loader import BaseDataLoader +from data.image_folder import ImageFolder +from builtins import object + + +class PairedData(object): + def __init__(self, data_loader_A, data_loader_B): + self.data_loader_A = data_loader_A + self.data_loader_B = data_loader_B + + def __iter__(self): + self.data_loader_A_iter = iter(self.data_loader_A) + self.data_loader_B_iter = iter(self.data_loader_B) + return self + + def __next__(self): + A, A_paths = next(self.data_loader_A_iter) + B, B_paths = next(self.data_loader_B_iter) + return {'A': A, 'A_paths': A_paths, + 'B': B, 'B_paths': B_paths} + + +class UnalignedDataLoader(BaseDataLoader): + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + transform = transforms.Compose([ + transforms.Scale(opt.loadSize), + transforms.CenterCrop(opt.fineSize), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))]) + + # Dataset A + dataset_A = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'A', + transform=transform, return_paths=True) + data_loader_A = torch.utils.data.DataLoader( + dataset_A, + batch_size=self.opt.batchSize, + shuffle=not self.opt.serial_batches, + num_workers=int(self.opt.nThreads)) + + # Dataset B + dataset_B = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'B', + transform=transform, return_paths=True) + data_loader_B = torch.utils.data.DataLoader( + dataset_B, + batch_size=self.opt.batchSize, + shuffle=not self.opt.serial_batches, + num_workers=int(self.opt.nThreads)) + self.dataset_A = dataset_A + self.dataset_B = dataset_B + self.paired_data = PairedData(data_loader_A, data_loader_B) + + def name(self): + return 'UnalignedDataLoader' + + def load_data(self): + return self.paired_data + + def __len__(self): + return len(self.dataset_A) |
