summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md14
-rw-r--r--data/aligned_dataset.py22
-rw-r--r--data/base_data_loader.py6
-rw-r--r--data/base_dataset.py3
-rw-r--r--data/data_loader.py1
-rw-r--r--data/single_dataset.py1
-rw-r--r--data/unaligned_dataset.py4
-rw-r--r--models/cycle_gan_model.py16
-rw-r--r--models/models.py1
-rw-r--r--models/networks.py5
-rw-r--r--models/pix2pix_model.py3
-rw-r--r--options/base_options.py5
-rw-r--r--options/train_options.py4
-rw-r--r--test.py1
-rw-r--r--train.py8
-rw-r--r--util/image_pool.py3
-rw-r--r--util/util.py4
-rw-r--r--util/visualizer.py6
18 files changed, 45 insertions, 62 deletions
diff --git a/README.md b/README.md
index 02c2d36..2c93bfa 100644
--- a/README.md
+++ b/README.md
@@ -222,19 +222,21 @@ This will combine each pair of images (A,B) into a single image file, ready for
## Citation
If you use this code for your research, please cite our papers.
```
-@article{CycleGAN2017,
- title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks},
+@inproceedings{CycleGAN2017,
+ title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss},
author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
- journal={arXiv preprint arXiv:1703.10593},
+ booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on},
year={2017}
}
-@article{pix2pix2016,
+
+@inproceedings{isola2017image,
title={Image-to-Image Translation with Conditional Adversarial Networks},
author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A},
- journal={arxiv},
- year={2016}
+ booktitle={Computer Vision and Pattern Recognition (CVPR), 2017 IEEE Conference on},
+ year={2017}
}
+
```
## Related Projects
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py
index 8899cb2..f153f26 100644
--- a/data/aligned_dataset.py
+++ b/data/aligned_dataset.py
@@ -18,19 +18,17 @@ class AlignedDataset(BaseDataset):
def __getitem__(self, index):
AB_path = self.AB_paths[index]
AB = Image.open(AB_path).convert('RGB')
- AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC)
- AB = transforms.ToTensor()(AB)
+ w, h = AB.size
+ w2 = int(w / 2)
+ A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
+ B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
+ A = transforms.ToTensor()(A)
+ B = transforms.ToTensor()(B)
+ w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
+ h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
- w_total = AB.size(2)
- w = int(w_total / 2)
- h = AB.size(1)
- w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
- h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
-
- A = AB[:, h_offset:h_offset + self.opt.fineSize,
- w_offset:w_offset + self.opt.fineSize]
- B = AB[:, h_offset:h_offset + self.opt.fineSize,
- w + w_offset:w + w_offset + self.opt.fineSize]
+ A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
+ B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A)
B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B)
diff --git a/data/base_data_loader.py b/data/base_data_loader.py
index 0e1deb5..ae5a168 100644
--- a/data/base_data_loader.py
+++ b/data/base_data_loader.py
@@ -1,14 +1,10 @@
-
class BaseDataLoader():
def __init__(self):
pass
-
+
def initialize(self, opt):
self.opt = opt
pass
def load_data():
return None
-
-
-
diff --git a/data/base_dataset.py b/data/base_dataset.py
index a061a05..7cfac54 100644
--- a/data/base_dataset.py
+++ b/data/base_dataset.py
@@ -2,6 +2,7 @@ import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
+
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
@@ -12,6 +13,7 @@ class BaseDataset(data.Dataset):
def initialize(self, opt):
pass
+
def get_transform(opt):
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
@@ -36,6 +38,7 @@ def get_transform(opt):
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
+
def __scale_width(img, target_width):
ow, oh = img.size
if (ow == target_width):
diff --git a/data/data_loader.py b/data/data_loader.py
index 2a4433a..22b6a8f 100644
--- a/data/data_loader.py
+++ b/data/data_loader.py
@@ -1,4 +1,3 @@
-
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
diff --git a/data/single_dataset.py b/data/single_dataset.py
index f8b4f1d..12083b1 100644
--- a/data/single_dataset.py
+++ b/data/single_dataset.py
@@ -1,5 +1,4 @@
import os.path
-import torchvision.transforms as transforms
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
index ad0c11b..2f59b2a 100644
--- a/data/unaligned_dataset.py
+++ b/data/unaligned_dataset.py
@@ -1,11 +1,10 @@
import os.path
-import torchvision.transforms as transforms
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
-import PIL
import random
+
class UnalignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
@@ -24,7 +23,6 @@ class UnalignedDataset(BaseDataset):
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
- index_A = index % self.A_size
if self.opt.serial_batches:
index_B = index % self.B_size
else:
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index b7b840d..85432bb 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -1,6 +1,4 @@
-import numpy as np
import torch
-import os
from collections import OrderedDict
from torch.autograd import Variable
import itertools
@@ -8,7 +6,6 @@ import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
-import sys
class CycleGANModel(BaseModel):
@@ -17,10 +14,6 @@ class CycleGANModel(BaseModel):
def initialize(self, opt):
BaseModel.initialize(self, opt)
-
- nb = opt.batchSize
- size = opt.fineSize
-
# load/define networks
# The naming conversion is different from those used in the paper
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
@@ -47,7 +40,6 @@ class CycleGANModel(BaseModel):
self.load_network(self.netD_B, 'D_B', which_epoch)
if self.isTrain:
- self.old_lr = opt.lr
self.fake_A_pool = ImagePool(opt.pool_size)
self.fake_B_pool = ImagePool(opt.pool_size)
# define loss functions
@@ -129,7 +121,7 @@ class CycleGANModel(BaseModel):
self.loss_D_B = loss_D_B.data[0]
def backward_G(self):
- lambda_idt = self.opt.identity
+ lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
@@ -200,8 +192,8 @@ class CycleGANModel(BaseModel):
def get_current_errors(self):
ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A),
- ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
- if self.opt.identity > 0.0:
+ ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
+ if self.opt.lambda_identity > 0.0:
ret_errors['idt_A'] = self.loss_idt_A
ret_errors['idt_B'] = self.loss_idt_B
return ret_errors
@@ -215,7 +207,7 @@ class CycleGANModel(BaseModel):
rec_B = util.tensor2im(self.rec_B)
ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
- if self.opt.isTrain and self.opt.identity > 0.0:
+ if self.opt.isTrain and self.opt.lambda_identity > 0.0:
ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
return ret_visuals
diff --git a/models/models.py b/models/models.py
index d5bb9d8..39cc020 100644
--- a/models/models.py
+++ b/models/models.py
@@ -1,4 +1,3 @@
-
def create_model(opt):
model = None
print(opt.model)
diff --git a/models/networks.py b/models/networks.py
index da2f59c..b118c6a 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -4,7 +4,6 @@ from torch.nn import init
import functools
from torch.autograd import Variable
from torch.optim import lr_scheduler
-import numpy as np
###############################################################################
# Functions
###############################################################################
@@ -434,6 +433,7 @@ class NLayerDiscriminator(nn.Module):
else:
return self.model(input)
+
class PixelDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(PixelDiscriminator, self).__init__()
@@ -442,7 +442,7 @@ class PixelDiscriminator(nn.Module):
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
-
+
self.net = [
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
nn.LeakyReLU(0.2, True),
@@ -461,4 +461,3 @@ class PixelDiscriminator(nn.Module):
return nn.parallel.data_parallel(self.net, input, self.gpu_ids)
else:
return self.net(input)
-
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index 74a941e..78f8d69 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -1,6 +1,4 @@
-import numpy as np
import torch
-import os
from collections import OrderedDict
from torch.autograd import Variable
import util.util as util
@@ -32,7 +30,6 @@ class Pix2PixModel(BaseModel):
if self.isTrain:
self.fake_AB_pool = ImagePool(opt.pool_size)
- self.old_lr = opt.lr
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionL1 = torch.nn.L1Loss()
diff --git a/options/base_options.py b/options/base_options.py
index 1b5ebaf..5bdd85e 100644
--- a/options/base_options.py
+++ b/options/base_options.py
@@ -31,12 +31,13 @@ class BaseOptions():
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
- self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
+ self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
self.parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
- self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
+ self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"),
+ help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
diff --git a/options/train_options.py b/options/train_options.py
index 603d76a..3d05a2b 100644
--- a/options/train_options.py
+++ b/options/train_options.py
@@ -21,10 +21,12 @@ class TrainOptions(BaseOptions):
self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
+ self.parser.add_argument('--lambda_identity', type=float, default=0.5,
+ help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss.'
+ 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
- self.parser.add_argument('--identity', type=float, default=0.5, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1')
self.isTrain = True
diff --git a/test.py b/test.py
index 65e79d7..863e550 100644
--- a/test.py
+++ b/test.py
@@ -1,4 +1,3 @@
-import time
import os
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
diff --git a/train.py b/train.py
index 6dbd66b..61b596a 100644
--- a/train.py
+++ b/train.py
@@ -16,10 +16,13 @@ total_steps = 0
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
+ iter_data_time = time.time()
epoch_iter = 0
for i, data in enumerate(dataset):
iter_start_time = time.time()
+ if total_steps % opt.print_freq == 0:
+ t_data = iter_start_time - iter_data_time
visualizer.reset()
total_steps += opt.batchSize
epoch_iter += opt.batchSize
@@ -33,15 +36,16 @@ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
if total_steps % opt.print_freq == 0:
errors = model.get_current_errors()
t = (time.time() - iter_start_time) / opt.batchSize
- visualizer.print_current_errors(epoch, epoch_iter, errors, t)
+ visualizer.print_current_errors(epoch, epoch_iter, errors, t, t_data)
if opt.display_id > 0:
- visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
+ visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors)
if total_steps % opt.save_latest_freq == 0:
print('saving the latest model (epoch %d, total_steps %d)' %
(epoch, total_steps))
model.save('latest')
+ iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' %
(epoch, total_steps))
diff --git a/util/image_pool.py b/util/image_pool.py
index ada1627..634fd81 100644
--- a/util/image_pool.py
+++ b/util/image_pool.py
@@ -1,5 +1,4 @@
import random
-import numpy as np
import torch
from torch.autograd import Variable
@@ -24,7 +23,7 @@ class ImagePool():
else:
p = random.uniform(0, 1)
if p > 0.5:
- random_id = random.randint(0, self.pool_size-1)
+ random_id = random.randint(0, self.pool_size - 1)
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
diff --git a/util/util.py b/util/util.py
index 26b259a..7a452a6 100644
--- a/util/util.py
+++ b/util/util.py
@@ -2,11 +2,7 @@ from __future__ import print_function
import torch
import numpy as np
from PIL import Image
-import inspect
-import re
-import numpy as np
import os
-import collections
# Converts a Tensor into a Numpy array
diff --git a/util/visualizer.py b/util/visualizer.py
index 08eeb75..a98b512 100644
--- a/util/visualizer.py
+++ b/util/visualizer.py
@@ -56,7 +56,7 @@ class Visualizer():
if idx % ncols == 0:
label_html += '<tr>%s</tr>' % label_html_row
label_html_row = ''
- white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
+ white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
while idx % ncols != 0:
images.append(white_image)
label_html_row += '<td></td>'
@@ -114,8 +114,8 @@ class Visualizer():
win=self.display_id)
# errors: same format as |errors| of plotCurrentErrors
- def print_current_errors(self, epoch, i, errors, t):
- message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
+ def print_current_errors(self, epoch, i, errors, t, t_data):
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
for k, v in errors.items():
message += '%s: %.3f ' % (k, v)