diff options
Diffstat (limited to 'megapixels/commands/faiss/build.py')
| -rw-r--r-- | megapixels/commands/faiss/build.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/megapixels/commands/faiss/build.py b/megapixels/commands/faiss/build.py new file mode 100644 index 00000000..e95619af --- /dev/null +++ b/megapixels/commands/faiss/build.py @@ -0,0 +1,46 @@ +""" +Index all of the FAISS datasets +""" + +import os +import click + +from app.utils.file_utils import load_recipe, load_csv +from app.settings import app_cfg as cfg + +@click.command() +@click.pass_context +def cli(ctx): + """train the FAISS index""" + + recipe = { + "dim": 128, + "factory_type": "Flat" + } + + datasets = [] + for fn in glob.iglob(os.path.join(cfg.DIR_FAISS_DATASETS, "*")): + name = os.path.basename(fn) + recipe_fn = os.path.join(cfg.DIR_FAISS_RECIPES, name + ".json") + if os.path.exists(recipe_fn): + train(name, load_recipe(recipe_fn)) + else: + train(name, recipe) + +def train(name, recipe): + vec_fn = os.path.join(cfg.DIR_FAISS_DATASETS, name, "vecs.csv") + index_fn = os.path.join(cfg.DIR_FAISS_INDEXES, name + ".index") + + index = faiss.index_factory(recipe.dimension, recipe.factory) + + keys, rows = file_utils.load_csv_safe(vec_fn) + feats = np.array([ float(x[1].split(",")) for x in rows]).astype('float32') + n, d = feats.shape + + train_start = time.time() + index.train(feats) + train_end = time.time() + train_time = train_end - train_start + print("{} train time: {:.1f}s".format(name, train_time)) + + faiss.write_index(index, index_fn) |
