diff options
| -rw-r--r-- | data/base_data_loader.py | 6 | ||||
| -rw-r--r-- | data/base_dataset.py | 3 | ||||
| -rw-r--r-- | data/data_loader.py | 1 | ||||
| -rw-r--r-- | data/single_dataset.py | 1 | ||||
| -rw-r--r-- | data/unaligned_dataset.py | 4 | ||||
| -rw-r--r-- | models/cycle_gan_model.py | 9 | ||||
| -rw-r--r-- | models/models.py | 1 | ||||
| -rw-r--r-- | models/networks.py | 5 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 2 | ||||
| -rw-r--r-- | options/base_options.py | 5 | ||||
| -rw-r--r-- | options/train_options.py | 4 | ||||
| -rw-r--r-- | test.py | 1 | ||||
| -rw-r--r-- | train.py | 2 | ||||
| -rw-r--r-- | util/image_pool.py | 3 | ||||
| -rw-r--r-- | util/util.py | 4 | ||||
| -rw-r--r-- | util/visualizer.py | 2 |
16 files changed, 17 insertions, 36 deletions
diff --git a/data/base_data_loader.py b/data/base_data_loader.py index 0e1deb5..ae5a168 100644 --- a/data/base_data_loader.py +++ b/data/base_data_loader.py @@ -1,14 +1,10 @@ - class BaseDataLoader(): def __init__(self): pass - + def initialize(self, opt): self.opt = opt pass def load_data(): return None - - - diff --git a/data/base_dataset.py b/data/base_dataset.py index a061a05..7cfac54 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -2,6 +2,7 @@ import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms + class BaseDataset(data.Dataset): def __init__(self): super(BaseDataset, self).__init__() @@ -12,6 +13,7 @@ class BaseDataset(data.Dataset): def initialize(self, opt): pass + def get_transform(opt): transform_list = [] if opt.resize_or_crop == 'resize_and_crop': @@ -36,6 +38,7 @@ def get_transform(opt): (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) + def __scale_width(img, target_width): ow, oh = img.size if (ow == target_width): diff --git a/data/data_loader.py b/data/data_loader.py index 2a4433a..22b6a8f 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -1,4 +1,3 @@ - def CreateDataLoader(opt): from data.custom_dataset_data_loader import CustomDatasetDataLoader data_loader = CustomDatasetDataLoader() diff --git a/data/single_dataset.py b/data/single_dataset.py index f8b4f1d..12083b1 100644 --- a/data/single_dataset.py +++ b/data/single_dataset.py @@ -1,5 +1,4 @@ import os.path -import torchvision.transforms as transforms from data.base_dataset import BaseDataset, get_transform from data.image_folder import make_dataset from PIL import Image diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index ad0c11b..2f59b2a 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -1,11 +1,10 @@ import os.path -import torchvision.transforms as transforms from data.base_dataset import BaseDataset, get_transform from data.image_folder import make_dataset from PIL import Image -import PIL import random + class UnalignedDataset(BaseDataset): def initialize(self, opt): self.opt = opt @@ -24,7 +23,6 @@ class UnalignedDataset(BaseDataset): def __getitem__(self, index): A_path = self.A_paths[index % self.A_size] - index_A = index % self.A_size if self.opt.serial_batches: index_B = index % self.B_size else: diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index b7b840d..bcc6a15 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -1,6 +1,4 @@ -import numpy as np import torch -import os from collections import OrderedDict from torch.autograd import Variable import itertools @@ -8,7 +6,6 @@ import util.util as util from util.image_pool import ImagePool from .base_model import BaseModel from . import networks -import sys class CycleGANModel(BaseModel): @@ -17,10 +14,6 @@ class CycleGANModel(BaseModel): def initialize(self, opt): BaseModel.initialize(self, opt) - - nb = opt.batchSize - size = opt.fineSize - # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) @@ -200,7 +193,7 @@ class CycleGANModel(BaseModel): def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), - ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) + ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B diff --git a/models/models.py b/models/models.py index d5bb9d8..39cc020 100644 --- a/models/models.py +++ b/models/models.py @@ -1,4 +1,3 @@ - def create_model(opt): model = None print(opt.model) diff --git a/models/networks.py b/models/networks.py index da2f59c..b118c6a 100644 --- a/models/networks.py +++ b/models/networks.py @@ -4,7 +4,6 @@ from torch.nn import init import functools from torch.autograd import Variable from torch.optim import lr_scheduler -import numpy as np ############################################################################### # Functions ############################################################################### @@ -434,6 +433,7 @@ class NLayerDiscriminator(nn.Module): else: return self.model(input) + class PixelDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): super(PixelDiscriminator, self).__init__() @@ -442,7 +442,7 @@ class PixelDiscriminator(nn.Module): use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d - + self.net = [ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(0.2, True), @@ -461,4 +461,3 @@ class PixelDiscriminator(nn.Module): return nn.parallel.data_parallel(self.net, input, self.gpu_ids) else: return self.net(input) - diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 74a941e..9c46a19 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -1,6 +1,4 @@ -import numpy as np import torch -import os from collections import OrderedDict from torch.autograd import Variable import util.util as util diff --git a/options/base_options.py b/options/base_options.py index 13466bf..ce58548 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -31,11 +31,12 @@ class BaseOptions(): self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') - self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') + self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') - self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), + help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') diff --git a/options/train_options.py b/options/train_options.py index 603d76a..f4627ce 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -25,6 +25,8 @@ class TrainOptions(BaseOptions): self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') - self.parser.add_argument('--identity', type=float, default=0.5, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') + self.parser.add_argument('--identity', type=float, default=0.5, + help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss.' + 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') self.isTrain = True @@ -1,4 +1,3 @@ -import time import os from options.test_options import TestOptions from data.data_loader import CreateDataLoader @@ -35,7 +35,7 @@ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t) if opt.display_id > 0: - visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) + visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors) if total_steps % opt.save_latest_freq == 0: print('saving the latest model (epoch %d, total_steps %d)' % diff --git a/util/image_pool.py b/util/image_pool.py index ada1627..634fd81 100644 --- a/util/image_pool.py +++ b/util/image_pool.py @@ -1,5 +1,4 @@ import random -import numpy as np import torch from torch.autograd import Variable @@ -24,7 +23,7 @@ class ImagePool(): else: p = random.uniform(0, 1) if p > 0.5: - random_id = random.randint(0, self.pool_size-1) + random_id = random.randint(0, self.pool_size - 1) tmp = self.images[random_id].clone() self.images[random_id] = image return_images.append(tmp) diff --git a/util/util.py b/util/util.py index 26b259a..7a452a6 100644 --- a/util/util.py +++ b/util/util.py @@ -2,11 +2,7 @@ from __future__ import print_function import torch import numpy as np from PIL import Image -import inspect -import re -import numpy as np import os -import collections # Converts a Tensor into a Numpy array diff --git a/util/visualizer.py b/util/visualizer.py index 8bec8df..b22f235 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -56,7 +56,7 @@ class Visualizer(): if idx % ncols == 0: label_html += '<tr>%s</tr>' % label_html_row label_html_row = '' - white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 + white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 while idx % ncols != 0: images.append(white_image) label_html_row += '<td></td>' |
