diff options
Diffstat (limited to 'data/aligned_data_loader.py')
| -rw-r--r-- | data/aligned_data_loader.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py new file mode 100644 index 0000000..01dbf89 --- /dev/null +++ b/data/aligned_data_loader.py @@ -0,0 +1,69 @@ +import random +import torch.utils.data +import torchvision.transforms as transforms +from data.base_data_loader import BaseDataLoader +from data.image_folder import ImageFolder +from pdb import set_trace as st +from builtins import object + +class PairedData(object): + def __init__(self, data_loader, fineSize): + self.data_loader = data_loader + self.fineSize = fineSize + # st() + + def __iter__(self): + self.data_loader_iter = iter(self.data_loader) + return self + + def __next__(self): + # st() + AB, AB_paths = next(self.data_loader_iter) + # st() + w_total = AB.size(3) + w = int(w_total / 2) + h = AB.size(2) + + w_offset = random.randint(0, max(0, w - self.fineSize - 1)) + h_offset = random.randint(0, max(0, h - self.fineSize - 1)) + + A = AB[:, :, h_offset:h_offset + self.fineSize, + w_offset:w_offset + self.fineSize] + B = AB[:, :, h_offset:h_offset + self.fineSize, + w + w_offset:w + w_offset + self.fineSize] + + return {'A': A, 'A_paths': AB_paths, 'B': B, 'B_paths': AB_paths} + + +class AlignedDataLoader(BaseDataLoader): + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.fineSize = opt.fineSize + transform = transforms.Compose([ + # TODO: Scale + #transforms.Scale((opt.loadSize * 2, opt.loadSize)), + #transforms.CenterCrop(opt.fineSize), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))]) + + # Dataset A + dataset = ImageFolder(root=opt.dataroot + '/' + opt.phase, + transform=transform, return_paths=True) + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.opt.batchSize, + shuffle=not self.opt.serial_batches, + num_workers=int(self.opt.nThreads)) + + self.dataset = dataset + self.paired_data = PairedData(data_loader, opt.fineSize) + + def name(self): + return 'AlignedDataLoader' + + def load_data(self): + return self.paired_data + + def __len__(self): + return len(self.dataset) |
