summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
authorTaesung Park <taesung_park@berkeley.edu>2017-05-07 22:18:20 -0700
committerTaesung Park <taesung_park@berkeley.edu>2017-05-07 22:18:20 -0700
commit5f6e2c4a115a6a706cc011b3bf9ed9e3ef149d98 (patch)
tree29708b3526fcb354893982eab0b8d003c63bb12e /data
parent349614a2f168654ba59bf1461ea61e1cb9358eb6 (diff)
1. Added flipping functionality
2. Changed the default options
Diffstat (limited to 'data')
-rw-r--r--data/aligned_data_loader.py8
-rw-r--r--data/unaligned_data_loader.py14
2 files changed, 14 insertions, 8 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
index a1efde8..039c113 100644
--- a/data/aligned_data_loader.py
+++ b/data/aligned_data_loader.py
@@ -43,12 +43,16 @@ class AlignedDataLoader(BaseDataLoader):
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.fineSize = opt.fineSize
- transform = transforms.Compose([
+
+ transformations = [
# TODO: Scale
transforms.Scale(opt.loadSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
- (0.5, 0.5, 0.5))])
+ (0.5, 0.5, 0.5))]
+ if opt.isTrain and not opt.no_flip:
+ transformations.insert(1, transforms.RandomHorizontalFlip())
+ transform = transforms.Compose(transformations)
# Dataset A
dataset = ImageFolder(root=opt.dataroot + '/' + opt.phase,
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
index 77f9274..3deb55b 100644
--- a/data/unaligned_data_loader.py
+++ b/data/unaligned_data_loader.py
@@ -53,12 +53,14 @@ class PairedData(object):
class UnalignedDataLoader(BaseDataLoader):
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
- transform = transforms.Compose([
- transforms.Scale(opt.loadSize),
- transforms.RandomCrop(opt.fineSize),
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5),
- (0.5, 0.5, 0.5))])
+ transformations = [transforms.Scale(opt.loadSize),
+ transforms.RandomCrop(opt.fineSize),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+ if opt.isTrain and not opt.no_flip:
+ transformations.insert(1, transforms.RandomHorizontalFlip())
+ transform = transforms.Compose(transformations)
# Dataset A
dataset_A = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'A',