diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2019-12-05 17:01:44 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2019-12-05 17:01:44 +0100 |
| commit | d9ceb77e8700312c554cd3205d0c8db775db00c2 (patch) | |
| tree | 70c0f0e177a7104988fb8b56698a59fc4886d5aa /cli/app/search/util.py | |
| parent | 1d3c7428068c46568638db5ab547c8aeb2308b57 (diff) | |
biggan
Diffstat (limited to 'cli/app/search/util.py')
| -rw-r--r-- | cli/app/search/util.py | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/cli/app/search/util.py b/cli/app/search/util.py new file mode 100644 index 0000000..a4cdfd9 --- /dev/null +++ b/cli/app/search/util.py @@ -0,0 +1,78 @@ + +def truncated_z_sample(batch_size, truncation=1., seed=None): + state = None if seed is None else np.random.RandomState(seed) + values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state) + return truncation * values + +def one_hot(index, vocab_size=vocab_size): + index = np.asarray(index) + if len(index.shape) == 0: + index = np.asarray([index]) + assert len(index.shape) == 1 + num = index.shape[0] + output = np.zeros((num, vocab_size), dtype=np.float32) + output[np.arange(num), index] = 1 + return output + +def imconvert_uint8(im): + im = np.clip(((im + 1) / 2.0) * 256, 0, 255) + im = np.uint8(im) + return im + +def imconvert_float32(im): + im = np.float32(im) + im = (im / 256) * 2.0 - 1 + return im + +def imread(filename): + img = cv2.imread(filename, cv2.IMREAD_UNCHANGED) + if img is not None: + if len(img.shape) > 2: + img = img[...,::-1] + return img + +def imwrite(filename, img): + if img is not None: + if len(img.shape) > 2: + img = img[...,::-1] + return cv2.imwrite(filename, img) + +def imgrid(imarray, cols=5, pad=1): + if imarray.dtype != np.uint8: + raise ValueError('imgrid input imarray must be uint8') + pad = int(pad) + assert pad >= 0 + cols = int(cols) + assert cols >= 1 + N, H, W, C = imarray.shape + rows = int(np.ceil(N / float(cols))) + batch_pad = rows * cols - N + assert batch_pad >= 0 + post_pad = [batch_pad, pad, pad, 0] + pad_arg = [[0, p] for p in post_pad] + imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255) + H += pad + W += pad + grid = (imarray + .reshape(rows, cols, H, W, C) + .transpose(0, 2, 1, 3, 4) + .reshape(rows*H, cols*W, C)) + if pad: + grid = grid[:-pad, :-pad] + return grid + +def imshow(a, format='png', jpeg_fallback=True): + a = np.asarray(a, dtype=np.uint8) + str_file = cStringIO.StringIO() + PIL.Image.fromarray(a).save(str_file, format) + im_data = str_file.getvalue() + try: + disp = IPython.display.display(IPython.display.Image(im_data)) + except IOError: + if jpeg_fallback and format != 'jpeg': + print ('Warning: image was too large to display in format "{}"; ' + 'trying jpeg instead.').format(format) + return imshow(a, format='jpeg') + else: + raise + return disp |
