summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/base_model.py3
-rw-r--r--models/cycle_gan_model.py88
-rw-r--r--models/networks.py4
-rw-r--r--models/pix2pix_model.py6
-rw-r--r--options/base_options.py2
-rw-r--r--options/train_options.py4
-rw-r--r--train.py4
-rw-r--r--util/image_pool.py2
-rw-r--r--util/visualizer.py32
9 files changed, 74 insertions, 71 deletions
diff --git a/models/base_model.py b/models/base_model.py
index d62d189..646a014 100644
--- a/models/base_model.py
+++ b/models/base_model.py
@@ -44,13 +44,14 @@ class BaseModel():
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if len(gpu_ids) and torch.cuda.is_available():
- network.cuda(device_id=gpu_ids[0]) # network.cuda(device=gpu_ids[0]) for the latest version.
+ network.cuda(device_id=gpu_ids[0]) # network.cuda(device=gpu_ids[0]) for the latest version.
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
network.load_state_dict(torch.load(save_path))
+
# update learning rate (called once every epoch)
def update_learning_rate(self):
for scheduler in self.schedulers:
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index ff7330b..71a447d 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -90,13 +90,15 @@ class CycleGANModel(BaseModel):
self.real_B = Variable(self.input_B)
def test(self):
- self.real_A = Variable(self.input_A, volatile=True)
- self.fake_B = self.netG_A.forward(self.real_A)
- self.rec_A = self.netG_B.forward(self.fake_B)
+ real_A = Variable(self.input_A, volatile=True)
+ fake_B = self.netG_A.forward(real_A)
+ self.rec_A = self.netG_B.forward(fake_B).data
+ self.fake_B = fake_B.data
- self.real_B = Variable(self.input_B, volatile=True)
- self.fake_A = self.netG_B.forward(self.real_B)
- self.rec_B = self.netG_A.forward(self.fake_A)
+ real_B = Variable(self.input_B, volatile=True)
+ fake_A = self.netG_B.forward(real_B)
+ self.rec_B = self.netG_A.forward(fake_A).data
+ self.fake_A = self.fake_A.data
# get image paths
def get_image_paths(self):
@@ -117,11 +119,13 @@ class CycleGANModel(BaseModel):
def backward_D_A(self):
fake_B = self.fake_B_pool.query(self.fake_B)
- self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
+ loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
+ self.loss_D_A = loss_D_A.data[0]
def backward_D_B(self):
fake_A = self.fake_A_pool.query(self.fake_A)
- self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
+ loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
+ self.loss_D_B = loss_D_B.data[0]
def backward_G(self):
lambda_idt = self.opt.identity
@@ -135,53 +139,49 @@ class CycleGANModel(BaseModel):
# G_B should be identity if real_A is fed.
idt_B = self.netG_B.forward(self.real_A)
loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt
-
+
self.idt_A = idt_A.data
self.idt_B = idt_B.data
self.loss_idt_A = loss_idt_A.data[0]
- self.loss_idt_B = loss_idt_B.data[0]
-
+ self.loss_idt_B = loss_idt_B.data[0]
+
else:
loss_idt_A = 0
loss_idt_B = 0
self.loss_idt_A = 0
self.loss_idt_B = 0
- # GAN loss
- # D_A(G_A(A))
+ # GAN loss D_A(G_A(A))
fake_B = self.netG_A.forward(self.real_A)
pred_fake = self.netD_A.forward(fake_B)
loss_G_A = self.criterionGAN(pred_fake, True)
-
- # D_B(G_B(B))
+
+ # GAN loss D_B(G_B(B))
fake_A = self.netG_B.forward(self.real_B)
pred_fake = self.netD_B.forward(fake_A)
loss_G_B = self.criterionGAN(pred_fake, True)
-
+
# Forward cycle loss
rec_A = self.netG_B.forward(fake_B)
loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A
-
+
# Backward cycle loss
rec_B = self.netG_A.forward(fake_A)
loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
-
+
# combined loss
loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
loss_G.backward()
-
+
self.fake_B = fake_B.data
self.fake_A = fake_A.data
self.rec_A = rec_A.data
self.rec_B = rec_B.data
-
+
self.loss_G_A = loss_G_A.data[0]
self.loss_G_B = loss_G_B.data[0]
self.loss_cycle_A = loss_cycle_A.data[0]
self.loss_cycle_B = loss_cycle_B.data[0]
-
-
-
def optimize_parameters(self):
# forward
@@ -200,36 +200,26 @@ class CycleGANModel(BaseModel):
self.optimizer_D_B.step()
def get_current_errors(self):
- D_A = self.loss_D_A.data[0]
- G_A = self.loss_G_A
- Cyc_A = self.loss_cycle_A
- D_B = self.loss_D_B.data[0]
- G_B = self.loss_G_B
- Cyc_B = self.loss_cycle_B
+ ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A),
+ ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
if self.opt.identity > 0.0:
- idt_A = self.loss_idt_A
- idt_B = self.loss_idt_B
- return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A),
- ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
- else:
- return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
- ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)])
+ ret_errors['idt_A'] = self.loss_idt_A
+ ret_errors['idt_B'] = self.loss_idt_B
+ return ret_errors
def get_current_visuals(self):
- real_A = util.tensor2im(self.real_A.data)
- fake_B = util.tensor2im(self.fake_B.data)
- rec_A = util.tensor2im(self.rec_A.data)
- real_B = util.tensor2im(self.real_B.data)
- fake_A = util.tensor2im(self.fake_A.data)
- rec_B = util.tensor2im(self.rec_B.data)
+ real_A = util.tensor2im(self.input_A)
+ fake_B = util.tensor2im(self.fake_B)
+ rec_A = util.tensor2im(self.rec_A)
+ real_B = util.tensor2im(self.input_B)
+ fake_A = util.tensor2im(self.fake_A)
+ rec_B = util.tensor2im(self.rec_B)
+ ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
+ ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
if self.opt.isTrain and self.opt.identity > 0.0:
- idt_A = util.tensor2im(self.idt_A)
- idt_B = util.tensor2im(self.idt_B)
- return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B),
- ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)])
- else:
- return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
- ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
+ ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
+ ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
+ return ret_visuals
def save(self, label):
self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
diff --git a/models/networks.py b/models/networks.py
index 949659d..d071ac4 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -118,7 +118,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
- netG.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version.
+ netG.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version.
init_weights(netG, init_type=init_type)
return netG
@@ -139,7 +139,7 @@ def define_D(input_nc, ndf, which_model_netD,
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
if use_gpu:
- netD.cuda(device_id=gpu_ids[0]) # or netD.cuda(device=gpu_ids[0]) for latest version.
+ netD.cuda(device_id=gpu_ids[0]) # or netD.cuda(device=gpu_ids[0]) for latest version.
init_weights(netD, init_type=init_type)
return netD
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index 18ba53f..8cd494f 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -87,12 +87,12 @@ class Pix2PixModel(BaseModel):
# Fake
# stop backprop to the generator by detaching fake_B
fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
- self.pred_fake = self.netD.forward(fake_AB.detach())
- self.loss_D_fake = self.criterionGAN(self.pred_fake, False)
+ pred_fake = self.netD.forward(fake_AB.detach())
+ self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Real
real_AB = torch.cat((self.real_A, self.real_B), 1)
- self.pred_real = self.netD.forward(real_AB)
+ pred_real = self.netD.forward(real_AB)
self.loss_D_real = self.criterionGAN(self.pred_real, True)
# Combined loss
diff --git a/options/base_options.py b/options/base_options.py
index b2d5360..28ca673 100644
--- a/options/base_options.py
+++ b/options/base_options.py
@@ -3,6 +3,7 @@ import os
from util import util
import torch
+
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@@ -33,7 +34,6 @@ class BaseOptions():
self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
- self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
diff --git a/options/train_options.py b/options/train_options.py
index 32120ec..603d76a 100644
--- a/options/train_options.py
+++ b/options/train_options.py
@@ -5,6 +5,8 @@ class TrainOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
+ self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
+ self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
@@ -23,6 +25,6 @@ class TrainOptions(BaseOptions):
self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
- self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1')
+ self.parser.add_argument('--identity', type=float, default=0.5, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1')
self.isTrain = True
diff --git a/train.py b/train.py
index 7d2a5e9..6dbd66b 100644
--- a/train.py
+++ b/train.py
@@ -20,13 +20,15 @@ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
for i, data in enumerate(dataset):
iter_start_time = time.time()
+ visualizer.reset()
total_steps += opt.batchSize
epoch_iter += opt.batchSize
model.set_input(data)
model.optimize_parameters()
if total_steps % opt.display_freq == 0:
- visualizer.display_current_results(model.get_current_visuals(), epoch)
+ save_result = total_steps % opt.update_html_freq == 0
+ visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
if total_steps % opt.print_freq == 0:
errors = model.get_current_errors()
diff --git a/util/image_pool.py b/util/image_pool.py
index 9f34a09..5a242e6 100644
--- a/util/image_pool.py
+++ b/util/image_pool.py
@@ -2,6 +2,8 @@ import random
import numpy as np
import torch
from torch.autograd import Variable
+
+
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
diff --git a/util/visualizer.py b/util/visualizer.py
index 02a36b7..22fe9da 100644
--- a/util/visualizer.py
+++ b/util/visualizer.py
@@ -4,7 +4,8 @@ import ntpath
import time
from . import util
from . import html
-from pdb import set_trace as st
+
+
class Visualizer():
def __init__(self, opt):
# self.opt = opt
@@ -12,9 +13,10 @@ class Visualizer():
self.use_html = opt.isTrain and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
+ self.saved = False
if self.display_id > 0:
import visdom
- self.vis = visdom.Visdom(port = opt.display_port)
+ self.vis = visdom.Visdom(port=opt.display_port)
self.display_single_pane_ncols = opt.display_single_pane_ncols
if self.use_html:
@@ -27,15 +29,18 @@ class Visualizer():
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
+ def reset(self):
+ self.saved = False
+
# |visuals|: dictionary of images to display or save
- def display_current_results(self, visuals, epoch):
- if self.display_id > 0: # show images in the browser
+ def display_current_results(self, visuals, epoch, save_result):
+ if self.display_id > 0: # show images in the browser
if self.display_single_pane_ncols > 0:
h, w = next(iter(visuals.values())).shape[:2]
table_css = """<style>
- table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
- table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
-</style>""" % (w, h)
+ table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
+ table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
+ </style>""" % (w, h)
ncols = self.display_single_pane_ncols
title = self.name
label_html = ''
@@ -61,16 +66,17 @@ class Visualizer():
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
padding=2, opts=dict(title=title + ' images'))
label_html = '<table>%s</table>' % label_html
- self.vis.text(table_css + label_html, win = self.display_id + 2,
+ self.vis.text(table_css + label_html, win=self.display_id + 2,
opts=dict(title=title + ' labels'))
else:
idx = 1
for label, image_numpy in visuals.items():
- self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),
- win=self.display_id + idx)
+ self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
+ win=self.display_id + idx)
idx += 1
- if self.use_html: # save images to a html file
+ if self.use_html and (save_result or not self.saved): # save images to a html file
+ self.saved = True
for label, image_numpy in visuals.items():
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
util.save_image(image_numpy, img_path)
@@ -93,11 +99,11 @@ class Visualizer():
# errors: dictionary of error labels and values
def plot_current_errors(self, epoch, counter_ratio, opt, errors):
if not hasattr(self, 'plot_data'):
- self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
+ self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
self.vis.line(
- X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
+ X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',