diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-11-04 02:27:18 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-11-04 02:27:18 -0700 |
| commit | 6b8e96c4bbd73a1e1d4e126d795a26fd0dae983c (patch) | |
| tree | 67072a0442b705b5d5b29840f4b41e13af1d4597 /util/visualizer.py | |
| parent | 5f858eb70a3c110238f74a592bad0e7be601c539 (diff) | |
add update_html_freq flag
Diffstat (limited to 'util/visualizer.py')
| -rw-r--r-- | util/visualizer.py | 32 |
1 files changed, 19 insertions, 13 deletions
diff --git a/util/visualizer.py b/util/visualizer.py index 02a36b7..22fe9da 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -4,7 +4,8 @@ import ntpath import time from . import util from . import html -from pdb import set_trace as st + + class Visualizer(): def __init__(self, opt): # self.opt = opt @@ -12,9 +13,10 @@ class Visualizer(): self.use_html = opt.isTrain and not opt.no_html self.win_size = opt.display_winsize self.name = opt.name + self.saved = False if self.display_id > 0: import visdom - self.vis = visdom.Visdom(port = opt.display_port) + self.vis = visdom.Visdom(port=opt.display_port) self.display_single_pane_ncols = opt.display_single_pane_ncols if self.use_html: @@ -27,15 +29,18 @@ class Visualizer(): now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) + def reset(self): + self.saved = False + # |visuals|: dictionary of images to display or save - def display_current_results(self, visuals, epoch): - if self.display_id > 0: # show images in the browser + def display_current_results(self, visuals, epoch, save_result): + if self.display_id > 0: # show images in the browser if self.display_single_pane_ncols > 0: h, w = next(iter(visuals.values())).shape[:2] table_css = """<style> - table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center} - table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black} -</style>""" % (w, h) + table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center} + table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black} + </style>""" % (w, h) ncols = self.display_single_pane_ncols title = self.name label_html = '' @@ -61,16 +66,17 @@ class Visualizer(): self.vis.images(images, nrow=ncols, win=self.display_id + 1, padding=2, opts=dict(title=title + ' images')) label_html = '<table>%s</table>' % label_html - self.vis.text(table_css + label_html, win = self.display_id + 2, + self.vis.text(table_css + label_html, win=self.display_id + 2, opts=dict(title=title + ' labels')) else: idx = 1 for label, image_numpy in visuals.items(): - self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), - win=self.display_id + idx) + self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), + win=self.display_id + idx) idx += 1 - if self.use_html: # save images to a html file + if self.use_html and (save_result or not self.saved): # save images to a html file + self.saved = True for label, image_numpy in visuals.items(): img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) util.save_image(image_numpy, img_path) @@ -93,11 +99,11 @@ class Visualizer(): # errors: dictionary of error labels and values def plot_current_errors(self, epoch, counter_ratio, opt, errors): if not hasattr(self, 'plot_data'): - self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} + self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())} self.plot_data['X'].append(epoch + counter_ratio) self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) self.vis.line( - X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), + X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), Y=np.array(self.plot_data['Y']), opts={ 'title': self.name + ' loss over time', |
