diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-12-17 00:35:19 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-12-17 00:35:19 +0100 |
| commit | 4cf8581655c34698f8869bb364b6d436b881d17a (patch) | |
| tree | 449d6c5a8cd5a3f5bbd277e67f66a734ae0f51c8 /megapixels | |
| parent | 0bbaef7c889f2bf17cdf7e4584a6946085d0a7eb (diff) | |
returning results...!
Diffstat (limited to 'megapixels')
| -rw-r--r-- | megapixels/app/models/sql_factory.py | 64 | ||||
| -rw-r--r-- | megapixels/app/processors/faiss.py | 58 | ||||
| -rw-r--r-- | megapixels/app/server/api.py | 53 | ||||
| -rw-r--r-- | megapixels/app/server/json_encoder.py | 17 | ||||
| -rw-r--r-- | megapixels/commands/faiss/build_faiss.py | 36 |
5 files changed, 179 insertions, 49 deletions
diff --git a/megapixels/app/models/sql_factory.py b/megapixels/app/models/sql_factory.py index e35c3e15..0f7e73a0 100644 --- a/megapixels/app/models/sql_factory.py +++ b/megapixels/app/models/sql_factory.py @@ -19,6 +19,7 @@ connection_url = "mysql+mysqldb://{}:{}@{}/{}".format( datasets = {} loaded = False +Session = None def list_datasets(): return [dataset.describe() for dataset in datasets.values()] @@ -31,10 +32,11 @@ def get_table(name, table_name): return dataset.get_table(table_name) if dataset else None def load_sql_datasets(replace=False, base_model=None): - global datasets, loaded + global datasets, loaded, Session if loaded: return datasets - engine = create_engine(connection_url) if replace else None + engine = create_engine(connection_url) + Session = sessionmaker(bind=engine) for path in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")): dataset = load_sql_dataset(path, replace, engine, base_model) datasets[dataset.name] = dataset @@ -79,6 +81,27 @@ class SqlDataset: 'tables': list(self.tables.keys()), } + def get_identity(self, id): + table = self.get_table('identity_meta') + identity = table.query.filter(table.image_id >= id).order_by(table.image_id.asc()).first().toJSON() + print(identity) + return { + 'uuid': self.select('uuids', id), + 'identity': identity, + 'roi': self.select('roi', id), + 'pose': self.select('pose', id), + } + + def select(self, table, id): + table = self.get_table(table) + if not table: + return None + session = Session() + # for obj in session.query(table).filter_by(id=id): + print(table) + obj = session.query(table).filter(table.id == id).first() + return obj.toJSON() + def get_table(self, type): if type in self.tables: return self.tables[type] @@ -102,6 +125,11 @@ class SqlDataset: __tablename__ = self.name + "_uuid" id = Column(Integer, primary_key=True) uuid = Column(String(36), nullable=False) + def toJSON(self): + return { + 'id': self.id, + 'uuid': self.uuid, + } return UUID # ==> roi.csv <== @@ -118,6 +146,17 @@ class SqlDataset: w = Column(Float, nullable=False) x = Column(Float, nullable=False) y = Column(Float, nullable=False) + def toJSON(self): + return { + 'id': self.id, + 'image_index': self.image_index, + 'image_height': self.image_height, + 'image_width': self.image_width, + 'w': self.w, + 'h': self.h, + 'x': self.x, + 'y': self.y, + } return ROI # ==> identity.csv <== @@ -132,6 +171,15 @@ class SqlDataset: gender = Column(String(1), nullable=False) images = Column(Integer, nullable=False) image_id = Column(Integer, nullable=False) + def toJSON(self): + return { + 'id': self.id, + 'image_id': self.image_id, + 'fullname': self.fullname, + 'images': self.images, + 'gender': self.gender, + 'description': self.description, + } return Identity # ==> pose.csv <== @@ -145,8 +193,12 @@ class SqlDataset: pitch = Column(Float, nullable=False) roll = Column(Float, nullable=False) yaw = Column(Float, nullable=False) + def toJSON(self): + return { + 'id': self.id, + 'image_id': self.image_id, + 'pitch': self.pitch, + 'roll': self.roll, + 'yaw': self.yaw, + } return Pose - - -# Session = sessionmaker(bind=engine) -# session = Session() diff --git a/megapixels/app/processors/faiss.py b/megapixels/app/processors/faiss.py new file mode 100644 index 00000000..5156ad71 --- /dev/null +++ b/megapixels/app/processors/faiss.py @@ -0,0 +1,58 @@ +""" +Index all of the FAISS datasets +""" + +import os +import glob +import faiss +import time +import numpy as np + +from app.utils.file_utils import load_recipe, load_csv_safe +from app.settings import app_cfg as cfg + +class DefaultRecipe: + def __init__(self): + self.dim = 128 + self.factory_type = 'Flat' + +def build_all_faiss_databases(): + datasets = [] + for fn in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")): + name = os.path.basename(fn) + recipe_fn = os.path.join(cfg.DIR_FAISS_RECIPES, name + ".json") + if os.path.exists(recipe_fn): + build_faiss_database(name, load_recipe(recipe_fn)) + else: + build_faiss_database(name, DefaultRecipe()) + +def build_faiss_database(name, recipe): + vec_fn = os.path.join(cfg.DIR_FAISS_METADATA, name, "vecs.csv") + index_fn = os.path.join(cfg.DIR_FAISS_INDEXES, name + ".index") + + index = faiss.index_factory(recipe.dim, recipe.factory_type) + + keys, rows = load_csv_safe(vec_fn) + feats = np.array([ list(map(float, row[3].split(","))) for row in rows ]).astype('float32') + n, d = feats.shape + + print("{}: training {} x {} dim vectors".format(name, n, d)) + print(recipe.factory_type) + + add_start = time.time() + index.add(feats) + add_end = time.time() + add_time = add_end - add_start + print("{}: add time: {:.1f}s".format(name, add_time)) + + faiss.write_index(index, index_fn) + +def load_faiss_databases(): + faiss_datasets = {} + for fn in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")): + name = os.path.basename(fn) + index_fn = os.path.join(cfg.DIR_FAISS_INDEXES, name + ".index") + if os.path.exists(index_fn): + index = faiss.read_index(index_fn) + faiss_datasets[name] = index + return faiss_datasets diff --git a/megapixels/app/server/api.py b/megapixels/app/server/api.py index cf8241bd..36563910 100644 --- a/megapixels/app/server/api.py +++ b/megapixels/app/server/api.py @@ -2,18 +2,23 @@ import os import re import time import dlib +import numpy as np from flask import Blueprint, request, jsonify from PIL import Image # todo: try to remove PIL dependency from app.processors import face_recognition from app.processors import face_detector -from app.models.sql_factory import list_datasets, get_dataset, get_table +from app.processors.faiss import load_faiss_databases +from app.models.sql_factory import load_sql_datasets, list_datasets, get_dataset, get_table +from app.utils.im_utils import pil2np sanitize_re = re.compile('[\W]+') valid_exts = ['.gif', '.jpg', '.jpeg', '.png'] api = Blueprint('api', __name__) +faiss_datasets = load_faiss_databases() + @api.route('/') def index(): return jsonify({ 'datasets': list_datasets() }) @@ -26,10 +31,15 @@ def show(name): else: return jsonify({ 'status': 404 }) -@api.route('/dataset/<name>/face', methods=['POST']) +@api.route('/dataset/<name>/face/', methods=['POST']) def upload(name): start = time.time() dataset = get_dataset(name) + if name not in faiss_datasets: + return jsonify({ + 'error': 'invalid dataset' + }) + faiss_dataset = faiss_datasets[name] file = request.files['query_img'] fn = file.filename if fn.endswith('blob'): @@ -40,22 +50,46 @@ def upload(name): if ext.lower() not in valid_exts: return jsonify({ 'error': 'not an image' }) - img = Image.open(file.stream).convert('RGB') + im = Image.open(file.stream).convert('RGB') + im_np = pil2np(im) # Face detection detector = face_detector.DetectorDLIBHOG() # get detection as BBox object - bboxes = detector.detect(im, largest=True) + bboxes = detector.detect(im_np, largest=True) bbox = bboxes[0] - dim = im.shape[:2][::-1] + dim = im_np.shape[:2][::-1] bbox = bbox.to_dim(dim) # convert back to real dimensions # face recognition/vector recognition = face_recognition.RecognitionDLIB(gpu=-1) + vec = recognition.vec(im_np, bbox) + + # print(vec) + query = np.array([ vec ]).astype('float32') + + # query FAISS! + distances, indexes = faiss_dataset.search(query, 5) + + if len(indexes) == 0: + print("weird, no results!") + return [] + + # get the results for this single query... + distances = distances[0] + indexes = indexes[0] - # print(vec.shape) - # results = db.search(vec, limit=limit) + if len(indexes) == 0: + print("no results!") + return [] + + lookup = {} + for _d, _i in zip(distances, indexes): + lookup[_i+1] = _d + + print(distances) + print(indexes) # with the result we have an ID # query the sql dataset for the UUID etc here @@ -63,12 +97,13 @@ def upload(name): query = { 'timing': time.time() - start, } - results = [] + results = [ dataset.get_identity(index) for index in indexes ] print(results) return jsonify({ - 'query': query, 'results': results, + # 'distances': distances.tolist(), + # 'indexes': indexes.tolist(), }) @api.route('/dataset/<name>/name', methods=['GET']) diff --git a/megapixels/app/server/json_encoder.py b/megapixels/app/server/json_encoder.py new file mode 100644 index 00000000..89af578a --- /dev/null +++ b/megapixels/app/server/json_encoder.py @@ -0,0 +1,17 @@ +from sqlalchemy.ext.declarative import DeclarativeMeta +from flask import json + +class AlchemyEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o.__class__, DeclarativeMeta): + data = {} + fields = o.__json__() if hasattr(o, '__json__') else dir(o) + for field in [f for f in fields if not f.startswith('_') and f not in ['metadata', 'query', 'query_class']]: + value = o.__getattribute__(field) + try: + json.dumps(value) + data[field] = value + except TypeError: + data[field] = None + return data + return json.JSONEncoder.default(self, o) diff --git a/megapixels/commands/faiss/build_faiss.py b/megapixels/commands/faiss/build_faiss.py index ec94c924..fc6b37ce 100644 --- a/megapixels/commands/faiss/build_faiss.py +++ b/megapixels/commands/faiss/build_faiss.py @@ -11,11 +11,7 @@ import numpy as np from app.utils.file_utils import load_recipe, load_csv_safe from app.settings import app_cfg as cfg - -class DefaultRecipe: - def __init__(self): - self.dim = 128 - self.factory_type = 'Flat' +from app.processors.faiss import build_all_faiss_databases @click.command() @click.pass_context @@ -25,32 +21,4 @@ def cli(ctx): - uses the recipe above by default - however you can override this by adding a new recipe in faiss/recipes/{name}.json """ - datasets = [] - for fn in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")): - name = os.path.basename(fn) - recipe_fn = os.path.join(cfg.DIR_FAISS_RECIPES, name + ".json") - if os.path.exists(recipe_fn): - build_faiss(name, load_recipe(recipe_fn)) - else: - build_faiss(name, DefaultRecipe()) - -def build_faiss(name, recipe): - vec_fn = os.path.join(cfg.DIR_FAISS_METADATA, name, "vecs.csv") - index_fn = os.path.join(cfg.DIR_FAISS_INDEXES, name + ".index") - - index = faiss.index_factory(recipe.dim, recipe.factory_type) - - keys, rows = load_csv_safe(vec_fn) - feats = np.array([ list(map(float, row[3].split(","))) for row in rows ]).astype('float32') - n, d = feats.shape - - print("{}: training {} x {} dim vectors".format(name, n, d)) - print(recipe.factory_type) - - add_start = time.time() - index.add(feats) - add_end = time.time() - add_time = add_end - add_start - print("{}: add time: {:.1f}s".format(name, add_time)) - - faiss.write_index(index, index_fn) + build_all_faiss_databases() |
