diff options
| author | Taesung Park <taesung_park@berkeley.edu> | 2017-05-07 23:23:22 -0700 |
|---|---|---|
| committer | Taesung Park <taesung_park@berkeley.edu> | 2017-05-07 23:23:22 -0700 |
| commit | 68d0d0dfc9fc18ad65752bf01180cc1668255ba0 (patch) | |
| tree | 307acc6fcc087de6dfc49f58c4e72bcfe959197a /data | |
| parent | 5f6e2c4a115a6a706cc011b3bf9ed9e3ef149d98 (diff) | |
fixed a bug about flipping
Diffstat (limited to 'data')
| -rw-r--r-- | data/aligned_data_loader.py | 19 | ||||
| -rw-r--r-- | data/unaligned_data_loader.py | 15 |
2 files changed, 26 insertions, 8 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py index 039c113..d1d4572 100644 --- a/data/aligned_data_loader.py +++ b/data/aligned_data_loader.py @@ -1,4 +1,5 @@ import random +import numpy as np import torch.utils.data import torchvision.transforms as transforms from data.base_data_loader import BaseDataLoader @@ -8,10 +9,11 @@ from pdb import set_trace as st from builtins import object class PairedData(object): - def __init__(self, data_loader, fineSize, max_dataset_size): + def __init__(self, data_loader, fineSize, max_dataset_size, flip): self.data_loader = data_loader self.fineSize = fineSize self.max_dataset_size = max_dataset_size + self.flip = flip # st() def __iter__(self): @@ -36,6 +38,14 @@ class PairedData(object): B = AB[:, :, h_offset:h_offset + self.fineSize, w + w_offset:w + w_offset + self.fineSize] + if self.flip and random.random() < 0.5: + idx = [i for i in range(A.size(3) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A = A.index_select(3, idx) + B = B.index_select(3, idx) + + + return {'A': A, 'A_paths': AB_paths, 'B': B, 'B_paths': AB_paths} @@ -50,8 +60,6 @@ class AlignedDataLoader(BaseDataLoader): transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - if opt.isTrain and not opt.no_flip: - transformations.insert(1, transforms.RandomHorizontalFlip()) transform = transforms.Compose(transformations) # Dataset A @@ -64,7 +72,10 @@ class AlignedDataLoader(BaseDataLoader): num_workers=int(self.opt.nThreads)) self.dataset = dataset - self.paired_data = PairedData(data_loader, opt.fineSize, opt.max_dataset_size) + + flip = opt.isTrain and not opt.no_flip + self.paired_data = PairedData(data_loader, opt.fineSize, + opt.max_dataset_size, flip) def name(self): return 'AlignedDataLoader' diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py index 3deb55b..bd0ea75 100644 --- a/data/unaligned_data_loader.py +++ b/data/unaligned_data_loader.py @@ -1,3 +1,4 @@ +import random import torch.utils.data import torchvision.transforms as transforms from data.base_data_loader import BaseDataLoader @@ -7,12 +8,13 @@ from builtins import object from pdb import set_trace as st class PairedData(object): - def __init__(self, data_loader_A, data_loader_B, max_dataset_size): + def __init__(self, data_loader_A, data_loader_B, max_dataset_size, flip): self.data_loader_A = data_loader_A self.data_loader_B = data_loader_B self.stop_A = False self.stop_B = False self.max_dataset_size = max_dataset_size + self.flip = flip def __iter__(self): self.stop_A = False @@ -47,6 +49,11 @@ class PairedData(object): raise StopIteration() else: self.iter += 1 + if self.flip and random.random() < 0.5: + idx = [i for i in range(A.size(3) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A = A.index_select(3, idx) + B = B.index_select(3, idx) return {'A': A, 'A_paths': A_paths, 'B': B, 'B_paths': B_paths} @@ -58,8 +65,6 @@ class UnalignedDataLoader(BaseDataLoader): transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - if opt.isTrain and not opt.no_flip: - transformations.insert(1, transforms.RandomHorizontalFlip()) transform = transforms.Compose(transformations) # Dataset A @@ -81,7 +86,9 @@ class UnalignedDataLoader(BaseDataLoader): num_workers=int(self.opt.nThreads)) self.dataset_A = dataset_A self.dataset_B = dataset_B - self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.max_dataset_size) + flip = opt.isTrain and not opt.no_flip + self.paired_data = PairedData(data_loader_A, data_loader_B, + self.opt.max_dataset_size, flip) def name(self): return 'UnalignedDataLoader' |
