diff options
| author | Taesung Park <taesung_park@berkeley.edu> | 2017-04-20 03:04:34 -0700 |
|---|---|---|
| committer | Taesung Park <taesung_park@berkeley.edu> | 2017-04-20 03:04:34 -0700 |
| commit | 3aeb748074628fd033d0fa798813aa226b8645d6 (patch) | |
| tree | 9332a8e5d560c20a3502b9061607f3ea55e008da | |
| parent | 685cfc40f5e9bb529babc2e5ebc1026b864fd119 (diff) | |
| parent | 9cecc5fc9e19e08bf4c0d91b64fe9d45c83c610c (diff) | |
Merge branch 'ruotianluo-master'
| -rw-r--r-- | README.md | 2 | ||||
| -rw-r--r-- | train.py | 2 | ||||
| -rw-r--r-- | util/display.py | 115 | ||||
| -rw-r--r-- | util/visualizer.py | 28 |
4 files changed, 22 insertions, 125 deletions
@@ -88,7 +88,7 @@ More example scripts can be found at `scripts` directory. ## Training/test Details - See `options/train_options.py` and `options/base_options.py` for training flags; see `optoins/test_options.py` and `options/base_options.py` for test flags. - CPU/GPU: Set `--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. -- During training, you can visualize the result of current training. If you set `--display_id 0`, we will periodically save the training results to `[opt.checkpoints_dir]/[opt.name]/web/`. If you set `--display_id` > 0, the results will be shown on a local graphics web server launched by [szym/display: a lightweight display server for Torch](https://github.com/szym/display). To do this, you should have Torch, Python 3, and the display package installed. You need to invoke `th -ldisplay.start 8000 0.0.0.0` to start the server. +- During training, you can visualize the result of current training. If you set `--display_id 0`, we will periodically save the training results to `[opt.checkpoints_dir]/[opt.name]/web/`. If you set `--display_id` > 0, the results will be shown on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should visdom installed. You need to invoke `python -m visdom.server` to start the server. ### CycleGAN Datasets Download the CycleGAN datasets using the following script: @@ -32,7 +32,7 @@ for epoch in range(1, opt.niter + opt.niter_decay + 1): errors = model.get_current_errors() visualizer.print_current_errors(epoch, epoch_iter, errors, iter_start_time) if opt.display_id > 0: - visualizer.plot_current_errors(epoch, epoch_iter, opt, errors) + visualizer.plot_current_errors(epoch, float(epoch_iter)/num_train, opt, errors) if total_steps % opt.save_latest_freq == 0: print('saving the latest model (epoch %d, total_steps %d)' % diff --git a/util/display.py b/util/display.py deleted file mode 100644 index 1403483..0000000 --- a/util/display.py +++ /dev/null @@ -1,115 +0,0 @@ -############################################################################### -## Copied from https://github.com/szym/display.git -## The python package installer is under development, so -## the code was adopted. -############################################################################### - -import base64 -import json -import numpy - -try: - from urllib.parse import urlparse, urlencode - from urllib.request import urlopen, Request - from urllib.error import HTTPError -except ImportError: - from urlparse import urlparse - from urllib import urlencode - from urllib2 import urlopen, Request, HTTPError -from . import png - -__all__ = ['URL', 'image', 'images', 'plot'] - -URL = 'http://localhost:8000/events' - - -def uid(): - return 'pane_%s' % uuid.uuid4() - - -def send(**command): - command = json.dumps(command) - req = Request(URL, method='POST') - req.add_header('Content-Type', 'application/text') - req.data = command.encode('ascii') - try: - resp = urlopen(req) - return resp is not None - except: - raise - return False - - -def pane(panetype, win, title, content): - win = win or uid() - send(command='pane', type=panetype, id=win, title=title, content=content) - return win - - -def normalize(img, opts): - minval = opts.get('min') - if minval is None: - minval = numpy.amin(img) - maxval = opts.get('max') - if maxval is None: - maxval = numpy.amax(img) - - return numpy.uint8((img - minval) * (255/(maxval - minval))) - - -def to_rgb(img): - nchannels = img.shape[2] if img.ndim == 3 else 1 - if nchannels == 3: - return img - if nchannels == 1: - return img[:, :, numpy.newaxis].repeat(3, axis=2) - raise ValueError('Image must be RGB or gray-scale') - - -def image(img, **opts): - assert img.ndim == 2 or img.ndim == 3 - - if isinstance(img, list): - return images(img, opts) - # TODO: if img is a 3d tensor, then unstack it into a list of images - - img = to_rgb(normalize(img, opts)) - pngbytes = png.encode(img.tostring(), img.shape[1], img.shape[0]) - imgdata = 'data:image/png;base64,' + base64.b64encode(pngbytes).decode('ascii') - - return pane('image', opts.get('win'), opts.get('title'), content={ - 'src': imgdata, - 'labels': opts.get('labels'), - 'width': opts.get('width'), - }) - - -def images(images, **opts): - # TODO: need to merge images into a single canvas - raise Exception('Not implemented') - - -def plot(data, **opts): - """ Plot data as line chart. - Params: - data: either a 2-d numpy array or a list of lists. - win: pane id - labels: list of series names, first series is always the X-axis - see http://dygraphs.com/options.html for other supported options - """ - dataset = {} - if type(data).__module__ == numpy.__name__: - dataset = data.tolist() - else: - dataset = data - - # clone opts into options - options = dict(opts) - options['file'] = dataset - if options.get('labels'): - options['xlabel'] = options['labels'][0] - - # Don't pass our options to dygraphs. - options.pop('win', None) - - return pane('plot', opts.get('win'), opts.get('title'), content=options) diff --git a/util/visualizer.py b/util/visualizer.py index 4daf506..a839896 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -10,14 +10,14 @@ class Visualizer(): def __init__(self, opt): # self.opt = opt self.display_id = opt.display_id + self.name = opt.name if self.display_id > 0: - from . import display - self.display = display + import visdom + self.vis = visdom.Visdom() else: from . import html self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') self.img_dir = os.path.join(self.web_dir, 'images') - self.name = opt.name self.win_size = opt.display_winsize print('create web directory %s...' % self.web_dir) util.mkdirs([self.web_dir, self.img_dir]) @@ -26,10 +26,10 @@ class Visualizer(): # |visuals|: dictionary of images to display or save def display_current_results(self, visuals, epoch): if self.display_id > 0: # show images in the browser - idx = 0 + idx = 1 for label, image_numpy in visuals.items(): - image_numpy = np.flipud(image_numpy) - self.display.image(image_numpy, title=label, + #image_numpy = np.flipud(image_numpy) + self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), win=self.display_id + idx) idx += 1 else: # save images to a web directory @@ -53,8 +53,20 @@ class Visualizer(): webpage.save() # st() # errors: dictionary of error labels and values - def plot_current_errors(self, epoch, i, opt, errors): - pass + def plot_current_errors(self, epoch, counter_ratio, opt, errors): + if not hasattr(self, 'plot_data'): + self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) # errors: same format as |errors| of plotCurrentErrors def print_current_errors(self, epoch, i, errors, start_time): |
