summaryrefslogtreecommitdiff
path: root/cli/app/search/vector.py
diff options
context:
space:
mode:
Diffstat (limited to 'cli/app/search/vector.py')
-rw-r--r--cli/app/search/vector.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/cli/app/search/vector.py b/cli/app/search/vector.py
index 89cd949..b118ef3 100644
--- a/cli/app/search/vector.py
+++ b/cli/app/search/vector.py
@@ -18,3 +18,11 @@ def create_labels(batch_size, vocab_size, num_classes):
label[i, j] = random.random()
label[i] /= label[i].sum()
return label
+
+def create_labels_uniform(batch_size, vocab_size):
+ label = np.zeros((batch_size, vocab_size))
+ for i in range(batch_size):
+ for j in range(vocab_size):
+ label[i, j] = random.uniform(1 / vocab_size, 2 / vocab_size)
+ label[i] /= label[i].sum()
+ return label