summaryrefslogtreecommitdiff
path: root/util
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-04-18 03:38:47 -0700
committerjunyanz <junyanz@berkeley.edu>2017-04-18 03:38:47 -0700
commitc99ce7c4e781712e0252c6127ad1a4e8021cc489 (patch)
treeba99dfd56a47036d9c1f18620abf4efc248839ab /util
first commit
Diffstat (limited to 'util')
-rw-r--r--util/__init__.py0
-rw-r--r--util/display.py115
-rw-r--r--util/html.py64
-rw-r--r--util/image_pool.py33
-rw-r--r--util/png.py33
-rw-r--r--util/util.py71
-rw-r--r--util/visualizer.py86
7 files changed, 402 insertions, 0 deletions
diff --git a/util/__init__.py b/util/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/util/__init__.py
diff --git a/util/display.py b/util/display.py
new file mode 100644
index 0000000..1403483
--- /dev/null
+++ b/util/display.py
@@ -0,0 +1,115 @@
+###############################################################################
+## Copied from https://github.com/szym/display.git
+## The python package installer is under development, so
+## the code was adopted.
+###############################################################################
+
+import base64
+import json
+import numpy
+
+try:
+ from urllib.parse import urlparse, urlencode
+ from urllib.request import urlopen, Request
+ from urllib.error import HTTPError
+except ImportError:
+ from urlparse import urlparse
+ from urllib import urlencode
+ from urllib2 import urlopen, Request, HTTPError
+from . import png
+
+__all__ = ['URL', 'image', 'images', 'plot']
+
+URL = 'http://localhost:8000/events'
+
+
+def uid():
+ return 'pane_%s' % uuid.uuid4()
+
+
+def send(**command):
+ command = json.dumps(command)
+ req = Request(URL, method='POST')
+ req.add_header('Content-Type', 'application/text')
+ req.data = command.encode('ascii')
+ try:
+ resp = urlopen(req)
+ return resp is not None
+ except:
+ raise
+ return False
+
+
+def pane(panetype, win, title, content):
+ win = win or uid()
+ send(command='pane', type=panetype, id=win, title=title, content=content)
+ return win
+
+
+def normalize(img, opts):
+ minval = opts.get('min')
+ if minval is None:
+ minval = numpy.amin(img)
+ maxval = opts.get('max')
+ if maxval is None:
+ maxval = numpy.amax(img)
+
+ return numpy.uint8((img - minval) * (255/(maxval - minval)))
+
+
+def to_rgb(img):
+ nchannels = img.shape[2] if img.ndim == 3 else 1
+ if nchannels == 3:
+ return img
+ if nchannels == 1:
+ return img[:, :, numpy.newaxis].repeat(3, axis=2)
+ raise ValueError('Image must be RGB or gray-scale')
+
+
+def image(img, **opts):
+ assert img.ndim == 2 or img.ndim == 3
+
+ if isinstance(img, list):
+ return images(img, opts)
+ # TODO: if img is a 3d tensor, then unstack it into a list of images
+
+ img = to_rgb(normalize(img, opts))
+ pngbytes = png.encode(img.tostring(), img.shape[1], img.shape[0])
+ imgdata = 'data:image/png;base64,' + base64.b64encode(pngbytes).decode('ascii')
+
+ return pane('image', opts.get('win'), opts.get('title'), content={
+ 'src': imgdata,
+ 'labels': opts.get('labels'),
+ 'width': opts.get('width'),
+ })
+
+
+def images(images, **opts):
+ # TODO: need to merge images into a single canvas
+ raise Exception('Not implemented')
+
+
+def plot(data, **opts):
+ """ Plot data as line chart.
+ Params:
+ data: either a 2-d numpy array or a list of lists.
+ win: pane id
+ labels: list of series names, first series is always the X-axis
+ see http://dygraphs.com/options.html for other supported options
+ """
+ dataset = {}
+ if type(data).__module__ == numpy.__name__:
+ dataset = data.tolist()
+ else:
+ dataset = data
+
+ # clone opts into options
+ options = dict(opts)
+ options['file'] = dataset
+ if options.get('labels'):
+ options['xlabel'] = options['labels'][0]
+
+ # Don't pass our options to dygraphs.
+ options.pop('win', None)
+
+ return pane('plot', opts.get('win'), opts.get('title'), content=options)
diff --git a/util/html.py b/util/html.py
new file mode 100644
index 0000000..c7956f1
--- /dev/null
+++ b/util/html.py
@@ -0,0 +1,64 @@
+import dominate
+from dominate.tags import *
+import os
+
+
+class HTML:
+ def __init__(self, web_dir, title, reflesh=0):
+ self.title = title
+ self.web_dir = web_dir
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ if not os.path.exists(self.web_dir):
+ os.makedirs(self.web_dir)
+ if not os.path.exists(self.img_dir):
+ os.makedirs(self.img_dir)
+ # print(self.img_dir)
+
+ self.doc = dominate.document(title=title)
+ if reflesh > 0:
+ with self.doc.head:
+ meta(http_equiv="reflesh", content=str(reflesh))
+
+ def get_image_dir(self):
+ return self.img_dir
+
+ def add_header(self, str):
+ with self.doc:
+ h3(str)
+
+ def add_table(self, border=1):
+ self.t = table(border=border, style="table-layout: fixed;")
+ self.doc.add(self.t)
+
+ def add_images(self, ims, txts, links, width=400):
+ self.add_table()
+ with self.t:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
+ br()
+ p(txt)
+
+ def save(self):
+ html_file = '%s/index.html' % self.web_dir
+ f = open(html_file, 'wt')
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == '__main__':
+ html = HTML('web/', 'test_html')
+ html.add_header('hello world')
+
+ ims = []
+ txts = []
+ links = []
+ for n in range(4):
+ ims.append('image_%d.png' % n)
+ txts.append('text_%d' % n)
+ links.append('image_%d.png' % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/util/image_pool.py b/util/image_pool.py
new file mode 100644
index 0000000..b59e185
--- /dev/null
+++ b/util/image_pool.py
@@ -0,0 +1,33 @@
+import random
+import numpy as np
+import torch
+from pdb import set_trace as st
+from torch.autograd import Variable
+class ImagePool():
+ def __init__(self, pool_size):
+ self.pool_size = pool_size
+ if self.pool_size > 0:
+ self.num_imgs = 0
+ self.images = []
+
+ def query(self, images):
+ if self.pool_size == 0:
+ return images
+ return_images = []
+ for image in images.data:
+ image = torch.unsqueeze(image, 0)
+ if self.num_imgs < self.pool_size:
+ self.num_imgs = self.num_imgs + 1
+ self.images.append(image)
+ return_images.append(image)
+ else:
+ p = random.uniform(0, 1)
+ if p > 0.5:
+ random_id = random.randint(0, self.pool_size-1)
+ tmp = self.images[random_id].clone()
+ self.images[random_id] = image
+ return_images.append(tmp)
+ else:
+ return_images.append(image)
+ return_images = Variable(torch.cat(return_images, 0))
+ return return_images
diff --git a/util/png.py b/util/png.py
new file mode 100644
index 0000000..0936cf0
--- /dev/null
+++ b/util/png.py
@@ -0,0 +1,33 @@
+import struct
+import zlib
+
+def encode(buf, width, height):
+ """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """
+ assert (width * height * 3 == len(buf))
+ bpp = 3
+
+ def raw_data():
+ # reverse the vertical line order and add null bytes at the start
+ row_bytes = width * bpp
+ for row_start in range((height - 1) * width * bpp, -1, -row_bytes):
+ yield b'\x00'
+ yield buf[row_start:row_start + row_bytes]
+
+ def chunk(tag, data):
+ return [
+ struct.pack("!I", len(data)),
+ tag,
+ data,
+ struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag)))
+ ]
+
+ SIGNATURE = b'\x89PNG\r\n\x1a\n'
+ COLOR_TYPE_RGB = 2
+ COLOR_TYPE_RGBA = 6
+ bit_depth = 8
+ return b''.join(
+ [ SIGNATURE ] +
+ chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) +
+ chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) +
+ chunk(b'IEND', b'')
+ )
diff --git a/util/util.py b/util/util.py
new file mode 100644
index 0000000..781239f
--- /dev/null
+++ b/util/util.py
@@ -0,0 +1,71 @@
+from __future__ import print_function
+import torch
+import numpy as np
+from PIL import Image
+import inspect, re
+import numpy as np
+import os
+import collections
+
+# Converts a Tensor into a Numpy array
+# |imtype|: the desired type of the converted numpy array
+def tensor2im(image_tensor, imtype=np.uint8):
+ image_numpy = image_tensor[0].cpu().float().numpy()
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
+ return image_numpy.astype(imtype)
+
+
+def diagnose_network(net, name='network'):
+ mean = 0.0
+ count = 0
+ for param in net.parameters():
+ if param.grad is not None:
+ mean += torch.mean(torch.abs(param.grad.data))
+ count += 1
+ if count > 0:
+ mean = mean / count
+ print(name)
+ print(mean)
+
+
+def save_image(image_numpy, image_path):
+ image_pil = Image.fromarray(image_numpy)
+ image_pil.save(image_path)
+
+def info(object, spacing=10, collapse=1):
+ """Print methods and doc strings.
+ Takes module, class, list, dictionary, or string."""
+ methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
+ processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
+ print( "\n".join(["%s %s" %
+ (method.ljust(spacing),
+ processFunc(str(getattr(object, method).__doc__)))
+ for method in methodList]) )
+
+def varname(p):
+ for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
+ m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
+ if m:
+ return m.group(1)
+
+def print_numpy(x, val=True, shp=False):
+ x = x.astype(np.float64)
+ if shp:
+ print('shape,', x.shape)
+ if val:
+ x = x.flatten()
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
+
+
+def mkdirs(paths):
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
diff --git a/util/visualizer.py b/util/visualizer.py
new file mode 100644
index 0000000..0b8578e
--- /dev/null
+++ b/util/visualizer.py
@@ -0,0 +1,86 @@
+import numpy as np
+import os
+import ntpath
+import time
+from . import util
+from . import html
+from pdb import set_trace as st
+
+class Visualizer():
+ def __init__(self, opt):
+ # self.opt = opt
+ self.display_id = opt.display_id
+ if self.display_id > 0:
+ from . import display
+ self.display = display
+ else:
+ from . import html
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ self.name = opt.name
+ self.win_size = opt.display_winsize
+ print('create web directory %s...' % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+
+
+ # |visuals|: dictionary of images to display or save
+ def display_current_results(self, visuals, epoch):
+ if self.display_id > 0: # show images in the browser
+ idx = 0
+ for label, image_numpy in visuals:
+ image_numpy = np.flipud(image_numpy)
+ self.display.image(image_numpy, title=label,
+ win=self.display_id + idx)
+ idx += 1
+ else: # save images to a web directory
+ for label, image_numpy in visuals.items():
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
+ util.save_image(image_numpy, img_path)
+ # update website
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
+ for n in range(epoch, 0, -1):
+ webpage.add_header('epoch [%d]' % n)
+ ims = []
+ txts = []
+ links = []
+
+ for label, image_numpy in visuals.items():
+ img_path = 'epoch%.3d_%s.png' % (n, label)
+ ims.append(img_path)
+ txts.append(label)
+ links.append(img_path)
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ webpage.save()
+ # st()
+ # errors: dictionary of error labels and values
+ def plot_current_errors(self, epoch, i, opt, errors):
+ pass
+
+ # errors: same format as |errors| of plotCurrentErrors
+ def print_current_errors(self, epoch, i, errors, start_time):
+ message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, time.time() - start_time)
+ for k, v in errors.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message)
+
+ # save image to the disk
+ def save_images(self, webpage, visuals, image_path):
+ image_dir = webpage.get_image_dir()
+ short_path = ntpath.basename(image_path[0])
+ name = os.path.splitext(short_path)[0]
+
+ webpage.add_header(name)
+ ims = []
+ txts = []
+ links = []
+
+ for label, image_numpy in visuals.items():
+ image_name = '%s_%s.png' % (name, label)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(image_numpy, save_path)
+
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+ webpage.add_images(ims, txts, links, width=self.win_size)