summaryrefslogtreecommitdiff
path: root/megapixels/app
diff options
context:
space:
mode:
Diffstat (limited to 'megapixels/app')
-rw-r--r--megapixels/app/models/sql_factory.py64
-rw-r--r--megapixels/app/processors/faiss.py58
-rw-r--r--megapixels/app/server/api.py53
-rw-r--r--megapixels/app/server/json_encoder.py17
4 files changed, 177 insertions, 15 deletions
diff --git a/megapixels/app/models/sql_factory.py b/megapixels/app/models/sql_factory.py
index e35c3e15..0f7e73a0 100644
--- a/megapixels/app/models/sql_factory.py
+++ b/megapixels/app/models/sql_factory.py
@@ -19,6 +19,7 @@ connection_url = "mysql+mysqldb://{}:{}@{}/{}".format(
datasets = {}
loaded = False
+Session = None
def list_datasets():
return [dataset.describe() for dataset in datasets.values()]
@@ -31,10 +32,11 @@ def get_table(name, table_name):
return dataset.get_table(table_name) if dataset else None
def load_sql_datasets(replace=False, base_model=None):
- global datasets, loaded
+ global datasets, loaded, Session
if loaded:
return datasets
- engine = create_engine(connection_url) if replace else None
+ engine = create_engine(connection_url)
+ Session = sessionmaker(bind=engine)
for path in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")):
dataset = load_sql_dataset(path, replace, engine, base_model)
datasets[dataset.name] = dataset
@@ -79,6 +81,27 @@ class SqlDataset:
'tables': list(self.tables.keys()),
}
+ def get_identity(self, id):
+ table = self.get_table('identity_meta')
+ identity = table.query.filter(table.image_id >= id).order_by(table.image_id.asc()).first().toJSON()
+ print(identity)
+ return {
+ 'uuid': self.select('uuids', id),
+ 'identity': identity,
+ 'roi': self.select('roi', id),
+ 'pose': self.select('pose', id),
+ }
+
+ def select(self, table, id):
+ table = self.get_table(table)
+ if not table:
+ return None
+ session = Session()
+ # for obj in session.query(table).filter_by(id=id):
+ print(table)
+ obj = session.query(table).filter(table.id == id).first()
+ return obj.toJSON()
+
def get_table(self, type):
if type in self.tables:
return self.tables[type]
@@ -102,6 +125,11 @@ class SqlDataset:
__tablename__ = self.name + "_uuid"
id = Column(Integer, primary_key=True)
uuid = Column(String(36), nullable=False)
+ def toJSON(self):
+ return {
+ 'id': self.id,
+ 'uuid': self.uuid,
+ }
return UUID
# ==> roi.csv <==
@@ -118,6 +146,17 @@ class SqlDataset:
w = Column(Float, nullable=False)
x = Column(Float, nullable=False)
y = Column(Float, nullable=False)
+ def toJSON(self):
+ return {
+ 'id': self.id,
+ 'image_index': self.image_index,
+ 'image_height': self.image_height,
+ 'image_width': self.image_width,
+ 'w': self.w,
+ 'h': self.h,
+ 'x': self.x,
+ 'y': self.y,
+ }
return ROI
# ==> identity.csv <==
@@ -132,6 +171,15 @@ class SqlDataset:
gender = Column(String(1), nullable=False)
images = Column(Integer, nullable=False)
image_id = Column(Integer, nullable=False)
+ def toJSON(self):
+ return {
+ 'id': self.id,
+ 'image_id': self.image_id,
+ 'fullname': self.fullname,
+ 'images': self.images,
+ 'gender': self.gender,
+ 'description': self.description,
+ }
return Identity
# ==> pose.csv <==
@@ -145,8 +193,12 @@ class SqlDataset:
pitch = Column(Float, nullable=False)
roll = Column(Float, nullable=False)
yaw = Column(Float, nullable=False)
+ def toJSON(self):
+ return {
+ 'id': self.id,
+ 'image_id': self.image_id,
+ 'pitch': self.pitch,
+ 'roll': self.roll,
+ 'yaw': self.yaw,
+ }
return Pose
-
-
-# Session = sessionmaker(bind=engine)
-# session = Session()
diff --git a/megapixels/app/processors/faiss.py b/megapixels/app/processors/faiss.py
new file mode 100644
index 00000000..5156ad71
--- /dev/null
+++ b/megapixels/app/processors/faiss.py
@@ -0,0 +1,58 @@
+"""
+Index all of the FAISS datasets
+"""
+
+import os
+import glob
+import faiss
+import time
+import numpy as np
+
+from app.utils.file_utils import load_recipe, load_csv_safe
+from app.settings import app_cfg as cfg
+
+class DefaultRecipe:
+ def __init__(self):
+ self.dim = 128
+ self.factory_type = 'Flat'
+
+def build_all_faiss_databases():
+ datasets = []
+ for fn in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")):
+ name = os.path.basename(fn)
+ recipe_fn = os.path.join(cfg.DIR_FAISS_RECIPES, name + ".json")
+ if os.path.exists(recipe_fn):
+ build_faiss_database(name, load_recipe(recipe_fn))
+ else:
+ build_faiss_database(name, DefaultRecipe())
+
+def build_faiss_database(name, recipe):
+ vec_fn = os.path.join(cfg.DIR_FAISS_METADATA, name, "vecs.csv")
+ index_fn = os.path.join(cfg.DIR_FAISS_INDEXES, name + ".index")
+
+ index = faiss.index_factory(recipe.dim, recipe.factory_type)
+
+ keys, rows = load_csv_safe(vec_fn)
+ feats = np.array([ list(map(float, row[3].split(","))) for row in rows ]).astype('float32')
+ n, d = feats.shape
+
+ print("{}: training {} x {} dim vectors".format(name, n, d))
+ print(recipe.factory_type)
+
+ add_start = time.time()
+ index.add(feats)
+ add_end = time.time()
+ add_time = add_end - add_start
+ print("{}: add time: {:.1f}s".format(name, add_time))
+
+ faiss.write_index(index, index_fn)
+
+def load_faiss_databases():
+ faiss_datasets = {}
+ for fn in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")):
+ name = os.path.basename(fn)
+ index_fn = os.path.join(cfg.DIR_FAISS_INDEXES, name + ".index")
+ if os.path.exists(index_fn):
+ index = faiss.read_index(index_fn)
+ faiss_datasets[name] = index
+ return faiss_datasets
diff --git a/megapixels/app/server/api.py b/megapixels/app/server/api.py
index cf8241bd..36563910 100644
--- a/megapixels/app/server/api.py
+++ b/megapixels/app/server/api.py
@@ -2,18 +2,23 @@ import os
import re
import time
import dlib
+import numpy as np
from flask import Blueprint, request, jsonify
from PIL import Image # todo: try to remove PIL dependency
from app.processors import face_recognition
from app.processors import face_detector
-from app.models.sql_factory import list_datasets, get_dataset, get_table
+from app.processors.faiss import load_faiss_databases
+from app.models.sql_factory import load_sql_datasets, list_datasets, get_dataset, get_table
+from app.utils.im_utils import pil2np
sanitize_re = re.compile('[\W]+')
valid_exts = ['.gif', '.jpg', '.jpeg', '.png']
api = Blueprint('api', __name__)
+faiss_datasets = load_faiss_databases()
+
@api.route('/')
def index():
return jsonify({ 'datasets': list_datasets() })
@@ -26,10 +31,15 @@ def show(name):
else:
return jsonify({ 'status': 404 })
-@api.route('/dataset/<name>/face', methods=['POST'])
+@api.route('/dataset/<name>/face/', methods=['POST'])
def upload(name):
start = time.time()
dataset = get_dataset(name)
+ if name not in faiss_datasets:
+ return jsonify({
+ 'error': 'invalid dataset'
+ })
+ faiss_dataset = faiss_datasets[name]
file = request.files['query_img']
fn = file.filename
if fn.endswith('blob'):
@@ -40,22 +50,46 @@ def upload(name):
if ext.lower() not in valid_exts:
return jsonify({ 'error': 'not an image' })
- img = Image.open(file.stream).convert('RGB')
+ im = Image.open(file.stream).convert('RGB')
+ im_np = pil2np(im)
# Face detection
detector = face_detector.DetectorDLIBHOG()
# get detection as BBox object
- bboxes = detector.detect(im, largest=True)
+ bboxes = detector.detect(im_np, largest=True)
bbox = bboxes[0]
- dim = im.shape[:2][::-1]
+ dim = im_np.shape[:2][::-1]
bbox = bbox.to_dim(dim) # convert back to real dimensions
# face recognition/vector
recognition = face_recognition.RecognitionDLIB(gpu=-1)
+ vec = recognition.vec(im_np, bbox)
+
+ # print(vec)
+ query = np.array([ vec ]).astype('float32')
+
+ # query FAISS!
+ distances, indexes = faiss_dataset.search(query, 5)
+
+ if len(indexes) == 0:
+ print("weird, no results!")
+ return []
+
+ # get the results for this single query...
+ distances = distances[0]
+ indexes = indexes[0]
- # print(vec.shape)
- # results = db.search(vec, limit=limit)
+ if len(indexes) == 0:
+ print("no results!")
+ return []
+
+ lookup = {}
+ for _d, _i in zip(distances, indexes):
+ lookup[_i+1] = _d
+
+ print(distances)
+ print(indexes)
# with the result we have an ID
# query the sql dataset for the UUID etc here
@@ -63,12 +97,13 @@ def upload(name):
query = {
'timing': time.time() - start,
}
- results = []
+ results = [ dataset.get_identity(index) for index in indexes ]
print(results)
return jsonify({
- 'query': query,
'results': results,
+ # 'distances': distances.tolist(),
+ # 'indexes': indexes.tolist(),
})
@api.route('/dataset/<name>/name', methods=['GET'])
diff --git a/megapixels/app/server/json_encoder.py b/megapixels/app/server/json_encoder.py
new file mode 100644
index 00000000..89af578a
--- /dev/null
+++ b/megapixels/app/server/json_encoder.py
@@ -0,0 +1,17 @@
+from sqlalchemy.ext.declarative import DeclarativeMeta
+from flask import json
+
+class AlchemyEncoder(json.JSONEncoder):
+ def default(self, o):
+ if isinstance(o.__class__, DeclarativeMeta):
+ data = {}
+ fields = o.__json__() if hasattr(o, '__json__') else dir(o)
+ for field in [f for f in fields if not f.startswith('_') and f not in ['metadata', 'query', 'query_class']]:
+ value = o.__getattribute__(field)
+ try:
+ json.dumps(value)
+ data[field] = value
+ except TypeError:
+ data[field] = None
+ return data
+ return json.JSONEncoder.default(self, o)