From c41ab7acdb1e9191d91991d6461ddc4dff3243f7 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Tue, 18 Apr 2017 21:10:01 -0500 Subject: Use visdom instead of display --- util/visualizer.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) (limited to 'util/visualizer.py') diff --git a/util/visualizer.py b/util/visualizer.py index 4daf506..7df9cce 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -10,14 +10,14 @@ class Visualizer(): def __init__(self, opt): # self.opt = opt self.display_id = opt.display_id + self.name = opt.name if self.display_id > 0: - from . import display - self.display = display + import visdom + self.vis = visdom.Visdom() else: from . import html self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') self.img_dir = os.path.join(self.web_dir, 'images') - self.name = opt.name self.win_size = opt.display_winsize print('create web directory %s...' % self.web_dir) util.mkdirs([self.web_dir, self.img_dir]) @@ -26,10 +26,10 @@ class Visualizer(): # |visuals|: dictionary of images to display or save def display_current_results(self, visuals, epoch): if self.display_id > 0: # show images in the browser - idx = 0 + idx = 1 for label, image_numpy in visuals.items(): - image_numpy = np.flipud(image_numpy) - self.display.image(image_numpy, title=label, + #image_numpy = np.flipud(image_numpy) + self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), win=self.display_id + idx) idx += 1 else: # save images to a web directory @@ -53,8 +53,20 @@ class Visualizer(): webpage.save() # st() # errors: dictionary of error labels and values - def plot_current_errors(self, epoch, i, opt, errors): - pass + def plot_current_errors(self, epoch, counter_ratio, opt, errors): + if not hasattr(self, 'plot_data'): + self.plot_data = {'X':[],'Y':[], 'legend':errors.keys()} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) # errors: same format as |errors| of plotCurrentErrors def print_current_errors(self, epoch, i, errors, start_time): -- cgit v1.2.3-70-g09d2