diff options
| author | adamhrv <adam@ahprojects.com> | 2018-12-15 19:57:49 +0100 |
|---|---|---|
| committer | adamhrv <adam@ahprojects.com> | 2018-12-15 19:57:49 +0100 |
| commit | 82b2c0b5d6d7baccbe4d574d96e18fe2078047d7 (patch) | |
| tree | a8784b7ec2bc5a0451c252f66a6b786f3a2504f5 /megapixels/app/models | |
| parent | 8e978af21c2b29f678a09701afb3ec7d65d0a6ab (diff) | |
| parent | c5b02ffab8d388e8a2925e51736b902a48a95e71 (diff) | |
Merge branch 'master' of github.com:adamhrv/megapixels_dev
Diffstat (limited to 'megapixels/app/models')
| -rw-r--r-- | megapixels/app/models/sql_factory.py | 152 |
1 files changed, 152 insertions, 0 deletions
diff --git a/megapixels/app/models/sql_factory.py b/megapixels/app/models/sql_factory.py new file mode 100644 index 00000000..e35c3e15 --- /dev/null +++ b/megapixels/app/models/sql_factory.py @@ -0,0 +1,152 @@ +import os +import glob +import time +import pandas as pd + +from sqlalchemy import create_engine, Table, Column, String, Integer, DateTime, Float +from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.declarative import declarative_base + +from app.utils.file_utils import load_recipe, load_csv_safe +from app.settings import app_cfg as cfg + +connection_url = "mysql+mysqldb://{}:{}@{}/{}".format( + os.getenv("DB_USER"), + os.getenv("DB_PASS"), + os.getenv("DB_HOST"), + os.getenv("DB_NAME") +) + +datasets = {} +loaded = False + +def list_datasets(): + return [dataset.describe() for dataset in datasets.values()] + +def get_dataset(name): + return datasets[name] if name in datasets else None + +def get_table(name, table_name): + dataset = get_dataset(name) + return dataset.get_table(table_name) if dataset else None + +def load_sql_datasets(replace=False, base_model=None): + global datasets, loaded + if loaded: + return datasets + engine = create_engine(connection_url) if replace else None + for path in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")): + dataset = load_sql_dataset(path, replace, engine, base_model) + datasets[dataset.name] = dataset + loaded = True + return datasets + +def load_sql_dataset(path, replace=False, engine=None, base_model=None): + name = os.path.basename(path) + dataset = SqlDataset(name, base_model=base_model) + + for fn in glob.iglob(os.path.join(path, "*.csv")): + key = os.path.basename(fn).replace(".csv", "") + table = dataset.get_table(key) + if table is None: + continue + if replace: + print('loading dataset {}'.format(fn)) + df = pd.read_csv(fn) + # fix columns that are named "index", a sql reserved word + df.columns = table.__table__.columns.keys() + df.to_sql(name=table.__tablename__, con=engine, if_exists='replace', index=False) + return dataset + +class SqlDataset: + """ + Bridge between the facial information CSVs connected to the datasets, and MySQL + - each dataset should have files that can be loaded into these database models + - names will be fixed to work in SQL (index -> id) + - we can then have more generic models for fetching this info after doing a FAISS query + """ + def __init__(self, name, engine=None, base_model=None): + self.name = name + self.tables = {} + if base_model is None: + self.engine = create_engine(connection_url) + base_model = declarative_base(engine) + self.base_model = base_model + + def describe(self): + return { + 'name': self.name, + 'tables': list(self.tables.keys()), + } + + def get_table(self, type): + if type in self.tables: + return self.tables[type] + elif type == 'uuids': + self.tables[type] = self.uuid_table() + elif type == 'roi': + self.tables[type] = self.roi_table() + elif type == 'identity_meta': + self.tables[type] = self.identity_table() + elif type == 'pose': + self.tables[type] = self.pose_table() + else: + return None + return self.tables[type] + + # ==> uuids.csv <== + # index,uuid + # 0,f03fd921-2d56-4e83-8115-f658d6a72287 + def uuid_table(self): + class UUID(self.base_model): + __tablename__ = self.name + "_uuid" + id = Column(Integer, primary_key=True) + uuid = Column(String(36), nullable=False) + return UUID + + # ==> roi.csv <== + # index,h,image_height,image_index,image_width,w,x,y + # 0,0.33000000000000007,250,0,250,0.32999999999999996,0.33666666666666667,0.35 + def roi_table(self): + class ROI(self.base_model): + __tablename__ = self.name + "_roi" + id = Column(Integer, primary_key=True) + h = Column(Float, nullable=False) + image_height = Column(Integer, nullable=False) + image_index = Column(Integer, nullable=False) + image_width = Column(Integer, nullable=False) + w = Column(Float, nullable=False) + x = Column(Float, nullable=False) + y = Column(Float, nullable=False) + return ROI + + # ==> identity.csv <== + # index,fullname,description,gender,images,image_index + # 0,A. J. Cook,Canadian actress,f,1,0 + def identity_table(self): + class Identity(self.base_model): + __tablename__ = self.name + "_identity" + id = Column(Integer, primary_key=True) + fullname = Column(String(36), nullable=False) + description = Column(String(36), nullable=False) + gender = Column(String(1), nullable=False) + images = Column(Integer, nullable=False) + image_id = Column(Integer, nullable=False) + return Identity + + # ==> pose.csv <== + # index,image_index,pitch,roll,yaw + # 0,0,11.16264458441435,10.415885631337728,22.99719032415318 + def pose_table(self): + class Pose(self.base_model): + __tablename__ = self.name + "_pose" + id = Column(Integer, primary_key=True) + image_id = Column(Integer, primary_key=True) + pitch = Column(Float, nullable=False) + roll = Column(Float, nullable=False) + yaw = Column(Float, nullable=False) + return Pose + + +# Session = sessionmaker(bind=engine) +# session = Session() |
