summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
authorTaesung Park <taesung_park@berkeley.edu>2017-05-07 23:23:22 -0700
committerTaesung Park <taesung_park@berkeley.edu>2017-05-07 23:23:22 -0700
commit68d0d0dfc9fc18ad65752bf01180cc1668255ba0 (patch)
tree307acc6fcc087de6dfc49f58c4e72bcfe959197a /data
parent5f6e2c4a115a6a706cc011b3bf9ed9e3ef149d98 (diff)
fixed a bug about flipping
Diffstat (limited to 'data')
-rw-r--r--data/aligned_data_loader.py19
-rw-r--r--data/unaligned_data_loader.py15
2 files changed, 26 insertions, 8 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
index 039c113..d1d4572 100644
--- a/data/aligned_data_loader.py
+++ b/data/aligned_data_loader.py
@@ -1,4 +1,5 @@
import random
+import numpy as np
import torch.utils.data
import torchvision.transforms as transforms
from data.base_data_loader import BaseDataLoader
@@ -8,10 +9,11 @@ from pdb import set_trace as st
from builtins import object
class PairedData(object):
- def __init__(self, data_loader, fineSize, max_dataset_size):
+ def __init__(self, data_loader, fineSize, max_dataset_size, flip):
self.data_loader = data_loader
self.fineSize = fineSize
self.max_dataset_size = max_dataset_size
+ self.flip = flip
# st()
def __iter__(self):
@@ -36,6 +38,14 @@ class PairedData(object):
B = AB[:, :, h_offset:h_offset + self.fineSize,
w + w_offset:w + w_offset + self.fineSize]
+ if self.flip and random.random() < 0.5:
+ idx = [i for i in range(A.size(3) - 1, -1, -1)]
+ idx = torch.LongTensor(idx)
+ A = A.index_select(3, idx)
+ B = B.index_select(3, idx)
+
+
+
return {'A': A, 'A_paths': AB_paths, 'B': B, 'B_paths': AB_paths}
@@ -50,8 +60,6 @@ class AlignedDataLoader(BaseDataLoader):
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
- if opt.isTrain and not opt.no_flip:
- transformations.insert(1, transforms.RandomHorizontalFlip())
transform = transforms.Compose(transformations)
# Dataset A
@@ -64,7 +72,10 @@ class AlignedDataLoader(BaseDataLoader):
num_workers=int(self.opt.nThreads))
self.dataset = dataset
- self.paired_data = PairedData(data_loader, opt.fineSize, opt.max_dataset_size)
+
+ flip = opt.isTrain and not opt.no_flip
+ self.paired_data = PairedData(data_loader, opt.fineSize,
+ opt.max_dataset_size, flip)
def name(self):
return 'AlignedDataLoader'
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
index 3deb55b..bd0ea75 100644
--- a/data/unaligned_data_loader.py
+++ b/data/unaligned_data_loader.py
@@ -1,3 +1,4 @@
+import random
import torch.utils.data
import torchvision.transforms as transforms
from data.base_data_loader import BaseDataLoader
@@ -7,12 +8,13 @@ from builtins import object
from pdb import set_trace as st
class PairedData(object):
- def __init__(self, data_loader_A, data_loader_B, max_dataset_size):
+ def __init__(self, data_loader_A, data_loader_B, max_dataset_size, flip):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
self.stop_A = False
self.stop_B = False
self.max_dataset_size = max_dataset_size
+ self.flip = flip
def __iter__(self):
self.stop_A = False
@@ -47,6 +49,11 @@ class PairedData(object):
raise StopIteration()
else:
self.iter += 1
+ if self.flip and random.random() < 0.5:
+ idx = [i for i in range(A.size(3) - 1, -1, -1)]
+ idx = torch.LongTensor(idx)
+ A = A.index_select(3, idx)
+ B = B.index_select(3, idx)
return {'A': A, 'A_paths': A_paths,
'B': B, 'B_paths': B_paths}
@@ -58,8 +65,6 @@ class UnalignedDataLoader(BaseDataLoader):
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
- if opt.isTrain and not opt.no_flip:
- transformations.insert(1, transforms.RandomHorizontalFlip())
transform = transforms.Compose(transformations)
# Dataset A
@@ -81,7 +86,9 @@ class UnalignedDataLoader(BaseDataLoader):
num_workers=int(self.opt.nThreads))
self.dataset_A = dataset_A
self.dataset_B = dataset_B
- self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.max_dataset_size)
+ flip = opt.isTrain and not opt.no_flip
+ self.paired_data = PairedData(data_loader_A, data_loader_B,
+ self.opt.max_dataset_size, flip)
def name(self):
return 'UnalignedDataLoader'