1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
import random
import numpy as np
from scipy.stats import truncnorm
def truncated_z_sample(batch_size, z_dim, truncation):
values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim))
return truncation * values
def truncated_z_single(z_dim, truncation):
values = truncnorm.rvs(-2, 2, size=(1, z_dim))
return truncation * values
def create_labels(batch_size, vocab_size, num_classes):
label = np.zeros((batch_size, vocab_size))
for i in range(batch_size):
for _ in range(random.randint(1, num_classes)):
j = random.randint(0, vocab_size-1)
label[i, j] = random.random()
label[i] /= label[i].sum()
return label
def create_labels_uniform(batch_size, vocab_size):
label = np.zeros((batch_size, vocab_size))
for i in range(batch_size):
for j in range(vocab_size):
label[i, j] = random.uniform(1 / vocab_size, 2 / vocab_size)
label[i] /= label[i].sum()
return label
|