summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md2
-rw-r--r--train.py2
-rw-r--r--util/display.py115
-rw-r--r--util/visualizer.py28
4 files changed, 22 insertions, 125 deletions
diff --git a/README.md b/README.md
index 372d349..d05d80e 100644
--- a/README.md
+++ b/README.md
@@ -86,7 +86,7 @@ More example scripts can be found at `scripts` directory.
## Training/test Details
- See `options/train_options.py` and `options/base_options.py` for training flags; see `optoins/test_options.py` and `options/base_options.py` for test flags.
- CPU/GPU: Set `--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode.
-- During training, you can visualize the result of current training. If you set `--display_id 0`, we will periodically save the training results to `[opt.checkpoints_dir]/[opt.name]/web/`. If you set `--display_id` > 0, the results will be shown on a local graphics web server launched by [szym/display: a lightweight display server for Torch](https://github.com/szym/display). To do this, you should have Torch, Python 3, and the display package installed. You need to invoke `th -ldisplay.start 8000 0.0.0.0` to start the server.
+- During training, you can visualize the result of current training. If you set `--display_id 0`, we will periodically save the training results to `[opt.checkpoints_dir]/[opt.name]/web/`. If you set `--display_id` > 0, the results will be shown on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should visdom installed. You need to invoke `python -m visdom.server` to start the server.
### CycleGAN Datasets
Download the CycleGAN datasets using the following script:
diff --git a/train.py b/train.py
index e85042f..da497ce 100644
--- a/train.py
+++ b/train.py
@@ -32,7 +32,7 @@ for epoch in range(1, opt.niter + opt.niter_decay + 1):
errors = model.get_current_errors()
visualizer.print_current_errors(epoch, epoch_iter, errors, iter_start_time)
if opt.display_id > 0:
- visualizer.plot_current_errors(epoch, epoch_iter, opt, errors)
+ visualizer.plot_current_errors(epoch, float(epoch_iter)/num_train, opt, errors)
if total_steps % opt.save_latest_freq == 0:
print('saving the latest model (epoch %d, total_steps %d)' %
diff --git a/util/display.py b/util/display.py
deleted file mode 100644
index 1403483..0000000
--- a/util/display.py
+++ /dev/null
@@ -1,115 +0,0 @@
-###############################################################################
-## Copied from https://github.com/szym/display.git
-## The python package installer is under development, so
-## the code was adopted.
-###############################################################################
-
-import base64
-import json
-import numpy
-
-try:
- from urllib.parse import urlparse, urlencode
- from urllib.request import urlopen, Request
- from urllib.error import HTTPError
-except ImportError:
- from urlparse import urlparse
- from urllib import urlencode
- from urllib2 import urlopen, Request, HTTPError
-from . import png
-
-__all__ = ['URL', 'image', 'images', 'plot']
-
-URL = 'http://localhost:8000/events'
-
-
-def uid():
- return 'pane_%s' % uuid.uuid4()
-
-
-def send(**command):
- command = json.dumps(command)
- req = Request(URL, method='POST')
- req.add_header('Content-Type', 'application/text')
- req.data = command.encode('ascii')
- try:
- resp = urlopen(req)
- return resp is not None
- except:
- raise
- return False
-
-
-def pane(panetype, win, title, content):
- win = win or uid()
- send(command='pane', type=panetype, id=win, title=title, content=content)
- return win
-
-
-def normalize(img, opts):
- minval = opts.get('min')
- if minval is None:
- minval = numpy.amin(img)
- maxval = opts.get('max')
- if maxval is None:
- maxval = numpy.amax(img)
-
- return numpy.uint8((img - minval) * (255/(maxval - minval)))
-
-
-def to_rgb(img):
- nchannels = img.shape[2] if img.ndim == 3 else 1
- if nchannels == 3:
- return img
- if nchannels == 1:
- return img[:, :, numpy.newaxis].repeat(3, axis=2)
- raise ValueError('Image must be RGB or gray-scale')
-
-
-def image(img, **opts):
- assert img.ndim == 2 or img.ndim == 3
-
- if isinstance(img, list):
- return images(img, opts)
- # TODO: if img is a 3d tensor, then unstack it into a list of images
-
- img = to_rgb(normalize(img, opts))
- pngbytes = png.encode(img.tostring(), img.shape[1], img.shape[0])
- imgdata = 'data:image/png;base64,' + base64.b64encode(pngbytes).decode('ascii')
-
- return pane('image', opts.get('win'), opts.get('title'), content={
- 'src': imgdata,
- 'labels': opts.get('labels'),
- 'width': opts.get('width'),
- })
-
-
-def images(images, **opts):
- # TODO: need to merge images into a single canvas
- raise Exception('Not implemented')
-
-
-def plot(data, **opts):
- """ Plot data as line chart.
- Params:
- data: either a 2-d numpy array or a list of lists.
- win: pane id
- labels: list of series names, first series is always the X-axis
- see http://dygraphs.com/options.html for other supported options
- """
- dataset = {}
- if type(data).__module__ == numpy.__name__:
- dataset = data.tolist()
- else:
- dataset = data
-
- # clone opts into options
- options = dict(opts)
- options['file'] = dataset
- if options.get('labels'):
- options['xlabel'] = options['labels'][0]
-
- # Don't pass our options to dygraphs.
- options.pop('win', None)
-
- return pane('plot', opts.get('win'), opts.get('title'), content=options)
diff --git a/util/visualizer.py b/util/visualizer.py
index 4daf506..7df9cce 100644
--- a/util/visualizer.py
+++ b/util/visualizer.py
@@ -10,14 +10,14 @@ class Visualizer():
def __init__(self, opt):
# self.opt = opt
self.display_id = opt.display_id
+ self.name = opt.name
if self.display_id > 0:
- from . import display
- self.display = display
+ import visdom
+ self.vis = visdom.Visdom()
else:
from . import html
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
- self.name = opt.name
self.win_size = opt.display_winsize
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
@@ -26,10 +26,10 @@ class Visualizer():
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch):
if self.display_id > 0: # show images in the browser
- idx = 0
+ idx = 1
for label, image_numpy in visuals.items():
- image_numpy = np.flipud(image_numpy)
- self.display.image(image_numpy, title=label,
+ #image_numpy = np.flipud(image_numpy)
+ self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),
win=self.display_id + idx)
idx += 1
else: # save images to a web directory
@@ -53,8 +53,20 @@ class Visualizer():
webpage.save()
# st()
# errors: dictionary of error labels and values
- def plot_current_errors(self, epoch, i, opt, errors):
- pass
+ def plot_current_errors(self, epoch, counter_ratio, opt, errors):
+ if not hasattr(self, 'plot_data'):
+ self.plot_data = {'X':[],'Y':[], 'legend':errors.keys()}
+ self.plot_data['X'].append(epoch + counter_ratio)
+ self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
+ self.vis.line(
+ X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
+ Y=np.array(self.plot_data['Y']),
+ opts={
+ 'title': self.name + ' loss over time',
+ 'legend': self.plot_data['legend'],
+ 'xlabel': 'epoch',
+ 'ylabel': 'loss'},
+ win=self.display_id)
# errors: same format as |errors| of plotCurrentErrors
def print_current_errors(self, epoch, i, errors, start_time):