diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2019-12-10 22:23:04 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2019-12-10 22:23:04 +0100 |
| commit | 0458542e4d06ae7aaae23c15e04ef43f54ad4f8d (patch) | |
| tree | bcf584264b950a62500a7f9bc3c4f5703848a8bd /cli/app/search/vector.py | |
| parent | c7ad87acbce1b307b49489eb11a6f5f8740a66e3 (diff) | |
refactor and add hdf5 support
Diffstat (limited to 'cli/app/search/vector.py')
| -rw-r--r-- | cli/app/search/vector.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/cli/app/search/vector.py b/cli/app/search/vector.py new file mode 100644 index 0000000..89cd949 --- /dev/null +++ b/cli/app/search/vector.py @@ -0,0 +1,20 @@ +import random +import numpy as np +from scipy.stats import truncnorm + +def truncated_z_sample(batch_size, z_dim, truncation): + values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim)) + return truncation * values + +def truncated_z_single(z_dim, truncation): + values = truncnorm.rvs(-2, 2, size=(1, z_dim)) + return truncation * values + +def create_labels(batch_size, vocab_size, num_classes): + label = np.zeros((batch_size, vocab_size)) + for i in range(batch_size): + for _ in range(random.randint(1, num_classes)): + j = random.randint(0, vocab_size-1) + label[i, j] = random.random() + label[i] /= label[i].sum() + return label |
