diff options
| -rw-r--r-- | data/base_dataset.py | 37 | ||||
| -rw-r--r-- | data/single_dataset.py | 18 | ||||
| -rw-r--r-- | data/unaligned_dataset.py | 19 | ||||
| -rw-r--r-- | options/base_options.py | 2 | ||||
| -rw-r--r-- | util/visualizer.py | 14 |
5 files changed, 50 insertions, 40 deletions
diff --git a/data/base_dataset.py b/data/base_dataset.py index 49b9d98..a061a05 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -1,12 +1,45 @@ import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms class BaseDataset(data.Dataset): def __init__(self): super(BaseDataset, self).__init__() - + def name(self): return 'BaseDataset' - + def initialize(self, opt): pass +def get_transform(opt): + transform_list = [] + if opt.resize_or_crop == 'resize_and_crop': + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Scale(osize, Image.BICUBIC)) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'crop': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'scale_width': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.fineSize))) + elif opt.resize_or_crop == 'scale_width_and_crop': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.loadSize))) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + +def __scale_width(img, target_width): + ow, oh = img.size + if (ow == target_width): + return img + w = target_width + h = int(target_width * oh / ow) + return img.resize((w, h), Image.BICUBIC) diff --git a/data/single_dataset.py b/data/single_dataset.py index 106bea3..faf416a 100644 --- a/data/single_dataset.py +++ b/data/single_dataset.py @@ -1,6 +1,6 @@ import os.path import torchvision.transforms as transforms -from data.base_dataset import BaseDataset +from data.base_dataset import BaseDataset, get_transform from data.image_folder import make_dataset from PIL import Image @@ -15,21 +15,7 @@ class SingleDataset(BaseDataset): self.A_paths = sorted(self.A_paths) - transform_list = [] - if opt.resize_or_crop == 'resize_and_crop': - transform_list.append(transforms.Scale(opt.loadSize)) - - if opt.isTrain and not opt.no_flip: - transform_list.append(transforms.RandomHorizontalFlip()) - - if opt.resize_or_crop != 'no_resize': - transform_list.append(transforms.RandomCrop(opt.fineSize)) - - transform_list += [transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] - - self.transform = transforms.Compose(transform_list) + self.transform = get_transform(opt) def __getitem__(self, index): A_path = self.A_paths[index] diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index 7333d16..3864bf3 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -1,6 +1,6 @@ import os.path import torchvision.transforms as transforms -from data.base_dataset import BaseDataset +from data.base_dataset import BaseDataset, get_transform from data.image_folder import make_dataset from PIL import Image import PIL @@ -21,22 +21,7 @@ class UnalignedDataset(BaseDataset): self.B_paths = sorted(self.B_paths) self.A_size = len(self.A_paths) self.B_size = len(self.B_paths) - - transform_list = [] - if opt.resize_or_crop == 'resize_and_crop': - osize = [opt.loadSize, opt.loadSize] - transform_list.append(transforms.Scale(osize, Image.BICUBIC)) - - if opt.isTrain and not opt.no_flip: - transform_list.append(transforms.RandomHorizontalFlip()) - - if opt.resize_or_crop != 'no_resize': - transform_list.append(transforms.RandomCrop(opt.fineSize)) - - transform_list += [transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] - self.transform = transforms.Compose(transform_list) + self.transform = get_transform(opt) def __getitem__(self, index): A_path = self.A_paths[index % self.A_size] diff --git a/options/base_options.py b/options/base_options.py index 909136f..b5b92fb 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -37,7 +37,7 @@ class BaseOptions(): self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') - self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width]') + self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') self.initialized = True diff --git a/util/visualizer.py b/util/visualizer.py index 38b3bac..3733525 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -31,6 +31,11 @@ class Visualizer(): def display_current_results(self, visuals, epoch): if self.display_id > 0: # show images in the browser if self.display_single_pane_ncols > 0: + h, w = next(iter(visuals.values())).shape[:2] + table_css = """<style> + table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center} + table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black} +</style>""" % (w, h) ncols = self.display_single_pane_ncols title = self.name label_html = '' @@ -45,17 +50,18 @@ class Visualizer(): if idx % ncols == 0: label_html += '<tr>%s</tr>' % label_html_row label_html_row = '' + white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 while idx % ncols != 0: - white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 images.append(white_image) label_html_row += '<td></td>' idx += 1 if label_html_row != '': label_html += '<tr>%s</tr>' % label_html_row + # pane col = image row self.vis.images(images, nrow=ncols, win=self.display_id + 1, - opts=dict(title=title + ' images')) # pane col = image row - label_html = '<table style="border-collapse:separate;border-spacing:10px;">%s</table' % label_html - self.vis.text(label_html, win = self.display_id + 2, + padding=2, opts=dict(title=title + ' images')) + label_html = '<table>%s</table>' % label_html + self.vis.text(table_css + label_html, win = self.display_id + 2, opts=dict(title=title + ' labels')) else: idx = 1 |
