summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjunyanz <junyanzhu89@gmail.com>2018-02-09 11:39:35 -0500
committerjunyanz <junyanzhu89@gmail.com>2018-02-09 11:39:35 -0500
commit0ae4f0500e415a6a67689ef9356e8e4779ae5833 (patch)
tree25392e96a1b64c8454f7f548886af7dd48aa6bd0
parent7a5e2cd5f5003e8ca9a0fc3dac14a74b81287881 (diff)
code reformatting
-rw-r--r--data/base_data_loader.py6
-rw-r--r--data/base_dataset.py3
-rw-r--r--data/data_loader.py1
-rw-r--r--data/single_dataset.py1
-rw-r--r--data/unaligned_dataset.py4
-rw-r--r--models/cycle_gan_model.py9
-rw-r--r--models/models.py1
-rw-r--r--models/networks.py5
-rw-r--r--models/pix2pix_model.py2
-rw-r--r--options/base_options.py5
-rw-r--r--options/train_options.py4
-rw-r--r--test.py1
-rw-r--r--train.py2
-rw-r--r--util/image_pool.py3
-rw-r--r--util/util.py4
-rw-r--r--util/visualizer.py2
16 files changed, 17 insertions, 36 deletions
diff --git a/data/base_data_loader.py b/data/base_data_loader.py
index 0e1deb5..ae5a168 100644
--- a/data/base_data_loader.py
+++ b/data/base_data_loader.py
@@ -1,14 +1,10 @@
-
class BaseDataLoader():
def __init__(self):
pass
-
+
def initialize(self, opt):
self.opt = opt
pass
def load_data():
return None
-
-
-
diff --git a/data/base_dataset.py b/data/base_dataset.py
index a061a05..7cfac54 100644
--- a/data/base_dataset.py
+++ b/data/base_dataset.py
@@ -2,6 +2,7 @@ import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
+
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
@@ -12,6 +13,7 @@ class BaseDataset(data.Dataset):
def initialize(self, opt):
pass
+
def get_transform(opt):
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
@@ -36,6 +38,7 @@ def get_transform(opt):
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
+
def __scale_width(img, target_width):
ow, oh = img.size
if (ow == target_width):
diff --git a/data/data_loader.py b/data/data_loader.py
index 2a4433a..22b6a8f 100644
--- a/data/data_loader.py
+++ b/data/data_loader.py
@@ -1,4 +1,3 @@
-
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
diff --git a/data/single_dataset.py b/data/single_dataset.py
index f8b4f1d..12083b1 100644
--- a/data/single_dataset.py
+++ b/data/single_dataset.py
@@ -1,5 +1,4 @@
import os.path
-import torchvision.transforms as transforms
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
index ad0c11b..2f59b2a 100644
--- a/data/unaligned_dataset.py
+++ b/data/unaligned_dataset.py
@@ -1,11 +1,10 @@
import os.path
-import torchvision.transforms as transforms
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
-import PIL
import random
+
class UnalignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
@@ -24,7 +23,6 @@ class UnalignedDataset(BaseDataset):
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
- index_A = index % self.A_size
if self.opt.serial_batches:
index_B = index % self.B_size
else:
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index b7b840d..bcc6a15 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -1,6 +1,4 @@
-import numpy as np
import torch
-import os
from collections import OrderedDict
from torch.autograd import Variable
import itertools
@@ -8,7 +6,6 @@ import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
-import sys
class CycleGANModel(BaseModel):
@@ -17,10 +14,6 @@ class CycleGANModel(BaseModel):
def initialize(self, opt):
BaseModel.initialize(self, opt)
-
- nb = opt.batchSize
- size = opt.fineSize
-
# load/define networks
# The naming conversion is different from those used in the paper
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
@@ -200,7 +193,7 @@ class CycleGANModel(BaseModel):
def get_current_errors(self):
ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A),
- ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
+ ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
if self.opt.identity > 0.0:
ret_errors['idt_A'] = self.loss_idt_A
ret_errors['idt_B'] = self.loss_idt_B
diff --git a/models/models.py b/models/models.py
index d5bb9d8..39cc020 100644
--- a/models/models.py
+++ b/models/models.py
@@ -1,4 +1,3 @@
-
def create_model(opt):
model = None
print(opt.model)
diff --git a/models/networks.py b/models/networks.py
index da2f59c..b118c6a 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -4,7 +4,6 @@ from torch.nn import init
import functools
from torch.autograd import Variable
from torch.optim import lr_scheduler
-import numpy as np
###############################################################################
# Functions
###############################################################################
@@ -434,6 +433,7 @@ class NLayerDiscriminator(nn.Module):
else:
return self.model(input)
+
class PixelDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(PixelDiscriminator, self).__init__()
@@ -442,7 +442,7 @@ class PixelDiscriminator(nn.Module):
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
-
+
self.net = [
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
nn.LeakyReLU(0.2, True),
@@ -461,4 +461,3 @@ class PixelDiscriminator(nn.Module):
return nn.parallel.data_parallel(self.net, input, self.gpu_ids)
else:
return self.net(input)
-
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index 74a941e..9c46a19 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -1,6 +1,4 @@
-import numpy as np
import torch
-import os
from collections import OrderedDict
from torch.autograd import Variable
import util.util as util
diff --git a/options/base_options.py b/options/base_options.py
index 13466bf..ce58548 100644
--- a/options/base_options.py
+++ b/options/base_options.py
@@ -31,11 +31,12 @@ class BaseOptions():
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
- self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
+ self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
- self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
+ self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"),
+ help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
diff --git a/options/train_options.py b/options/train_options.py
index 603d76a..f4627ce 100644
--- a/options/train_options.py
+++ b/options/train_options.py
@@ -25,6 +25,8 @@ class TrainOptions(BaseOptions):
self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
- self.parser.add_argument('--identity', type=float, default=0.5, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1')
+ self.parser.add_argument('--identity', type=float, default=0.5,
+ help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss.'
+ 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1')
self.isTrain = True
diff --git a/test.py b/test.py
index 65e79d7..863e550 100644
--- a/test.py
+++ b/test.py
@@ -1,4 +1,3 @@
-import time
import os
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
diff --git a/train.py b/train.py
index 6dbd66b..f6072c7 100644
--- a/train.py
+++ b/train.py
@@ -35,7 +35,7 @@ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
t = (time.time() - iter_start_time) / opt.batchSize
visualizer.print_current_errors(epoch, epoch_iter, errors, t)
if opt.display_id > 0:
- visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
+ visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors)
if total_steps % opt.save_latest_freq == 0:
print('saving the latest model (epoch %d, total_steps %d)' %
diff --git a/util/image_pool.py b/util/image_pool.py
index ada1627..634fd81 100644
--- a/util/image_pool.py
+++ b/util/image_pool.py
@@ -1,5 +1,4 @@
import random
-import numpy as np
import torch
from torch.autograd import Variable
@@ -24,7 +23,7 @@ class ImagePool():
else:
p = random.uniform(0, 1)
if p > 0.5:
- random_id = random.randint(0, self.pool_size-1)
+ random_id = random.randint(0, self.pool_size - 1)
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
diff --git a/util/util.py b/util/util.py
index 26b259a..7a452a6 100644
--- a/util/util.py
+++ b/util/util.py
@@ -2,11 +2,7 @@ from __future__ import print_function
import torch
import numpy as np
from PIL import Image
-import inspect
-import re
-import numpy as np
import os
-import collections
# Converts a Tensor into a Numpy array
diff --git a/util/visualizer.py b/util/visualizer.py
index 8bec8df..b22f235 100644
--- a/util/visualizer.py
+++ b/util/visualizer.py
@@ -56,7 +56,7 @@ class Visualizer():
if idx % ncols == 0:
label_html += '<tr>%s</tr>' % label_html_row
label_html_row = ''
- white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
+ white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
while idx % ncols != 0:
images.append(white_image)
label_html_row += '<td></td>'