diff options
Diffstat (limited to 'cli/app/search/util.py')
| -rw-r--r-- | cli/app/search/util.py | 78 |
1 files changed, 0 insertions, 78 deletions
diff --git a/cli/app/search/util.py b/cli/app/search/util.py deleted file mode 100644 index a4cdfd9..0000000 --- a/cli/app/search/util.py +++ /dev/null @@ -1,78 +0,0 @@ - -def truncated_z_sample(batch_size, truncation=1., seed=None): - state = None if seed is None else np.random.RandomState(seed) - values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state) - return truncation * values - -def one_hot(index, vocab_size=vocab_size): - index = np.asarray(index) - if len(index.shape) == 0: - index = np.asarray([index]) - assert len(index.shape) == 1 - num = index.shape[0] - output = np.zeros((num, vocab_size), dtype=np.float32) - output[np.arange(num), index] = 1 - return output - -def imconvert_uint8(im): - im = np.clip(((im + 1) / 2.0) * 256, 0, 255) - im = np.uint8(im) - return im - -def imconvert_float32(im): - im = np.float32(im) - im = (im / 256) * 2.0 - 1 - return im - -def imread(filename): - img = cv2.imread(filename, cv2.IMREAD_UNCHANGED) - if img is not None: - if len(img.shape) > 2: - img = img[...,::-1] - return img - -def imwrite(filename, img): - if img is not None: - if len(img.shape) > 2: - img = img[...,::-1] - return cv2.imwrite(filename, img) - -def imgrid(imarray, cols=5, pad=1): - if imarray.dtype != np.uint8: - raise ValueError('imgrid input imarray must be uint8') - pad = int(pad) - assert pad >= 0 - cols = int(cols) - assert cols >= 1 - N, H, W, C = imarray.shape - rows = int(np.ceil(N / float(cols))) - batch_pad = rows * cols - N - assert batch_pad >= 0 - post_pad = [batch_pad, pad, pad, 0] - pad_arg = [[0, p] for p in post_pad] - imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255) - H += pad - W += pad - grid = (imarray - .reshape(rows, cols, H, W, C) - .transpose(0, 2, 1, 3, 4) - .reshape(rows*H, cols*W, C)) - if pad: - grid = grid[:-pad, :-pad] - return grid - -def imshow(a, format='png', jpeg_fallback=True): - a = np.asarray(a, dtype=np.uint8) - str_file = cStringIO.StringIO() - PIL.Image.fromarray(a).save(str_file, format) - im_data = str_file.getvalue() - try: - disp = IPython.display.display(IPython.display.Image(im_data)) - except IOError: - if jpeg_fallback and format != 'jpeg': - print ('Warning: image was too large to display in format "{}"; ' - 'trying jpeg instead.').format(format) - return imshow(a, format='jpeg') - else: - raise - return disp |
