diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-12-14 15:50:07 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-12-14 15:50:07 +0100 |
| commit | e1fba831b7c22f9840c5e92227f688079b9a206e (patch) | |
| tree | 834e4160b2a9ae54800c55e9d0e42d20f83513f5 /megapixels/commands/faiss/build.py | |
| parent | 9e7713e83a99d8ca50ffff49def7085bb8f4e09c (diff) | |
mysql import script
Diffstat (limited to 'megapixels/commands/faiss/build.py')
| -rw-r--r-- | megapixels/commands/faiss/build.py | 58 |
1 files changed, 37 insertions, 21 deletions
diff --git a/megapixels/commands/faiss/build.py b/megapixels/commands/faiss/build.py index e95619af..e525542a 100644 --- a/megapixels/commands/faiss/build.py +++ b/megapixels/commands/faiss/build.py @@ -3,44 +3,60 @@ Index all of the FAISS datasets """ import os +import glob import click +import faiss +import time +import numpy as np -from app.utils.file_utils import load_recipe, load_csv +from app.utils.file_utils import load_recipe, load_csv_safe from app.settings import app_cfg as cfg +engine = create_engine('sqlite:///:memory:') + +class DefaultRecipe: + def __init__(self): + self.dim = 128 + self.factory_type = 'Flat' + @click.command() @click.pass_context def cli(ctx): - """train the FAISS index""" - - recipe = { - "dim": 128, - "factory_type": "Flat" - } - + """build the FAISS index. + - looks for all datasets in faiss/metadata/ + - uses the recipe above by default + - however you can override this by adding a new recipe in faiss/recipes/{name}.json + """ datasets = [] - for fn in glob.iglob(os.path.join(cfg.DIR_FAISS_DATASETS, "*")): + for fn in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")): name = os.path.basename(fn) recipe_fn = os.path.join(cfg.DIR_FAISS_RECIPES, name + ".json") if os.path.exists(recipe_fn): - train(name, load_recipe(recipe_fn)) + build_faiss(name, load_recipe(recipe_fn)) else: - train(name, recipe) + build_faiss(name, DefaultRecipe()) + # index identities + # certain CSV files should be loaded into mysql + # User.__table__.drop() + SQLemployees.create(engine) -def train(name, recipe): - vec_fn = os.path.join(cfg.DIR_FAISS_DATASETS, name, "vecs.csv") +def build_faiss(name, recipe): + vec_fn = os.path.join(cfg.DIR_FAISS_METADATA, name, "vecs.csv") index_fn = os.path.join(cfg.DIR_FAISS_INDEXES, name + ".index") - index = faiss.index_factory(recipe.dimension, recipe.factory) + index = faiss.index_factory(recipe.dim, recipe.factory_type) - keys, rows = file_utils.load_csv_safe(vec_fn) - feats = np.array([ float(x[1].split(",")) for x in rows]).astype('float32') + keys, rows = load_csv_safe(vec_fn) + feats = np.array([ list(map(float, row[3].split(","))) for row in rows ]).astype('float32') n, d = feats.shape - train_start = time.time() - index.train(feats) - train_end = time.time() - train_time = train_end - train_start - print("{} train time: {:.1f}s".format(name, train_time)) + print("{}: training {} x {} dim vectors".format(name, n, d)) + print(recipe.factory_type) + + add_start = time.time() + index.add(feats) + add_end = time.time() + add_time = add_end - add_start + print("{}: add time: {:.1f}s".format(name, add_time)) faiss.write_index(index, index_fn) |
