diff options
| author | Ruotian Luo <rluo@ttic.edu> | 2017-04-18 21:10:01 -0500 |
|---|---|---|
| committer | Ruotian Luo <rluo@ttic.edu> | 2017-04-18 21:10:01 -0500 |
| commit | c41ab7acdb1e9191d91991d6461ddc4dff3243f7 (patch) | |
| tree | 29f109b0af57ed36ae5945aa5204388bc25b9818 /util/visualizer.py | |
| parent | 97e896a587e4f98e57a3f282b6d7994c1fe637dc (diff) | |
Use visdom instead of display
Diffstat (limited to 'util/visualizer.py')
| -rw-r--r-- | util/visualizer.py | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/util/visualizer.py b/util/visualizer.py index 4daf506..7df9cce 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -10,14 +10,14 @@ class Visualizer(): def __init__(self, opt): # self.opt = opt self.display_id = opt.display_id + self.name = opt.name if self.display_id > 0: - from . import display - self.display = display + import visdom + self.vis = visdom.Visdom() else: from . import html self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') self.img_dir = os.path.join(self.web_dir, 'images') - self.name = opt.name self.win_size = opt.display_winsize print('create web directory %s...' % self.web_dir) util.mkdirs([self.web_dir, self.img_dir]) @@ -26,10 +26,10 @@ class Visualizer(): # |visuals|: dictionary of images to display or save def display_current_results(self, visuals, epoch): if self.display_id > 0: # show images in the browser - idx = 0 + idx = 1 for label, image_numpy in visuals.items(): - image_numpy = np.flipud(image_numpy) - self.display.image(image_numpy, title=label, + #image_numpy = np.flipud(image_numpy) + self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), win=self.display_id + idx) idx += 1 else: # save images to a web directory @@ -53,8 +53,20 @@ class Visualizer(): webpage.save() # st() # errors: dictionary of error labels and values - def plot_current_errors(self, epoch, i, opt, errors): - pass + def plot_current_errors(self, epoch, counter_ratio, opt, errors): + if not hasattr(self, 'plot_data'): + self.plot_data = {'X':[],'Y':[], 'legend':errors.keys()} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) # errors: same format as |errors| of plotCurrentErrors def print_current_errors(self, epoch, i, errors, start_time): |
