summaryrefslogtreecommitdiff
path: root/util/visualizer.py
diff options
context:
space:
mode:
authorRuotian Luo <rluo@ttic.edu>2017-04-18 21:10:01 -0500
committerRuotian Luo <rluo@ttic.edu>2017-04-18 21:10:01 -0500
commitc41ab7acdb1e9191d91991d6461ddc4dff3243f7 (patch)
tree29f109b0af57ed36ae5945aa5204388bc25b9818 /util/visualizer.py
parent97e896a587e4f98e57a3f282b6d7994c1fe637dc (diff)
Use visdom instead of display
Diffstat (limited to 'util/visualizer.py')
-rw-r--r--util/visualizer.py28
1 files changed, 20 insertions, 8 deletions
diff --git a/util/visualizer.py b/util/visualizer.py
index 4daf506..7df9cce 100644
--- a/util/visualizer.py
+++ b/util/visualizer.py
@@ -10,14 +10,14 @@ class Visualizer():
def __init__(self, opt):
# self.opt = opt
self.display_id = opt.display_id
+ self.name = opt.name
if self.display_id > 0:
- from . import display
- self.display = display
+ import visdom
+ self.vis = visdom.Visdom()
else:
from . import html
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
- self.name = opt.name
self.win_size = opt.display_winsize
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
@@ -26,10 +26,10 @@ class Visualizer():
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch):
if self.display_id > 0: # show images in the browser
- idx = 0
+ idx = 1
for label, image_numpy in visuals.items():
- image_numpy = np.flipud(image_numpy)
- self.display.image(image_numpy, title=label,
+ #image_numpy = np.flipud(image_numpy)
+ self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),
win=self.display_id + idx)
idx += 1
else: # save images to a web directory
@@ -53,8 +53,20 @@ class Visualizer():
webpage.save()
# st()
# errors: dictionary of error labels and values
- def plot_current_errors(self, epoch, i, opt, errors):
- pass
+ def plot_current_errors(self, epoch, counter_ratio, opt, errors):
+ if not hasattr(self, 'plot_data'):
+ self.plot_data = {'X':[],'Y':[], 'legend':errors.keys()}
+ self.plot_data['X'].append(epoch + counter_ratio)
+ self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
+ self.vis.line(
+ X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
+ Y=np.array(self.plot_data['Y']),
+ opts={
+ 'title': self.name + ' loss over time',
+ 'legend': self.plot_data['legend'],
+ 'xlabel': 'epoch',
+ 'ylabel': 'loss'},
+ win=self.display_id)
# errors: same format as |errors| of plotCurrentErrors
def print_current_errors(self, epoch, i, errors, start_time):