import random import numpy as np from scipy.stats import truncnorm def truncated_z_sample(batch_size, z_dim, truncation): values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim)) return truncation * values def truncated_z_single(z_dim, truncation): values = truncnorm.rvs(-2, 2, size=(1, z_dim)) return truncation * values def create_labels(batch_size, vocab_size, num_classes): label = np.zeros((batch_size, vocab_size)) for i in range(batch_size): for _ in range(random.randint(1, num_classes)): j = random.randint(0, vocab_size-1) label[i, j] = random.random() label[i] /= label[i].sum() return label def create_labels_uniform(batch_size, vocab_size): label = np.zeros((batch_size, vocab_size)) for i in range(batch_size): for j in range(vocab_size): label[i, j] = random.uniform(1 / vocab_size, 2 / vocab_size) label[i] /= label[i].sum() return label