diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-04-18 03:38:47 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-04-18 03:38:47 -0700 |
| commit | c99ce7c4e781712e0252c6127ad1a4e8021cc489 (patch) | |
| tree | ba99dfd56a47036d9c1f18620abf4efc248839ab /util | |
first commit
Diffstat (limited to 'util')
| -rw-r--r-- | util/__init__.py | 0 | ||||
| -rw-r--r-- | util/display.py | 115 | ||||
| -rw-r--r-- | util/html.py | 64 | ||||
| -rw-r--r-- | util/image_pool.py | 33 | ||||
| -rw-r--r-- | util/png.py | 33 | ||||
| -rw-r--r-- | util/util.py | 71 | ||||
| -rw-r--r-- | util/visualizer.py | 86 |
7 files changed, 402 insertions, 0 deletions
diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/util/__init__.py diff --git a/util/display.py b/util/display.py new file mode 100644 index 0000000..1403483 --- /dev/null +++ b/util/display.py @@ -0,0 +1,115 @@ +############################################################################### +## Copied from https://github.com/szym/display.git +## The python package installer is under development, so +## the code was adopted. +############################################################################### + +import base64 +import json +import numpy + +try: + from urllib.parse import urlparse, urlencode + from urllib.request import urlopen, Request + from urllib.error import HTTPError +except ImportError: + from urlparse import urlparse + from urllib import urlencode + from urllib2 import urlopen, Request, HTTPError +from . import png + +__all__ = ['URL', 'image', 'images', 'plot'] + +URL = 'http://localhost:8000/events' + + +def uid(): + return 'pane_%s' % uuid.uuid4() + + +def send(**command): + command = json.dumps(command) + req = Request(URL, method='POST') + req.add_header('Content-Type', 'application/text') + req.data = command.encode('ascii') + try: + resp = urlopen(req) + return resp is not None + except: + raise + return False + + +def pane(panetype, win, title, content): + win = win or uid() + send(command='pane', type=panetype, id=win, title=title, content=content) + return win + + +def normalize(img, opts): + minval = opts.get('min') + if minval is None: + minval = numpy.amin(img) + maxval = opts.get('max') + if maxval is None: + maxval = numpy.amax(img) + + return numpy.uint8((img - minval) * (255/(maxval - minval))) + + +def to_rgb(img): + nchannels = img.shape[2] if img.ndim == 3 else 1 + if nchannels == 3: + return img + if nchannels == 1: + return img[:, :, numpy.newaxis].repeat(3, axis=2) + raise ValueError('Image must be RGB or gray-scale') + + +def image(img, **opts): + assert img.ndim == 2 or img.ndim == 3 + + if isinstance(img, list): + return images(img, opts) + # TODO: if img is a 3d tensor, then unstack it into a list of images + + img = to_rgb(normalize(img, opts)) + pngbytes = png.encode(img.tostring(), img.shape[1], img.shape[0]) + imgdata = 'data:image/png;base64,' + base64.b64encode(pngbytes).decode('ascii') + + return pane('image', opts.get('win'), opts.get('title'), content={ + 'src': imgdata, + 'labels': opts.get('labels'), + 'width': opts.get('width'), + }) + + +def images(images, **opts): + # TODO: need to merge images into a single canvas + raise Exception('Not implemented') + + +def plot(data, **opts): + """ Plot data as line chart. + Params: + data: either a 2-d numpy array or a list of lists. + win: pane id + labels: list of series names, first series is always the X-axis + see http://dygraphs.com/options.html for other supported options + """ + dataset = {} + if type(data).__module__ == numpy.__name__: + dataset = data.tolist() + else: + dataset = data + + # clone opts into options + options = dict(opts) + options['file'] = dataset + if options.get('labels'): + options['xlabel'] = options['labels'][0] + + # Don't pass our options to dygraphs. + options.pop('win', None) + + return pane('plot', opts.get('win'), opts.get('title'), content=options) diff --git a/util/html.py b/util/html.py new file mode 100644 index 0000000..c7956f1 --- /dev/null +++ b/util/html.py @@ -0,0 +1,64 @@ +import dominate +from dominate.tags import * +import os + + +class HTML: + def __init__(self, web_dir, title, reflesh=0): + self.title = title + self.web_dir = web_dir + self.img_dir = os.path.join(self.web_dir, 'images') + if not os.path.exists(self.web_dir): + os.makedirs(self.web_dir) + if not os.path.exists(self.img_dir): + os.makedirs(self.img_dir) + # print(self.img_dir) + + self.doc = dominate.document(title=title) + if reflesh > 0: + with self.doc.head: + meta(http_equiv="reflesh", content=str(reflesh)) + + def get_image_dir(self): + return self.img_dir + + def add_header(self, str): + with self.doc: + h3(str) + + def add_table(self, border=1): + self.t = table(border=border, style="table-layout: fixed;") + self.doc.add(self.t) + + def add_images(self, ims, txts, links, width=400): + self.add_table() + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims = [] + txts = [] + links = [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/util/image_pool.py b/util/image_pool.py new file mode 100644 index 0000000..b59e185 --- /dev/null +++ b/util/image_pool.py @@ -0,0 +1,33 @@ +import random +import numpy as np +import torch +from pdb import set_trace as st +from torch.autograd import Variable +class ImagePool(): + def __init__(self, pool_size): + self.pool_size = pool_size + if self.pool_size > 0: + self.num_imgs = 0 + self.images = [] + + def query(self, images): + if self.pool_size == 0: + return images + return_images = [] + for image in images.data: + image = torch.unsqueeze(image, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: + random_id = random.randint(0, self.pool_size-1) + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: + return_images.append(image) + return_images = Variable(torch.cat(return_images, 0)) + return return_images diff --git a/util/png.py b/util/png.py new file mode 100644 index 0000000..0936cf0 --- /dev/null +++ b/util/png.py @@ -0,0 +1,33 @@ +import struct +import zlib + +def encode(buf, width, height): + """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """ + assert (width * height * 3 == len(buf)) + bpp = 3 + + def raw_data(): + # reverse the vertical line order and add null bytes at the start + row_bytes = width * bpp + for row_start in range((height - 1) * width * bpp, -1, -row_bytes): + yield b'\x00' + yield buf[row_start:row_start + row_bytes] + + def chunk(tag, data): + return [ + struct.pack("!I", len(data)), + tag, + data, + struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag))) + ] + + SIGNATURE = b'\x89PNG\r\n\x1a\n' + COLOR_TYPE_RGB = 2 + COLOR_TYPE_RGBA = 6 + bit_depth = 8 + return b''.join( + [ SIGNATURE ] + + chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) + + chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) + + chunk(b'IEND', b'') + ) diff --git a/util/util.py b/util/util.py new file mode 100644 index 0000000..781239f --- /dev/null +++ b/util/util.py @@ -0,0 +1,71 @@ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import inspect, re +import numpy as np +import os +import collections + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + +def info(object, spacing=10, collapse=1): + """Print methods and doc strings. + Takes module, class, list, dictionary, or string.""" + methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)] + processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s) + print( "\n".join(["%s %s" % + (method.ljust(spacing), + processFunc(str(getattr(object, method).__doc__))) + for method in methodList]) ) + +def varname(p): + for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]: + m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line) + if m: + return m.group(1) + +def print_numpy(x, val=True, shp=False): + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) diff --git a/util/visualizer.py b/util/visualizer.py new file mode 100644 index 0000000..0b8578e --- /dev/null +++ b/util/visualizer.py @@ -0,0 +1,86 @@ +import numpy as np +import os +import ntpath +import time +from . import util +from . import html +from pdb import set_trace as st + +class Visualizer(): + def __init__(self, opt): + # self.opt = opt + self.display_id = opt.display_id + if self.display_id > 0: + from . import display + self.display = display + else: + from . import html + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + self.name = opt.name + self.win_size = opt.display_winsize + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch): + if self.display_id > 0: # show images in the browser + idx = 0 + for label, image_numpy in visuals: + image_numpy = np.flipud(image_numpy) + self.display.image(image_numpy, title=label, + win=self.display_id + idx) + idx += 1 + else: # save images to a web directory + for label, image_numpy in visuals.items(): + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + # st() + # errors: dictionary of error labels and values + def plot_current_errors(self, epoch, i, opt, errors): + pass + + # errors: same format as |errors| of plotCurrentErrors + def print_current_errors(self, epoch, i, errors, start_time): + message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, time.time() - start_time) + for k, v in errors.items(): + message += '%s: %.3f ' % (k, v) + + print(message) + + # save image to the disk + def save_images(self, webpage, visuals, image_path): + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + image_name = '%s_%s.png' % (name, label) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=self.win_size) |
