summaryrefslogtreecommitdiff
path: root/megapixels/commands/faiss/build.py
diff options
context:
space:
mode:
Diffstat (limited to 'megapixels/commands/faiss/build.py')
-rw-r--r--megapixels/commands/faiss/build.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/megapixels/commands/faiss/build.py b/megapixels/commands/faiss/build.py
new file mode 100644
index 00000000..e95619af
--- /dev/null
+++ b/megapixels/commands/faiss/build.py
@@ -0,0 +1,46 @@
+"""
+Index all of the FAISS datasets
+"""
+
+import os
+import click
+
+from app.utils.file_utils import load_recipe, load_csv
+from app.settings import app_cfg as cfg
+
+@click.command()
+@click.pass_context
+def cli(ctx):
+ """train the FAISS index"""
+
+ recipe = {
+ "dim": 128,
+ "factory_type": "Flat"
+ }
+
+ datasets = []
+ for fn in glob.iglob(os.path.join(cfg.DIR_FAISS_DATASETS, "*")):
+ name = os.path.basename(fn)
+ recipe_fn = os.path.join(cfg.DIR_FAISS_RECIPES, name + ".json")
+ if os.path.exists(recipe_fn):
+ train(name, load_recipe(recipe_fn))
+ else:
+ train(name, recipe)
+
+def train(name, recipe):
+ vec_fn = os.path.join(cfg.DIR_FAISS_DATASETS, name, "vecs.csv")
+ index_fn = os.path.join(cfg.DIR_FAISS_INDEXES, name + ".index")
+
+ index = faiss.index_factory(recipe.dimension, recipe.factory)
+
+ keys, rows = file_utils.load_csv_safe(vec_fn)
+ feats = np.array([ float(x[1].split(",")) for x in rows]).astype('float32')
+ n, d = feats.shape
+
+ train_start = time.time()
+ index.train(feats)
+ train_end = time.time()
+ train_time = train_end - train_start
+ print("{} train time: {:.1f}s".format(name, train_time))
+
+ faiss.write_index(index, index_fn)