diff options
| author | Taesung Park <taesung_park@berkeley.edu> | 2017-04-20 02:43:45 -0700 |
|---|---|---|
| committer | Taesung Park <taesung_park@berkeley.edu> | 2017-04-20 02:43:45 -0700 |
| commit | 07274cd211e6fe180b8287eb65519e66f87d3e2a (patch) | |
| tree | 785e9cbe8cbed5cd9daf7f30be8c13f912723b5c /util | |
| parent | 685cfc40f5e9bb529babc2e5ebc1026b864fd119 (diff) | |
| parent | c41ab7acdb1e9191d91991d6461ddc4dff3243f7 (diff) | |
Merge branch 'master' of https://github.com/ruotianluo/pytorch-CycleGAN-and-pix2pix into ruotianluo-master
Diffstat (limited to 'util')
| -rw-r--r-- | util/display.py | 115 | ||||
| -rw-r--r-- | util/visualizer.py | 28 |
2 files changed, 20 insertions, 123 deletions
diff --git a/util/display.py b/util/display.py deleted file mode 100644 index 1403483..0000000 --- a/util/display.py +++ /dev/null @@ -1,115 +0,0 @@ -############################################################################### -## Copied from https://github.com/szym/display.git -## The python package installer is under development, so -## the code was adopted. -############################################################################### - -import base64 -import json -import numpy - -try: - from urllib.parse import urlparse, urlencode - from urllib.request import urlopen, Request - from urllib.error import HTTPError -except ImportError: - from urlparse import urlparse - from urllib import urlencode - from urllib2 import urlopen, Request, HTTPError -from . import png - -__all__ = ['URL', 'image', 'images', 'plot'] - -URL = 'http://localhost:8000/events' - - -def uid(): - return 'pane_%s' % uuid.uuid4() - - -def send(**command): - command = json.dumps(command) - req = Request(URL, method='POST') - req.add_header('Content-Type', 'application/text') - req.data = command.encode('ascii') - try: - resp = urlopen(req) - return resp is not None - except: - raise - return False - - -def pane(panetype, win, title, content): - win = win or uid() - send(command='pane', type=panetype, id=win, title=title, content=content) - return win - - -def normalize(img, opts): - minval = opts.get('min') - if minval is None: - minval = numpy.amin(img) - maxval = opts.get('max') - if maxval is None: - maxval = numpy.amax(img) - - return numpy.uint8((img - minval) * (255/(maxval - minval))) - - -def to_rgb(img): - nchannels = img.shape[2] if img.ndim == 3 else 1 - if nchannels == 3: - return img - if nchannels == 1: - return img[:, :, numpy.newaxis].repeat(3, axis=2) - raise ValueError('Image must be RGB or gray-scale') - - -def image(img, **opts): - assert img.ndim == 2 or img.ndim == 3 - - if isinstance(img, list): - return images(img, opts) - # TODO: if img is a 3d tensor, then unstack it into a list of images - - img = to_rgb(normalize(img, opts)) - pngbytes = png.encode(img.tostring(), img.shape[1], img.shape[0]) - imgdata = 'data:image/png;base64,' + base64.b64encode(pngbytes).decode('ascii') - - return pane('image', opts.get('win'), opts.get('title'), content={ - 'src': imgdata, - 'labels': opts.get('labels'), - 'width': opts.get('width'), - }) - - -def images(images, **opts): - # TODO: need to merge images into a single canvas - raise Exception('Not implemented') - - -def plot(data, **opts): - """ Plot data as line chart. - Params: - data: either a 2-d numpy array or a list of lists. - win: pane id - labels: list of series names, first series is always the X-axis - see http://dygraphs.com/options.html for other supported options - """ - dataset = {} - if type(data).__module__ == numpy.__name__: - dataset = data.tolist() - else: - dataset = data - - # clone opts into options - options = dict(opts) - options['file'] = dataset - if options.get('labels'): - options['xlabel'] = options['labels'][0] - - # Don't pass our options to dygraphs. - options.pop('win', None) - - return pane('plot', opts.get('win'), opts.get('title'), content=options) diff --git a/util/visualizer.py b/util/visualizer.py index 4daf506..7df9cce 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -10,14 +10,14 @@ class Visualizer(): def __init__(self, opt): # self.opt = opt self.display_id = opt.display_id + self.name = opt.name if self.display_id > 0: - from . import display - self.display = display + import visdom + self.vis = visdom.Visdom() else: from . import html self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') self.img_dir = os.path.join(self.web_dir, 'images') - self.name = opt.name self.win_size = opt.display_winsize print('create web directory %s...' % self.web_dir) util.mkdirs([self.web_dir, self.img_dir]) @@ -26,10 +26,10 @@ class Visualizer(): # |visuals|: dictionary of images to display or save def display_current_results(self, visuals, epoch): if self.display_id > 0: # show images in the browser - idx = 0 + idx = 1 for label, image_numpy in visuals.items(): - image_numpy = np.flipud(image_numpy) - self.display.image(image_numpy, title=label, + #image_numpy = np.flipud(image_numpy) + self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), win=self.display_id + idx) idx += 1 else: # save images to a web directory @@ -53,8 +53,20 @@ class Visualizer(): webpage.save() # st() # errors: dictionary of error labels and values - def plot_current_errors(self, epoch, i, opt, errors): - pass + def plot_current_errors(self, epoch, counter_ratio, opt, errors): + if not hasattr(self, 'plot_data'): + self.plot_data = {'X':[],'Y':[], 'legend':errors.keys()} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) # errors: same format as |errors| of plotCurrentErrors def print_current_errors(self, epoch, i, errors, start_time): |
