summaryrefslogtreecommitdiff
path: root/cli/app/search/vector.py
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-10 22:23:04 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-10 22:23:04 +0100
commit0458542e4d06ae7aaae23c15e04ef43f54ad4f8d (patch)
treebcf584264b950a62500a7f9bc3c4f5703848a8bd /cli/app/search/vector.py
parentc7ad87acbce1b307b49489eb11a6f5f8740a66e3 (diff)
refactor and add hdf5 support
Diffstat (limited to 'cli/app/search/vector.py')
-rw-r--r--cli/app/search/vector.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/cli/app/search/vector.py b/cli/app/search/vector.py
new file mode 100644
index 0000000..89cd949
--- /dev/null
+++ b/cli/app/search/vector.py
@@ -0,0 +1,20 @@
+import random
+import numpy as np
+from scipy.stats import truncnorm
+
+def truncated_z_sample(batch_size, z_dim, truncation):
+ values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim))
+ return truncation * values
+
+def truncated_z_single(z_dim, truncation):
+ values = truncnorm.rvs(-2, 2, size=(1, z_dim))
+ return truncation * values
+
+def create_labels(batch_size, vocab_size, num_classes):
+ label = np.zeros((batch_size, vocab_size))
+ for i in range(batch_size):
+ for _ in range(random.randint(1, num_classes)):
+ j = random.randint(0, vocab_size-1)
+ label[i, j] = random.random()
+ label[i] /= label[i].sum()
+ return label