summaryrefslogtreecommitdiff
path: root/util/visualizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'util/visualizer.py')
-rw-r--r--util/visualizer.py32
1 files changed, 19 insertions, 13 deletions
diff --git a/util/visualizer.py b/util/visualizer.py
index 02a36b7..22fe9da 100644
--- a/util/visualizer.py
+++ b/util/visualizer.py
@@ -4,7 +4,8 @@ import ntpath
import time
from . import util
from . import html
-from pdb import set_trace as st
+
+
class Visualizer():
def __init__(self, opt):
# self.opt = opt
@@ -12,9 +13,10 @@ class Visualizer():
self.use_html = opt.isTrain and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
+ self.saved = False
if self.display_id > 0:
import visdom
- self.vis = visdom.Visdom(port = opt.display_port)
+ self.vis = visdom.Visdom(port=opt.display_port)
self.display_single_pane_ncols = opt.display_single_pane_ncols
if self.use_html:
@@ -27,15 +29,18 @@ class Visualizer():
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
+ def reset(self):
+ self.saved = False
+
# |visuals|: dictionary of images to display or save
- def display_current_results(self, visuals, epoch):
- if self.display_id > 0: # show images in the browser
+ def display_current_results(self, visuals, epoch, save_result):
+ if self.display_id > 0: # show images in the browser
if self.display_single_pane_ncols > 0:
h, w = next(iter(visuals.values())).shape[:2]
table_css = """<style>
- table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
- table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
-</style>""" % (w, h)
+ table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
+ table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
+ </style>""" % (w, h)
ncols = self.display_single_pane_ncols
title = self.name
label_html = ''
@@ -61,16 +66,17 @@ class Visualizer():
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
padding=2, opts=dict(title=title + ' images'))
label_html = '<table>%s</table>' % label_html
- self.vis.text(table_css + label_html, win = self.display_id + 2,
+ self.vis.text(table_css + label_html, win=self.display_id + 2,
opts=dict(title=title + ' labels'))
else:
idx = 1
for label, image_numpy in visuals.items():
- self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),
- win=self.display_id + idx)
+ self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
+ win=self.display_id + idx)
idx += 1
- if self.use_html: # save images to a html file
+ if self.use_html and (save_result or not self.saved): # save images to a html file
+ self.saved = True
for label, image_numpy in visuals.items():
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
util.save_image(image_numpy, img_path)
@@ -93,11 +99,11 @@ class Visualizer():
# errors: dictionary of error labels and values
def plot_current_errors(self, epoch, counter_ratio, opt, errors):
if not hasattr(self, 'plot_data'):
- self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
+ self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
self.vis.line(
- X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
+ X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',