summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data/base_dataset.py37
-rw-r--r--data/single_dataset.py18
-rw-r--r--data/unaligned_dataset.py19
-rw-r--r--options/base_options.py2
-rw-r--r--util/visualizer.py14
5 files changed, 50 insertions, 40 deletions
diff --git a/data/base_dataset.py b/data/base_dataset.py
index 49b9d98..a061a05 100644
--- a/data/base_dataset.py
+++ b/data/base_dataset.py
@@ -1,12 +1,45 @@
import torch.utils.data as data
+from PIL import Image
+import torchvision.transforms as transforms
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
-
+
def name(self):
return 'BaseDataset'
-
+
def initialize(self, opt):
pass
+def get_transform(opt):
+ transform_list = []
+ if opt.resize_or_crop == 'resize_and_crop':
+ osize = [opt.loadSize, opt.loadSize]
+ transform_list.append(transforms.Scale(osize, Image.BICUBIC))
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
+ elif opt.resize_or_crop == 'crop':
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
+ elif opt.resize_or_crop == 'scale_width':
+ transform_list.append(transforms.Lambda(
+ lambda img: __scale_width(img, opt.fineSize)))
+ elif opt.resize_or_crop == 'scale_width_and_crop':
+ transform_list.append(transforms.Lambda(
+ lambda img: __scale_width(img, opt.loadSize)))
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
+
+ if opt.isTrain and not opt.no_flip:
+ transform_list.append(transforms.RandomHorizontalFlip())
+
+ transform_list += [transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+ return transforms.Compose(transform_list)
+
+def __scale_width(img, target_width):
+ ow, oh = img.size
+ if (ow == target_width):
+ return img
+ w = target_width
+ h = int(target_width * oh / ow)
+ return img.resize((w, h), Image.BICUBIC)
diff --git a/data/single_dataset.py b/data/single_dataset.py
index 106bea3..faf416a 100644
--- a/data/single_dataset.py
+++ b/data/single_dataset.py
@@ -1,6 +1,6 @@
import os.path
import torchvision.transforms as transforms
-from data.base_dataset import BaseDataset
+from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
@@ -15,21 +15,7 @@ class SingleDataset(BaseDataset):
self.A_paths = sorted(self.A_paths)
- transform_list = []
- if opt.resize_or_crop == 'resize_and_crop':
- transform_list.append(transforms.Scale(opt.loadSize))
-
- if opt.isTrain and not opt.no_flip:
- transform_list.append(transforms.RandomHorizontalFlip())
-
- if opt.resize_or_crop != 'no_resize':
- transform_list.append(transforms.RandomCrop(opt.fineSize))
-
- transform_list += [transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5),
- (0.5, 0.5, 0.5))]
-
- self.transform = transforms.Compose(transform_list)
+ self.transform = get_transform(opt)
def __getitem__(self, index):
A_path = self.A_paths[index]
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
index 7333d16..3864bf3 100644
--- a/data/unaligned_dataset.py
+++ b/data/unaligned_dataset.py
@@ -1,6 +1,6 @@
import os.path
import torchvision.transforms as transforms
-from data.base_dataset import BaseDataset
+from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import PIL
@@ -21,22 +21,7 @@ class UnalignedDataset(BaseDataset):
self.B_paths = sorted(self.B_paths)
self.A_size = len(self.A_paths)
self.B_size = len(self.B_paths)
-
- transform_list = []
- if opt.resize_or_crop == 'resize_and_crop':
- osize = [opt.loadSize, opt.loadSize]
- transform_list.append(transforms.Scale(osize, Image.BICUBIC))
-
- if opt.isTrain and not opt.no_flip:
- transform_list.append(transforms.RandomHorizontalFlip())
-
- if opt.resize_or_crop != 'no_resize':
- transform_list.append(transforms.RandomCrop(opt.fineSize))
-
- transform_list += [transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5),
- (0.5, 0.5, 0.5))]
- self.transform = transforms.Compose(transform_list)
+ self.transform = get_transform(opt)
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
diff --git a/options/base_options.py b/options/base_options.py
index 909136f..b5b92fb 100644
--- a/options/base_options.py
+++ b/options/base_options.py
@@ -37,7 +37,7 @@ class BaseOptions():
self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1')
self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
- self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width]')
+ self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
self.initialized = True
diff --git a/util/visualizer.py b/util/visualizer.py
index 38b3bac..3733525 100644
--- a/util/visualizer.py
+++ b/util/visualizer.py
@@ -31,6 +31,11 @@ class Visualizer():
def display_current_results(self, visuals, epoch):
if self.display_id > 0: # show images in the browser
if self.display_single_pane_ncols > 0:
+ h, w = next(iter(visuals.values())).shape[:2]
+ table_css = """<style>
+ table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
+ table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
+</style>""" % (w, h)
ncols = self.display_single_pane_ncols
title = self.name
label_html = ''
@@ -45,17 +50,18 @@ class Visualizer():
if idx % ncols == 0:
label_html += '<tr>%s</tr>' % label_html_row
label_html_row = ''
+ white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
while idx % ncols != 0:
- white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
images.append(white_image)
label_html_row += '<td></td>'
idx += 1
if label_html_row != '':
label_html += '<tr>%s</tr>' % label_html_row
+ # pane col = image row
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
- opts=dict(title=title + ' images')) # pane col = image row
- label_html = '<table style="border-collapse:separate;border-spacing:10px;">%s</table' % label_html
- self.vis.text(label_html, win = self.display_id + 2,
+ padding=2, opts=dict(title=title + ' images'))
+ label_html = '<table>%s</table>' % label_html
+ self.vis.text(table_css + label_html, win = self.display_id + 2,
opts=dict(title=title + ' labels'))
else:
idx = 1