summaryrefslogtreecommitdiff
path: root/util
diff options
context:
space:
mode:
authortingchunw <tingchunw@nvidia.com>2017-12-04 16:52:46 -0800
committertingchunw <tingchunw@nvidia.com>2017-12-04 16:52:46 -0800
commit9054cf9b0c327a5077fd0793abe178f400da3315 (patch)
tree3c69c07bdcba86c47d8442648fd69c0434e04136 /util
parentf9e9999541d67a908a169cc88407675133130e1f (diff)
first commit
Diffstat (limited to 'util')
-rwxr-xr-xutil/__init__.py0
-rwxr-xr-xutil/html.py63
-rwxr-xr-xutil/image_pool.py32
-rwxr-xr-xutil/util.py99
-rwxr-xr-xutil/visualizer.py133
5 files changed, 327 insertions, 0 deletions
diff --git a/util/__init__.py b/util/__init__.py
new file mode 100755
index 0000000..e69de29
--- /dev/null
+++ b/util/__init__.py
diff --git a/util/html.py b/util/html.py
new file mode 100755
index 0000000..a80aa59
--- /dev/null
+++ b/util/html.py
@@ -0,0 +1,63 @@
+import dominate
+from dominate.tags import *
+import os
+
+
+class HTML:
+ def __init__(self, web_dir, title, reflesh=0):
+ self.title = title
+ self.web_dir = web_dir
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ if not os.path.exists(self.web_dir):
+ os.makedirs(self.web_dir)
+ if not os.path.exists(self.img_dir):
+ os.makedirs(self.img_dir)
+
+ self.doc = dominate.document(title=title)
+ if reflesh > 0:
+ with self.doc.head:
+ meta(http_equiv="reflesh", content=str(reflesh))
+
+ def get_image_dir(self):
+ return self.img_dir
+
+ def add_header(self, str):
+ with self.doc:
+ h3(str)
+
+ def add_table(self, border=1):
+ self.t = table(border=border, style="table-layout: fixed;")
+ self.doc.add(self.t)
+
+ def add_images(self, ims, txts, links, width=512):
+ self.add_table()
+ with self.t:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % (width), src=os.path.join('images', im))
+ br()
+ p(txt)
+
+ def save(self):
+ html_file = '%s/index.html' % self.web_dir
+ f = open(html_file, 'wt')
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == '__main__':
+ html = HTML('web/', 'test_html')
+ html.add_header('hello world')
+
+ ims = []
+ txts = []
+ links = []
+ for n in range(4):
+ ims.append('image_%d.jpg' % n)
+ txts.append('text_%d' % n)
+ links.append('image_%d.jpg' % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/util/image_pool.py b/util/image_pool.py
new file mode 100755
index 0000000..152ef5b
--- /dev/null
+++ b/util/image_pool.py
@@ -0,0 +1,32 @@
+import random
+import numpy as np
+import torch
+from torch.autograd import Variable
+class ImagePool():
+ def __init__(self, pool_size):
+ self.pool_size = pool_size
+ if self.pool_size > 0:
+ self.num_imgs = 0
+ self.images = []
+
+ def query(self, images):
+ if self.pool_size == 0:
+ return images
+ return_images = []
+ for image in images.data:
+ image = torch.unsqueeze(image, 0)
+ if self.num_imgs < self.pool_size:
+ self.num_imgs = self.num_imgs + 1
+ self.images.append(image)
+ return_images.append(image)
+ else:
+ p = random.uniform(0, 1)
+ if p > 0.5:
+ random_id = random.randint(0, self.pool_size-1)
+ tmp = self.images[random_id].clone()
+ self.images[random_id] = image
+ return_images.append(tmp)
+ else:
+ return_images.append(image)
+ return_images = Variable(torch.cat(return_images, 0))
+ return return_images
diff --git a/util/util.py b/util/util.py
new file mode 100755
index 0000000..0898f7a
--- /dev/null
+++ b/util/util.py
@@ -0,0 +1,99 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+from __future__ import print_function
+import torch
+import numpy as np
+from PIL import Image
+import inspect, re
+import numpy as np
+import os
+import collections
+from PIL import Image
+
+# Converts a Tensor into a Numpy array
+# |imtype|: the desired type of the converted numpy array
+def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
+ if isinstance(image_tensor, list):
+ image_numpy = []
+ for i in range(len(image_tensor)):
+ image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
+ return image_numpy
+ image_numpy = image_tensor.cpu().float().numpy()
+ if normalize:
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
+ else:
+ image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
+ image_numpy = np.clip(image_numpy, 0, 255)
+ if image_numpy.shape[2] == 1:
+ image_numpy = image_numpy[:,:,0]
+ return image_numpy.astype(imtype)
+
+def tensor2label(output, n_label, imtype=np.uint8):
+ output = output.cpu().float()
+ if output.size()[0] > 1:
+ output = output.max(0, keepdim=True)[1]
+ output = Colorize(n_label)(output)
+ output = np.transpose(output.numpy(), (1, 2, 0))
+ return output.astype(imtype)
+
+def save_image(image_numpy, image_path):
+ image_pil = Image.fromarray(image_numpy)
+ image_pil.save(image_path)
+
+def mkdirs(paths):
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+def uint82bin(n, count=8):
+ """returns the binary of integer n, count refers to amount of bits"""
+ return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
+
+def labelcolormap(N):
+ if N == 35: # cityscape
+ cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
+ (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
+ (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
+ (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
+ ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
+ dtype=np.uint8)
+ else:
+ cmap = np.zeros((N, 3), dtype=np.uint8)
+ for i in range(N):
+ r = 0
+ g = 0
+ b = 0
+ id = i
+ for j in range(7):
+ str_id = uint82bin(id)
+ r = r ^ (np.uint8(str_id[-1]) << (7-j))
+ g = g ^ (np.uint8(str_id[-2]) << (7-j))
+ b = b ^ (np.uint8(str_id[-3]) << (7-j))
+ id = id >> 3
+ cmap[i, 0] = r
+ cmap[i, 1] = g
+ cmap[i, 2] = b
+ return cmap
+
+class Colorize(object):
+ def __init__(self, n=35):
+ self.cmap = labelcolormap(n)
+ self.cmap = torch.from_numpy(self.cmap[:n])
+
+ def __call__(self, gray_image):
+ size = gray_image.size()
+ color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
+
+ for label in range(0, len(self.cmap)):
+ mask = (label == gray_image[0]).cpu()
+ color_image[0][mask] = self.cmap[label][0]
+ color_image[1][mask] = self.cmap[label][1]
+ color_image[2][mask] = self.cmap[label][2]
+
+ return color_image
diff --git a/util/visualizer.py b/util/visualizer.py
new file mode 100755
index 0000000..f41c55a
--- /dev/null
+++ b/util/visualizer.py
@@ -0,0 +1,133 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import numpy as np
+import os
+import ntpath
+import time
+from . import util
+from . import html
+import scipy.misc
+try:
+ from StringIO import StringIO # Python 2.7
+except ImportError:
+ from io import BytesIO # Python 3.x
+
+class Visualizer():
+ def __init__(self, opt):
+ # self.opt = opt
+ self.tf_log = opt.tf_log
+ self.use_html = opt.isTrain and not opt.no_html
+ self.win_size = opt.display_winsize
+ self.name = opt.name
+ if self.tf_log:
+ import tensorflow as tf
+ self.tf = tf
+ self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')
+ self.writer = tf.summary.FileWriter(self.log_dir)
+
+ if self.use_html:
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ print('create web directory %s...' % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+ # |visuals|: dictionary of images to display or save
+ def display_current_results(self, visuals, epoch, step):
+ if self.tf_log: # show images in tensorboard output
+ img_summaries = []
+ for label, image_numpy in visuals.items():
+ # Write the image to a string
+ try:
+ s = StringIO()
+ except:
+ s = BytesIO()
+ scipy.misc.toimage(image_numpy).save(s, format="jpeg")
+ # Create an Image object
+ img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1])
+ # Create a Summary value
+ img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum))
+
+ # Create and write Summary
+ summary = self.tf.Summary(value=img_summaries)
+ self.writer.add_summary(summary, step)
+
+ if self.use_html: # save images to a html file
+ for label, image_numpy in visuals.items():
+ if isinstance(image_numpy, list):
+ for i in range(len(image_numpy)):
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i))
+ util.save_image(image_numpy[i], img_path)
+ else:
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label))
+ util.save_image(image_numpy, img_path)
+
+ # update website
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
+ for n in range(epoch, 0, -1):
+ webpage.add_header('epoch [%d]' % n)
+ ims = []
+ txts = []
+ links = []
+
+ for label, image_numpy in visuals.items():
+ if isinstance(image_numpy, list):
+ for i in range(len(image_numpy)):
+ img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i)
+ ims.append(img_path)
+ txts.append(label+str(i))
+ links.append(img_path)
+ else:
+ img_path = 'epoch%.3d_%s.jpg' % (n, label)
+ ims.append(img_path)
+ txts.append(label)
+ links.append(img_path)
+ if len(ims) < 10:
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ else:
+ num = int(round(len(ims)/2.0))
+ webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size)
+ webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size)
+ webpage.save()
+
+ # errors: dictionary of error labels and values
+ def plot_current_errors(self, errors, step):
+ if self.tf_log:
+ for tag, value in errors.items():
+ summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
+ self.writer.add_summary(summary, step)
+
+ # errors: same format as |errors| of plotCurrentErrors
+ def print_current_errors(self, epoch, i, errors, t):
+ message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
+ for k, v in errors.items():
+ if v != 0:
+ message += '%s: %.3f ' % (k, v)
+
+ print(message)
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message)
+
+ # save image to the disk
+ def save_images(self, webpage, visuals, image_path):
+ image_dir = webpage.get_image_dir()
+ short_path = ntpath.basename(image_path[0])
+ name = os.path.splitext(short_path)[0]
+
+ webpage.add_header(name)
+ ims = []
+ txts = []
+ links = []
+
+ for label, image_numpy in visuals.items():
+ image_name = '%s_%s.jpg' % (name, label)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(image_numpy, save_path)
+
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+ webpage.add_images(ims, txts, links, width=self.win_size)