summaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
authorSsnL <tongzhou.wang.1994@gmail.com>2017-07-06 22:19:53 -0500
committerSsnL <tongzhou.wang.1994@gmail.com>2017-07-06 22:31:24 -0500
commit25124b8389f80d7a509b2d98ef69589cab597c9a (patch)
tree185d876bb0fed0e681f163e79ad810e597c8dd8c /data
parentee0a8292e2b87449c325bdb9439f90f911a0c0a1 (diff)
resize_or_crop and better display single image
Diffstat (limited to 'data')
-rw-r--r--data/base_dataset.py37
-rw-r--r--data/single_dataset.py18
-rw-r--r--data/unaligned_dataset.py19
3 files changed, 39 insertions, 35 deletions
diff --git a/data/base_dataset.py b/data/base_dataset.py
index 49b9d98..a061a05 100644
--- a/data/base_dataset.py
+++ b/data/base_dataset.py
@@ -1,12 +1,45 @@
import torch.utils.data as data
+from PIL import Image
+import torchvision.transforms as transforms
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
-
+
def name(self):
return 'BaseDataset'
-
+
def initialize(self, opt):
pass
+def get_transform(opt):
+ transform_list = []
+ if opt.resize_or_crop == 'resize_and_crop':
+ osize = [opt.loadSize, opt.loadSize]
+ transform_list.append(transforms.Scale(osize, Image.BICUBIC))
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
+ elif opt.resize_or_crop == 'crop':
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
+ elif opt.resize_or_crop == 'scale_width':
+ transform_list.append(transforms.Lambda(
+ lambda img: __scale_width(img, opt.fineSize)))
+ elif opt.resize_or_crop == 'scale_width_and_crop':
+ transform_list.append(transforms.Lambda(
+ lambda img: __scale_width(img, opt.loadSize)))
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
+
+ if opt.isTrain and not opt.no_flip:
+ transform_list.append(transforms.RandomHorizontalFlip())
+
+ transform_list += [transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+ return transforms.Compose(transform_list)
+
+def __scale_width(img, target_width):
+ ow, oh = img.size
+ if (ow == target_width):
+ return img
+ w = target_width
+ h = int(target_width * oh / ow)
+ return img.resize((w, h), Image.BICUBIC)
diff --git a/data/single_dataset.py b/data/single_dataset.py
index 106bea3..faf416a 100644
--- a/data/single_dataset.py
+++ b/data/single_dataset.py
@@ -1,6 +1,6 @@
import os.path
import torchvision.transforms as transforms
-from data.base_dataset import BaseDataset
+from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
@@ -15,21 +15,7 @@ class SingleDataset(BaseDataset):
self.A_paths = sorted(self.A_paths)
- transform_list = []
- if opt.resize_or_crop == 'resize_and_crop':
- transform_list.append(transforms.Scale(opt.loadSize))
-
- if opt.isTrain and not opt.no_flip:
- transform_list.append(transforms.RandomHorizontalFlip())
-
- if opt.resize_or_crop != 'no_resize':
- transform_list.append(transforms.RandomCrop(opt.fineSize))
-
- transform_list += [transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5),
- (0.5, 0.5, 0.5))]
-
- self.transform = transforms.Compose(transform_list)
+ self.transform = get_transform(opt)
def __getitem__(self, index):
A_path = self.A_paths[index]
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
index 7333d16..3864bf3 100644
--- a/data/unaligned_dataset.py
+++ b/data/unaligned_dataset.py
@@ -1,6 +1,6 @@
import os.path
import torchvision.transforms as transforms
-from data.base_dataset import BaseDataset
+from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import PIL
@@ -21,22 +21,7 @@ class UnalignedDataset(BaseDataset):
self.B_paths = sorted(self.B_paths)
self.A_size = len(self.A_paths)
self.B_size = len(self.B_paths)
-
- transform_list = []
- if opt.resize_or_crop == 'resize_and_crop':
- osize = [opt.loadSize, opt.loadSize]
- transform_list.append(transforms.Scale(osize, Image.BICUBIC))
-
- if opt.isTrain and not opt.no_flip:
- transform_list.append(transforms.RandomHorizontalFlip())
-
- if opt.resize_or_crop != 'no_resize':
- transform_list.append(transforms.RandomCrop(opt.fineSize))
-
- transform_list += [transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5),
- (0.5, 0.5, 0.5))]
- self.transform = transforms.Compose(transform_list)
+ self.transform = get_transform(opt)
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]