summaryrefslogtreecommitdiff
path: root/cli/app/search/util.py
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-05 17:01:44 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-05 17:01:44 +0100
commitd9ceb77e8700312c554cd3205d0c8db775db00c2 (patch)
tree70c0f0e177a7104988fb8b56698a59fc4886d5aa /cli/app/search/util.py
parent1d3c7428068c46568638db5ab547c8aeb2308b57 (diff)
biggan
Diffstat (limited to 'cli/app/search/util.py')
-rw-r--r--cli/app/search/util.py78
1 files changed, 78 insertions, 0 deletions
diff --git a/cli/app/search/util.py b/cli/app/search/util.py
new file mode 100644
index 0000000..a4cdfd9
--- /dev/null
+++ b/cli/app/search/util.py
@@ -0,0 +1,78 @@
+
+def truncated_z_sample(batch_size, truncation=1., seed=None):
+ state = None if seed is None else np.random.RandomState(seed)
+ values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state)
+ return truncation * values
+
+def one_hot(index, vocab_size=vocab_size):
+ index = np.asarray(index)
+ if len(index.shape) == 0:
+ index = np.asarray([index])
+ assert len(index.shape) == 1
+ num = index.shape[0]
+ output = np.zeros((num, vocab_size), dtype=np.float32)
+ output[np.arange(num), index] = 1
+ return output
+
+def imconvert_uint8(im):
+ im = np.clip(((im + 1) / 2.0) * 256, 0, 255)
+ im = np.uint8(im)
+ return im
+
+def imconvert_float32(im):
+ im = np.float32(im)
+ im = (im / 256) * 2.0 - 1
+ return im
+
+def imread(filename):
+ img = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
+ if img is not None:
+ if len(img.shape) > 2:
+ img = img[...,::-1]
+ return img
+
+def imwrite(filename, img):
+ if img is not None:
+ if len(img.shape) > 2:
+ img = img[...,::-1]
+ return cv2.imwrite(filename, img)
+
+def imgrid(imarray, cols=5, pad=1):
+ if imarray.dtype != np.uint8:
+ raise ValueError('imgrid input imarray must be uint8')
+ pad = int(pad)
+ assert pad >= 0
+ cols = int(cols)
+ assert cols >= 1
+ N, H, W, C = imarray.shape
+ rows = int(np.ceil(N / float(cols)))
+ batch_pad = rows * cols - N
+ assert batch_pad >= 0
+ post_pad = [batch_pad, pad, pad, 0]
+ pad_arg = [[0, p] for p in post_pad]
+ imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)
+ H += pad
+ W += pad
+ grid = (imarray
+ .reshape(rows, cols, H, W, C)
+ .transpose(0, 2, 1, 3, 4)
+ .reshape(rows*H, cols*W, C))
+ if pad:
+ grid = grid[:-pad, :-pad]
+ return grid
+
+def imshow(a, format='png', jpeg_fallback=True):
+ a = np.asarray(a, dtype=np.uint8)
+ str_file = cStringIO.StringIO()
+ PIL.Image.fromarray(a).save(str_file, format)
+ im_data = str_file.getvalue()
+ try:
+ disp = IPython.display.display(IPython.display.Image(im_data))
+ except IOError:
+ if jpeg_fallback and format != 'jpeg':
+ print ('Warning: image was too large to display in format "{}"; '
+ 'trying jpeg instead.').format(format)
+ return imshow(a, format='jpeg')
+ else:
+ raise
+ return disp