1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
|
def truncated_z_sample(batch_size, truncation=1., seed=None):
state = None if seed is None else np.random.RandomState(seed)
values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state)
return truncation * values
def one_hot(index, vocab_size=vocab_size):
index = np.asarray(index)
if len(index.shape) == 0:
index = np.asarray([index])
assert len(index.shape) == 1
num = index.shape[0]
output = np.zeros((num, vocab_size), dtype=np.float32)
output[np.arange(num), index] = 1
return output
def imconvert_uint8(im):
im = np.clip(((im + 1) / 2.0) * 256, 0, 255)
im = np.uint8(im)
return im
def imconvert_float32(im):
im = np.float32(im)
im = (im / 256) * 2.0 - 1
return im
def imread(filename):
img = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
if img is not None:
if len(img.shape) > 2:
img = img[...,::-1]
return img
def imwrite(filename, img):
if img is not None:
if len(img.shape) > 2:
img = img[...,::-1]
return cv2.imwrite(filename, img)
def imgrid(imarray, cols=5, pad=1):
if imarray.dtype != np.uint8:
raise ValueError('imgrid input imarray must be uint8')
pad = int(pad)
assert pad >= 0
cols = int(cols)
assert cols >= 1
N, H, W, C = imarray.shape
rows = int(np.ceil(N / float(cols)))
batch_pad = rows * cols - N
assert batch_pad >= 0
post_pad = [batch_pad, pad, pad, 0]
pad_arg = [[0, p] for p in post_pad]
imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)
H += pad
W += pad
grid = (imarray
.reshape(rows, cols, H, W, C)
.transpose(0, 2, 1, 3, 4)
.reshape(rows*H, cols*W, C))
if pad:
grid = grid[:-pad, :-pad]
return grid
def imshow(a, format='png', jpeg_fallback=True):
a = np.asarray(a, dtype=np.uint8)
str_file = cStringIO.StringIO()
PIL.Image.fromarray(a).save(str_file, format)
im_data = str_file.getvalue()
try:
disp = IPython.display.display(IPython.display.Image(im_data))
except IOError:
if jpeg_fallback and format != 'jpeg':
print ('Warning: image was too large to display in format "{}"; '
'trying jpeg instead.').format(format)
return imshow(a, format='jpeg')
else:
raise
return disp
|