def truncated_z_sample(batch_size, truncation=1., seed=None): state = None if seed is None else np.random.RandomState(seed) values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state) return truncation * values def one_hot(index, vocab_size=vocab_size): index = np.asarray(index) if len(index.shape) == 0: index = np.asarray([index]) assert len(index.shape) == 1 num = index.shape[0] output = np.zeros((num, vocab_size), dtype=np.float32) output[np.arange(num), index] = 1 return output def imconvert_uint8(im): im = np.clip(((im + 1) / 2.0) * 256, 0, 255) im = np.uint8(im) return im def imconvert_float32(im): im = np.float32(im) im = (im / 256) * 2.0 - 1 return im def imread(filename): img = cv2.imread(filename, cv2.IMREAD_UNCHANGED) if img is not None: if len(img.shape) > 2: img = img[...,::-1] return img def imwrite(filename, img): if img is not None: if len(img.shape) > 2: img = img[...,::-1] return cv2.imwrite(filename, img) def imgrid(imarray, cols=5, pad=1): if imarray.dtype != np.uint8: raise ValueError('imgrid input imarray must be uint8') pad = int(pad) assert pad >= 0 cols = int(cols) assert cols >= 1 N, H, W, C = imarray.shape rows = int(np.ceil(N / float(cols))) batch_pad = rows * cols - N assert batch_pad >= 0 post_pad = [batch_pad, pad, pad, 0] pad_arg = [[0, p] for p in post_pad] imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255) H += pad W += pad grid = (imarray .reshape(rows, cols, H, W, C) .transpose(0, 2, 1, 3, 4) .reshape(rows*H, cols*W, C)) if pad: grid = grid[:-pad, :-pad] return grid def imshow(a, format='png', jpeg_fallback=True): a = np.asarray(a, dtype=np.uint8) str_file = cStringIO.StringIO() PIL.Image.fromarray(a).save(str_file, format) im_data = str_file.getvalue() try: disp = IPython.display.display(IPython.display.Image(im_data)) except IOError: if jpeg_fallback and format != 'jpeg': print ('Warning: image was too large to display in format "{}"; ' 'trying jpeg instead.').format(format) return imshow(a, format='jpeg') else: raise return disp