diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-06-12 23:52:56 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-06-12 23:52:56 -0700 |
| commit | e6858e35f0a08c6139c133122d222d0d85e8005d (patch) | |
| tree | 2647ff13a164c684113eab455123394a49a65dad /data/unaligned_data_loader.py | |
| parent | 3b72a659c38141e502b74bee65ca08d51dc3eabf (diff) | |
update dataset mode
Diffstat (limited to 'data/unaligned_data_loader.py')
| -rw-r--r-- | data/unaligned_data_loader.py | 100 |
1 files changed, 0 insertions, 100 deletions
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py deleted file mode 100644 index bd0ea75..0000000 --- a/data/unaligned_data_loader.py +++ /dev/null @@ -1,100 +0,0 @@ -import random -import torch.utils.data -import torchvision.transforms as transforms -from data.base_data_loader import BaseDataLoader -from data.image_folder import ImageFolder -# pip install future --upgrade -from builtins import object -from pdb import set_trace as st - -class PairedData(object): - def __init__(self, data_loader_A, data_loader_B, max_dataset_size, flip): - self.data_loader_A = data_loader_A - self.data_loader_B = data_loader_B - self.stop_A = False - self.stop_B = False - self.max_dataset_size = max_dataset_size - self.flip = flip - - def __iter__(self): - self.stop_A = False - self.stop_B = False - self.data_loader_A_iter = iter(self.data_loader_A) - self.data_loader_B_iter = iter(self.data_loader_B) - self.iter = 0 - return self - - def __next__(self): - A, A_paths = None, None - B, B_paths = None, None - try: - A, A_paths = next(self.data_loader_A_iter) - except StopIteration: - if A is None or A_paths is None: - self.stop_A = True - self.data_loader_A_iter = iter(self.data_loader_A) - A, A_paths = next(self.data_loader_A_iter) - - try: - B, B_paths = next(self.data_loader_B_iter) - except StopIteration: - if B is None or B_paths is None: - self.stop_B = True - self.data_loader_B_iter = iter(self.data_loader_B) - B, B_paths = next(self.data_loader_B_iter) - - if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size: - self.stop_A = False - self.stop_B = False - raise StopIteration() - else: - self.iter += 1 - if self.flip and random.random() < 0.5: - idx = [i for i in range(A.size(3) - 1, -1, -1)] - idx = torch.LongTensor(idx) - A = A.index_select(3, idx) - B = B.index_select(3, idx) - return {'A': A, 'A_paths': A_paths, - 'B': B, 'B_paths': B_paths} - -class UnalignedDataLoader(BaseDataLoader): - def initialize(self, opt): - BaseDataLoader.initialize(self, opt) - transformations = [transforms.Scale(opt.loadSize), - transforms.RandomCrop(opt.fineSize), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] - transform = transforms.Compose(transformations) - - # Dataset A - dataset_A = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'A', - transform=transform, return_paths=True) - data_loader_A = torch.utils.data.DataLoader( - dataset_A, - batch_size=self.opt.batchSize, - shuffle=not self.opt.serial_batches, - num_workers=int(self.opt.nThreads)) - - # Dataset B - dataset_B = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'B', - transform=transform, return_paths=True) - data_loader_B = torch.utils.data.DataLoader( - dataset_B, - batch_size=self.opt.batchSize, - shuffle=not self.opt.serial_batches, - num_workers=int(self.opt.nThreads)) - self.dataset_A = dataset_A - self.dataset_B = dataset_B - flip = opt.isTrain and not opt.no_flip - self.paired_data = PairedData(data_loader_A, data_loader_B, - self.opt.max_dataset_size, flip) - - def name(self): - return 'UnalignedDataLoader' - - def load_data(self): - return self.paired_data - - def __len__(self): - return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.max_dataset_size) |
