diff options
Diffstat (limited to 'check/app/models/sql_factory.py')
| -rw-r--r-- | check/app/models/sql_factory.py | 164 |
1 files changed, 164 insertions, 0 deletions
diff --git a/check/app/models/sql_factory.py b/check/app/models/sql_factory.py new file mode 100644 index 0000000..5cfb36b --- /dev/null +++ b/check/app/models/sql_factory.py @@ -0,0 +1,164 @@ +import os +import glob +import time +import pandas as pd + +from sqlalchemy import create_engine, Table, Column, String, Integer, DateTime, Float, func +from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.declarative import declarative_base + +from app.utils.file_utils import load_recipe, load_csv_safe +from app.settings import app_cfg as cfg + +connection_url = "mysql+mysqlconnector://{}:{}@{}/{}?charset=utf8mb4".format( + os.getenv("DB_USER"), + os.getenv("DB_PASS"), + os.getenv("DB_HOST"), + os.getenv("DB_NAME") +) + +loaded = False +Session = None + +class FileRecord(self.base_model): + __tablename__ = self.name + "_file_record" + id = Column(Integer, primary_key=True) + ext = Column(String(3, convert_unicode=True), nullable=False) + fn = Column(String(36, convert_unicode=True), nullable=False) + identity_key = Column(String(36, convert_unicode=True), nullable=False) + sha256 = Column(String(36, convert_unicode=True), nullable=False) + def toJSON(self): + return { + 'id': self.id, + 'uuid': self.uuid, + 'identity_id': self.identity_id, + } + +def load_sql_datasets(replace=False, base_model=None): + global datasets, loaded, Session + if loaded: + return datasets + engine = create_engine(connection_url, encoding="utf-8", pool_recycle=3600) + # db.set_character_set('utf8') + # dbc = db.cursor() + # dbc.execute('SET NAMES utf8;') + # dbc.execute('SET CHARACTER SET utf8;') + # dbc.execute('SET character_set_connection=utf8;') + Session = sessionmaker(bind=engine) + for path in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")): + dataset = load_sql_dataset(path, replace, engine, base_model) + datasets[dataset.name] = dataset + loaded = True + return datasets + +def load_sql_dataset(path, replace=False, engine=None, base_model=None): + name = os.path.basename(path) + dataset = SqlDataset(name, base_model=base_model) + + for fn in glob.iglob(os.path.join(path, "*.csv")): + key = os.path.basename(fn).replace(".csv", "") + table = dataset.get_table(key) + if table is None: + continue + if replace: + print('loading dataset {}'.format(fn)) + df = pd.read_csv(fn) + # fix columns that are named "index", a sql reserved word + df.reindex_axis(sorted(df.columns), axis=1) + columns = [column.name for column in table.__table__.columns] + df.columns = columns + df.to_sql(name=table.__tablename__, con=engine, if_exists='replace', index=False) + return dataset + +class SqlDataset: + """ + Bridge between the facial information CSVs connected to the datasets, and MySQL + - each dataset should have files that can be loaded into these database models + - names will be fixed to work in SQL (index -> id) + - we can then have more generic models for fetching this info after doing a FAISS query + """ + def __init__(self, name, engine=None, base_model=None): + self.name = name + self.tables = {} + if base_model is None: + self.engine = create_engine(connection_url, encoding="utf-8", pool_recycle=3600) + base_model = declarative_base(engine) + self.base_model = base_model + + def describe(self): + """ + List the available SQL tables for a given dataset. + """ + return { + 'name': self.name, + 'tables': list(self.tables.keys()), + } + + def get_identity(self, id): + """ + Get an identity given an ID. + """ + # id += 1 + file_record_table = self.get_table('file_record') + file_record = file_record_table.query.filter(file_record_table.id == id).first() + + if not file_record: + return None + + identity_table = self.get_table('identity') + identity = identity_table.query.filter(identity_table.id == file_record.identity_id).first() + + if not identity: + return None + + return { + 'file_record': file_record.toJSON(), + 'identity': identity.toJSON(), + 'face_roi': self.select('face_roi', id), + 'face_pose': self.select('face_pose', id), + } + + def search_name(self, q): + """ + Find an identity by name. + """ + table = self.get_table('identity') + identity_list = table.query.filter(table.fullname.ilike(q)).order_by(table.fullname.desc()).limit(15) + return identity_list + + def search_description(self, q): + """ + Find an identity by description. + """ + table = self.get_table('identity') + identity_list = table.query.filter(table.description.ilike(q)).order_by(table.description.desc()).limit(15) + return identity_list + + def get_file_records_for_identities(self, identity_list): + """ + Given a list of identities, map these to file records. + """ + identities = [] + file_record_table = self.get_table('file_record') + for row in identity_list: + file_record = file_record_table.query.filter(file_record_table.identity_id == row.id).first() + if file_record: + identities.append({ + 'file_record': file_record.toJSON(), + 'identity': row.toJSON(), + }) + return identities + + def select(self, table, id): + """ + Perform a generic select. + """ + table = self.get_table(table) + if not table: + return None + session = Session() + # for obj in session.query(table).filter_by(id=id): + # print(table) + obj = session.query(table).filter(table.id == id).first() + session.close() + return obj.toJSON() |
