From df16bf1b18e03213efe9b22a0155b833a06c6a9c Mon Sep 17 00:00:00 2001 From: Jean-Philippe Mercier Date: Thu, 2 Nov 2017 18:02:02 -0400 Subject: gpu memory leaks A tensor is now fed to ImagePool() instead of a Variable --- util/image_pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'util/image_pool.py') diff --git a/util/image_pool.py b/util/image_pool.py index 152ef5b..9f34a09 100644 --- a/util/image_pool.py +++ b/util/image_pool.py @@ -13,7 +13,7 @@ class ImagePool(): if self.pool_size == 0: return images return_images = [] - for image in images.data: + for image in images: image = torch.unsqueeze(image, 0) if self.num_imgs < self.pool_size: self.num_imgs = self.num_imgs + 1 -- cgit v1.2.3-70-g09d2 From 6b8e96c4bbd73a1e1d4e126d795a26fd0dae983c Mon Sep 17 00:00:00 2001 From: junyanz Date: Sat, 4 Nov 2017 02:27:18 -0700 Subject: add update_html_freq flag --- models/base_model.py | 3 +- models/cycle_gan_model.py | 88 +++++++++++++++++++++-------------------------- models/networks.py | 4 +-- models/pix2pix_model.py | 6 ++-- options/base_options.py | 2 +- options/train_options.py | 4 ++- train.py | 4 ++- util/image_pool.py | 2 ++ util/visualizer.py | 32 ++++++++++------- 9 files changed, 74 insertions(+), 71 deletions(-) (limited to 'util/image_pool.py') diff --git a/models/base_model.py b/models/base_model.py index d62d189..646a014 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -44,13 +44,14 @@ class BaseModel(): save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if len(gpu_ids) and torch.cuda.is_available(): - network.cuda(device_id=gpu_ids[0]) # network.cuda(device=gpu_ids[0]) for the latest version. + network.cuda(device_id=gpu_ids[0]) # network.cuda(device=gpu_ids[0]) for the latest version. # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) network.load_state_dict(torch.load(save_path)) + # update learning rate (called once every epoch) def update_learning_rate(self): for scheduler in self.schedulers: diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index ff7330b..71a447d 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -90,13 +90,15 @@ class CycleGANModel(BaseModel): self.real_B = Variable(self.input_B) def test(self): - self.real_A = Variable(self.input_A, volatile=True) - self.fake_B = self.netG_A.forward(self.real_A) - self.rec_A = self.netG_B.forward(self.fake_B) + real_A = Variable(self.input_A, volatile=True) + fake_B = self.netG_A.forward(real_A) + self.rec_A = self.netG_B.forward(fake_B).data + self.fake_B = fake_B.data - self.real_B = Variable(self.input_B, volatile=True) - self.fake_A = self.netG_B.forward(self.real_B) - self.rec_B = self.netG_A.forward(self.fake_A) + real_B = Variable(self.input_B, volatile=True) + fake_A = self.netG_B.forward(real_B) + self.rec_B = self.netG_A.forward(fake_A).data + self.fake_A = self.fake_A.data # get image paths def get_image_paths(self): @@ -117,11 +119,13 @@ class CycleGANModel(BaseModel): def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) - self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + self.loss_D_A = loss_D_A.data[0] def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) - self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + self.loss_D_B = loss_D_B.data[0] def backward_G(self): lambda_idt = self.opt.identity @@ -135,53 +139,49 @@ class CycleGANModel(BaseModel): # G_B should be identity if real_A is fed. idt_B = self.netG_B.forward(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt - + self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.data[0] - self.loss_idt_B = loss_idt_B.data[0] - + self.loss_idt_B = loss_idt_B.data[0] + else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 - # GAN loss - # D_A(G_A(A)) + # GAN loss D_A(G_A(A)) fake_B = self.netG_A.forward(self.real_A) pred_fake = self.netD_A.forward(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) - - # D_B(G_B(B)) + + # GAN loss D_B(G_B(B)) fake_A = self.netG_B.forward(self.real_B) pred_fake = self.netD_B.forward(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) - + # Forward cycle loss rec_A = self.netG_B.forward(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A - + # Backward cycle loss rec_B = self.netG_A.forward(fake_A) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B - + # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() - + self.fake_B = fake_B.data self.fake_A = fake_A.data self.rec_A = rec_A.data self.rec_B = rec_B.data - + self.loss_G_A = loss_G_A.data[0] self.loss_G_B = loss_G_B.data[0] self.loss_cycle_A = loss_cycle_A.data[0] self.loss_cycle_B = loss_cycle_B.data[0] - - - def optimize_parameters(self): # forward @@ -200,36 +200,26 @@ class CycleGANModel(BaseModel): self.optimizer_D_B.step() def get_current_errors(self): - D_A = self.loss_D_A.data[0] - G_A = self.loss_G_A - Cyc_A = self.loss_cycle_A - D_B = self.loss_D_B.data[0] - G_B = self.loss_G_B - Cyc_B = self.loss_cycle_B + ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), + ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.identity > 0.0: - idt_A = self.loss_idt_A - idt_B = self.loss_idt_B - return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), - ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) - else: - return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), - ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) + ret_errors['idt_A'] = self.loss_idt_A + ret_errors['idt_B'] = self.loss_idt_B + return ret_errors def get_current_visuals(self): - real_A = util.tensor2im(self.real_A.data) - fake_B = util.tensor2im(self.fake_B.data) - rec_A = util.tensor2im(self.rec_A.data) - real_B = util.tensor2im(self.real_B.data) - fake_A = util.tensor2im(self.fake_A.data) - rec_B = util.tensor2im(self.rec_B.data) + real_A = util.tensor2im(self.input_A) + fake_B = util.tensor2im(self.fake_B) + rec_A = util.tensor2im(self.rec_A) + real_B = util.tensor2im(self.input_B) + fake_A = util.tensor2im(self.fake_A) + rec_B = util.tensor2im(self.rec_B) + ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.opt.isTrain and self.opt.identity > 0.0: - idt_A = util.tensor2im(self.idt_A) - idt_B = util.tensor2im(self.idt_B) - return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), - ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) - else: - return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), - ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) + ret_visuals['idt_A'] = util.tensor2im(self.idt_A) + ret_visuals['idt_B'] = util.tensor2im(self.idt_B) + return ret_visuals def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) diff --git a/models/networks.py b/models/networks.py index 949659d..d071ac4 100644 --- a/models/networks.py +++ b/models/networks.py @@ -118,7 +118,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo else: raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: - netG.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version. + netG.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version. init_weights(netG, init_type=init_type) return netG @@ -139,7 +139,7 @@ def define_D(input_nc, ndf, which_model_netD, raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) if use_gpu: - netD.cuda(device_id=gpu_ids[0]) # or netD.cuda(device=gpu_ids[0]) for latest version. + netD.cuda(device_id=gpu_ids[0]) # or netD.cuda(device=gpu_ids[0]) for latest version. init_weights(netD, init_type=init_type) return netD diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 18ba53f..8cd494f 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -87,12 +87,12 @@ class Pix2PixModel(BaseModel): # Fake # stop backprop to the generator by detaching fake_B fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) - self.pred_fake = self.netD.forward(fake_AB.detach()) - self.loss_D_fake = self.criterionGAN(self.pred_fake, False) + pred_fake = self.netD.forward(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) - self.pred_real = self.netD.forward(real_AB) + pred_real = self.netD.forward(real_AB) self.loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss diff --git a/options/base_options.py b/options/base_options.py index b2d5360..28ca673 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -3,6 +3,7 @@ import os from util import util import torch + class BaseOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -33,7 +34,6 @@ class BaseOptions(): self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') - self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') diff --git a/options/train_options.py b/options/train_options.py index 32120ec..603d76a 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -5,6 +5,8 @@ class TrainOptions(BaseOptions): def initialize(self): BaseOptions.initialize(self) self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') + self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') + self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') @@ -23,6 +25,6 @@ class TrainOptions(BaseOptions): self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') - self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') + self.parser.add_argument('--identity', type=float, default=0.5, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') self.isTrain = True diff --git a/train.py b/train.py index 7d2a5e9..6dbd66b 100644 --- a/train.py +++ b/train.py @@ -20,13 +20,15 @@ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): for i, data in enumerate(dataset): iter_start_time = time.time() + visualizer.reset() total_steps += opt.batchSize epoch_iter += opt.batchSize model.set_input(data) model.optimize_parameters() if total_steps % opt.display_freq == 0: - visualizer.display_current_results(model.get_current_visuals(), epoch) + save_result = total_steps % opt.update_html_freq == 0 + visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) if total_steps % opt.print_freq == 0: errors = model.get_current_errors() diff --git a/util/image_pool.py b/util/image_pool.py index 9f34a09..5a242e6 100644 --- a/util/image_pool.py +++ b/util/image_pool.py @@ -2,6 +2,8 @@ import random import numpy as np import torch from torch.autograd import Variable + + class ImagePool(): def __init__(self, pool_size): self.pool_size = pool_size diff --git a/util/visualizer.py b/util/visualizer.py index 02a36b7..22fe9da 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -4,7 +4,8 @@ import ntpath import time from . import util from . import html -from pdb import set_trace as st + + class Visualizer(): def __init__(self, opt): # self.opt = opt @@ -12,9 +13,10 @@ class Visualizer(): self.use_html = opt.isTrain and not opt.no_html self.win_size = opt.display_winsize self.name = opt.name + self.saved = False if self.display_id > 0: import visdom - self.vis = visdom.Visdom(port = opt.display_port) + self.vis = visdom.Visdom(port=opt.display_port) self.display_single_pane_ncols = opt.display_single_pane_ncols if self.use_html: @@ -27,15 +29,18 @@ class Visualizer(): now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) + def reset(self): + self.saved = False + # |visuals|: dictionary of images to display or save - def display_current_results(self, visuals, epoch): - if self.display_id > 0: # show images in the browser + def display_current_results(self, visuals, epoch, save_result): + if self.display_id > 0: # show images in the browser if self.display_single_pane_ncols > 0: h, w = next(iter(visuals.values())).shape[:2] table_css = """""" % (w, h) + table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center} + table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black} + """ % (w, h) ncols = self.display_single_pane_ncols title = self.name label_html = '' @@ -61,16 +66,17 @@ class Visualizer(): self.vis.images(images, nrow=ncols, win=self.display_id + 1, padding=2, opts=dict(title=title + ' images')) label_html = '%s
' % label_html - self.vis.text(table_css + label_html, win = self.display_id + 2, + self.vis.text(table_css + label_html, win=self.display_id + 2, opts=dict(title=title + ' labels')) else: idx = 1 for label, image_numpy in visuals.items(): - self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), - win=self.display_id + idx) + self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), + win=self.display_id + idx) idx += 1 - if self.use_html: # save images to a html file + if self.use_html and (save_result or not self.saved): # save images to a html file + self.saved = True for label, image_numpy in visuals.items(): img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) util.save_image(image_numpy, img_path) @@ -93,11 +99,11 @@ class Visualizer(): # errors: dictionary of error labels and values def plot_current_errors(self, epoch, counter_ratio, opt, errors): if not hasattr(self, 'plot_data'): - self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} + self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())} self.plot_data['X'].append(epoch + counter_ratio) self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) self.vis.line( - X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), + X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), Y=np.array(self.plot_data['Y']), opts={ 'title': self.name + ' loss over time', -- cgit v1.2.3-70-g09d2 From 7a9021d4f131ee059d49ff9b2d135e6543f75763 Mon Sep 17 00:00:00 2001 From: junyanz Date: Sat, 4 Nov 2017 02:47:39 -0700 Subject: fix small issues --- models/pix2pix_model.py | 4 ++-- util/image_pool.py | 2 +- util/visualizer.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) (limited to 'util/image_pool.py') diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 8cd494f..388a8d3 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -86,14 +86,14 @@ class Pix2PixModel(BaseModel): def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B - fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) + fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data) pred_fake = self.netD.forward(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD.forward(real_AB) - self.loss_D_real = self.criterionGAN(self.pred_real, True) + self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 diff --git a/util/image_pool.py b/util/image_pool.py index 5a242e6..ada1627 100644 --- a/util/image_pool.py +++ b/util/image_pool.py @@ -13,7 +13,7 @@ class ImagePool(): def query(self, images): if self.pool_size == 0: - return images + return Variable(images) return_images = [] for image in images: image = torch.unsqueeze(image, 0) diff --git a/util/visualizer.py b/util/visualizer.py index 22fe9da..e6e7cba 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -13,11 +13,11 @@ class Visualizer(): self.use_html = opt.isTrain and not opt.no_html self.win_size = opt.display_winsize self.name = opt.name + self.opt = opt self.saved = False if self.display_id > 0: import visdom self.vis = visdom.Visdom(port=opt.display_port) - self.display_single_pane_ncols = opt.display_single_pane_ncols if self.use_html: self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') @@ -35,13 +35,13 @@ class Visualizer(): # |visuals|: dictionary of images to display or save def display_current_results(self, visuals, epoch, save_result): if self.display_id > 0: # show images in the browser - if self.display_single_pane_ncols > 0: + ncols = self.opt.display_single_pane_ncols + if ncols > 0: h, w = next(iter(visuals.values())).shape[:2] table_css = """""" % (w, h) - ncols = self.display_single_pane_ncols title = self.name label_html = '' label_html_row = '' @@ -76,7 +76,7 @@ class Visualizer(): idx += 1 if self.use_html and (save_result or not self.saved): # save images to a html file - self.saved = True + self.saved = True for label, image_numpy in visuals.items(): img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) util.save_image(image_numpy, img_path) -- cgit v1.2.3-70-g09d2