""" Index all of the FAISS datasets """ import os import click from app.utils.file_utils import load_recipe, load_csv from app.settings import app_cfg as cfg @click.command() @click.pass_context def cli(ctx): """train the FAISS index""" recipe = { "dim": 128, "factory_type": "Flat" } datasets = [] for fn in glob.iglob(os.path.join(cfg.DIR_FAISS_DATASETS, "*")): name = os.path.basename(fn) recipe_fn = os.path.join(cfg.DIR_FAISS_RECIPES, name + ".json") if os.path.exists(recipe_fn): train(name, load_recipe(recipe_fn)) else: train(name, recipe) def train(name, recipe): vec_fn = os.path.join(cfg.DIR_FAISS_DATASETS, name, "vecs.csv") index_fn = os.path.join(cfg.DIR_FAISS_INDEXES, name + ".index") index = faiss.index_factory(recipe.dimension, recipe.factory) keys, rows = file_utils.load_csv_safe(vec_fn) feats = np.array([ float(x[1].split(",")) for x in rows]).astype('float32') n, d = feats.shape train_start = time.time() index.train(feats) train_end = time.time() train_time = train_end - train_start print("{} train time: {:.1f}s".format(name, train_time)) faiss.write_index(index, index_fn)