diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-11-04 02:27:18 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-11-04 02:27:18 -0700 |
| commit | 6b8e96c4bbd73a1e1d4e126d795a26fd0dae983c (patch) | |
| tree | 67072a0442b705b5d5b29840f4b41e13af1d4597 | |
| parent | 5f858eb70a3c110238f74a592bad0e7be601c539 (diff) | |
add update_html_freq flag
| -rw-r--r-- | models/base_model.py | 3 | ||||
| -rw-r--r-- | models/cycle_gan_model.py | 88 | ||||
| -rw-r--r-- | models/networks.py | 4 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 6 | ||||
| -rw-r--r-- | options/base_options.py | 2 | ||||
| -rw-r--r-- | options/train_options.py | 4 | ||||
| -rw-r--r-- | train.py | 4 | ||||
| -rw-r--r-- | util/image_pool.py | 2 | ||||
| -rw-r--r-- | util/visualizer.py | 32 |
9 files changed, 74 insertions, 71 deletions
diff --git a/models/base_model.py b/models/base_model.py index d62d189..646a014 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -44,13 +44,14 @@ class BaseModel(): save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if len(gpu_ids) and torch.cuda.is_available(): - network.cuda(device_id=gpu_ids[0]) # network.cuda(device=gpu_ids[0]) for the latest version. + network.cuda(device_id=gpu_ids[0]) # network.cuda(device=gpu_ids[0]) for the latest version. # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) network.load_state_dict(torch.load(save_path)) + # update learning rate (called once every epoch) def update_learning_rate(self): for scheduler in self.schedulers: diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index ff7330b..71a447d 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -90,13 +90,15 @@ class CycleGANModel(BaseModel): self.real_B = Variable(self.input_B) def test(self): - self.real_A = Variable(self.input_A, volatile=True) - self.fake_B = self.netG_A.forward(self.real_A) - self.rec_A = self.netG_B.forward(self.fake_B) + real_A = Variable(self.input_A, volatile=True) + fake_B = self.netG_A.forward(real_A) + self.rec_A = self.netG_B.forward(fake_B).data + self.fake_B = fake_B.data - self.real_B = Variable(self.input_B, volatile=True) - self.fake_A = self.netG_B.forward(self.real_B) - self.rec_B = self.netG_A.forward(self.fake_A) + real_B = Variable(self.input_B, volatile=True) + fake_A = self.netG_B.forward(real_B) + self.rec_B = self.netG_A.forward(fake_A).data + self.fake_A = self.fake_A.data # get image paths def get_image_paths(self): @@ -117,11 +119,13 @@ class CycleGANModel(BaseModel): def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) - self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + self.loss_D_A = loss_D_A.data[0] def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) - self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + self.loss_D_B = loss_D_B.data[0] def backward_G(self): lambda_idt = self.opt.identity @@ -135,53 +139,49 @@ class CycleGANModel(BaseModel): # G_B should be identity if real_A is fed. idt_B = self.netG_B.forward(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt - + self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.data[0] - self.loss_idt_B = loss_idt_B.data[0] - + self.loss_idt_B = loss_idt_B.data[0] + else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 - # GAN loss - # D_A(G_A(A)) + # GAN loss D_A(G_A(A)) fake_B = self.netG_A.forward(self.real_A) pred_fake = self.netD_A.forward(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) - - # D_B(G_B(B)) + + # GAN loss D_B(G_B(B)) fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) - + # Forward cycle loss rec_A = self.netG_B.forward(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A - + # Backward cycle loss rec_B = self.netG_A.forward(fake_A) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B - + # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() - + self.fake_B = fake_B.data self.fake_A = fake_A.data self.rec_A = rec_A.data self.rec_B = rec_B.data - + self.loss_G_A = loss_G_A.data[0] self.loss_G_B = loss_G_B.data[0] self.loss_cycle_A = loss_cycle_A.data[0] self.loss_cycle_B = loss_cycle_B.data[0] - - - def optimize_parameters(self): # forward @@ -200,36 +200,26 @@ class CycleGANModel(BaseModel): self.optimizer_D_B.step() def get_current_errors(self): - D_A = self.loss_D_A.data[0] - G_A = self.loss_G_A - Cyc_A = self.loss_cycle_A - D_B = self.loss_D_B.data[0] - G_B = self.loss_G_B - Cyc_B = self.loss_cycle_B + ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), + ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.identity > 0.0: - idt_A = self.loss_idt_A - idt_B = self.loss_idt_B - return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), - ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) - else: - return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), - ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) + ret_errors['idt_A'] = self.loss_idt_A + ret_errors['idt_B'] = self.loss_idt_B + return ret_errors def get_current_visuals(self): - real_A = util.tensor2im(self.real_A.data) - fake_B = util.tensor2im(self.fake_B.data) - rec_A = util.tensor2im(self.rec_A.data) - real_B = util.tensor2im(self.real_B.data) - fake_A = util.tensor2im(self.fake_A.data) - rec_B = util.tensor2im(self.rec_B.data) + real_A = util.tensor2im(self.input_A) + fake_B = util.tensor2im(self.fake_B) + rec_A = util.tensor2im(self.rec_A) + real_B = util.tensor2im(self.input_B) + fake_A = util.tensor2im(self.fake_A) + rec_B = util.tensor2im(self.rec_B) + ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.opt.isTrain and self.opt.identity > 0.0: - idt_A = util.tensor2im(self.idt_A) - idt_B = util.tensor2im(self.idt_B) - return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), - ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) - else: - return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), - ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) + ret_visuals['idt_A'] = util.tensor2im(self.idt_A) + ret_visuals['idt_B'] = util.tensor2im(self.idt_B) + return ret_visuals def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) diff --git a/models/networks.py b/models/networks.py index 949659d..d071ac4 100644 --- a/models/networks.py +++ b/models/networks.py @@ -118,7 +118,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo else: raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: - netG.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version. + netG.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version. init_weights(netG, init_type=init_type) return netG @@ -139,7 +139,7 @@ def define_D(input_nc, ndf, which_model_netD, raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) if use_gpu: - netD.cuda(device_id=gpu_ids[0]) # or netD.cuda(device=gpu_ids[0]) for latest version. + netD.cuda(device_id=gpu_ids[0]) # or netD.cuda(device=gpu_ids[0]) for latest version. init_weights(netD, init_type=init_type) return netD diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 18ba53f..8cd494f 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -87,12 +87,12 @@ class Pix2PixModel(BaseModel): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) - self.pred_fake = self.netD.forward(fake_AB.detach()) - self.loss_D_fake = self.criterionGAN(self.pred_fake, False) + pred_fake = self.netD.forward(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) - self.pred_real = self.netD.forward(real_AB) + pred_real = self.netD.forward(real_AB) self.loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss diff --git a/options/base_options.py b/options/base_options.py index b2d5360..28ca673 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -3,6 +3,7 @@ import os from util import util import torch + class BaseOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -33,7 +34,6 @@ class BaseOptions(): self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') - self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') diff --git a/options/train_options.py b/options/train_options.py index 32120ec..603d76a 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -5,6 +5,8 @@ class TrainOptions(BaseOptions): def initialize(self): BaseOptions.initialize(self) self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') + self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') + self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') @@ -23,6 +25,6 @@ class TrainOptions(BaseOptions): self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') - self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') + self.parser.add_argument('--identity', type=float, default=0.5, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') self.isTrain = True @@ -20,13 +20,15 @@ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): for i, data in enumerate(dataset): iter_start_time = time.time() + visualizer.reset() total_steps += opt.batchSize epoch_iter += opt.batchSize model.set_input(data) model.optimize_parameters() if total_steps % opt.display_freq == 0: - visualizer.display_current_results(model.get_current_visuals(), epoch) + save_result = total_steps % opt.update_html_freq == 0 + visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) if total_steps % opt.print_freq == 0: errors = model.get_current_errors() diff --git a/util/image_pool.py b/util/image_pool.py index 9f34a09..5a242e6 100644 --- a/util/image_pool.py +++ b/util/image_pool.py @@ -2,6 +2,8 @@ import random import numpy as np import torch from torch.autograd import Variable + + class ImagePool(): def __init__(self, pool_size): self.pool_size = pool_size diff --git a/util/visualizer.py b/util/visualizer.py index 02a36b7..22fe9da 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -4,7 +4,8 @@ import ntpath import time from . import util from . import html -from pdb import set_trace as st + + class Visualizer(): def __init__(self, opt): # self.opt = opt @@ -12,9 +13,10 @@ class Visualizer(): self.use_html = opt.isTrain and not opt.no_html self.win_size = opt.display_winsize self.name = opt.name + self.saved = False if self.display_id > 0: import visdom - self.vis = visdom.Visdom(port = opt.display_port) + self.vis = visdom.Visdom(port=opt.display_port) self.display_single_pane_ncols = opt.display_single_pane_ncols if self.use_html: @@ -27,15 +29,18 @@ class Visualizer(): now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) + def reset(self): + self.saved = False + # |visuals|: dictionary of images to display or save - def display_current_results(self, visuals, epoch): - if self.display_id > 0: # show images in the browser + def display_current_results(self, visuals, epoch, save_result): + if self.display_id > 0: # show images in the browser if self.display_single_pane_ncols > 0: h, w = next(iter(visuals.values())).shape[:2] table_css = """<style> - table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center} - table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black} -</style>""" % (w, h) + table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center} + table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black} + </style>""" % (w, h) ncols = self.display_single_pane_ncols title = self.name label_html = '' @@ -61,16 +66,17 @@ class Visualizer(): self.vis.images(images, nrow=ncols, win=self.display_id + 1, padding=2, opts=dict(title=title + ' images')) label_html = '<table>%s</table>' % label_html - self.vis.text(table_css + label_html, win = self.display_id + 2, + self.vis.text(table_css + label_html, win=self.display_id + 2, opts=dict(title=title + ' labels')) else: idx = 1 for label, image_numpy in visuals.items(): - self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), - win=self.display_id + idx) + self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), + win=self.display_id + idx) idx += 1 - if self.use_html: # save images to a html file + if self.use_html and (save_result or not self.saved): # save images to a html file + self.saved = True for label, image_numpy in visuals.items(): img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) util.save_image(image_numpy, img_path) @@ -93,11 +99,11 @@ class Visualizer(): # errors: dictionary of error labels and values def plot_current_errors(self, epoch, counter_ratio, opt, errors): if not hasattr(self, 'plot_data'): - self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} + self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())} self.plot_data['X'].append(epoch + counter_ratio) self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) self.vis.line( - X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), + X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), Y=np.array(self.plot_data['Y']), opts={ 'title': self.name + ' loss over time', |
