diff options
Diffstat (limited to 'megapixels/app/models/sql_factory.py')
| -rw-r--r-- | megapixels/app/models/sql_factory.py | 61 |
1 files changed, 56 insertions, 5 deletions
diff --git a/megapixels/app/models/sql_factory.py b/megapixels/app/models/sql_factory.py index 525492f1..2a18d6af 100644 --- a/megapixels/app/models/sql_factory.py +++ b/megapixels/app/models/sql_factory.py @@ -1,9 +1,15 @@ import os +import glob +import time +import pandas as pd from sqlalchemy import create_engine, Table, Column, String, Integer, DateTime, Float from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base +from app.utils.file_utils import load_recipe, load_csv_safe +from app.settings import app_cfg as cfg + connection_url = "mysql+mysqldb://{}:{}@{}/{}".format( os.getenv("DB_USER"), os.getenv("DB_PASS"), @@ -11,8 +17,49 @@ connection_url = "mysql+mysqldb://{}:{}@{}/{}".format( os.getenv("DB_NAME") ) -# Session = sessionmaker(bind=engine) -# session = Session() +datasets = {} +loaded = False + +def list_datasets(): + return [{ + 'name': name, + 'tables': list(datasets[name].tables.keys()), + } for name in datasets.keys()] + +def get_dataset(name): + return datasets[name] if name in datasets else None + +def get_table(name, table_name): + dataset = get_dataset(name) + return dataset.get_table(table_name) if dataset else None + +def load_sql_datasets(replace=False, base_model=None): + global datasets, loaded + if loaded: + return datasets + engine = create_engine(connection_url) if replace else None + for path in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")): + dataset = load_sql_dataset(path, replace, engine, base_model) + datasets[dataset.name] = dataset + loaded = True + return datasets + +def load_sql_dataset(path, replace=False, engine=None, base_model=None): + name = os.path.basename(path) + dataset = SqlDataset(name, base_model=base_model) + + for fn in glob.iglob(os.path.join(path, "*.csv")): + key = os.path.basename(fn).replace(".csv", "") + table = dataset.get_table(key) + if table is None: + continue + if replace: + print('loading dataset {}'.format(fn)) + df = pd.read_csv(fn) + # fix columns that are named "index", a sql reserved word + df.columns = table.__table__.columns.keys() + df.to_sql(name=table.__tablename__, con=engine, if_exists='replace', index=False) + return dataset class SqlDataset: """ @@ -21,18 +68,18 @@ class SqlDataset: - names will be fixed to work in SQL (index -> id) - we can then have more generic models for fetching this info after doing a FAISS query """ - def __init__(self, name, base_model=None): + def __init__(self, name, engine=None, base_model=None): self.name = name self.tables = {} if base_model is None: - engine = create_engine(connection_url) + self.engine = create_engine(connection_url) base_model = declarative_base(engine) self.base_model = base_model def get_table(self, type): if type in self.tables: return self.tables[type] - elif type == 'uuid': + elif type == 'uuids': self.tables[type] = self.uuid_table() elif type == 'roi': self.tables[type] = self.roi_table() @@ -96,3 +143,7 @@ class SqlDataset: roll = Column(Float, nullable=False) yaw = Column(Float, nullable=False) return Pose + + +# Session = sessionmaker(bind=engine) +# session = Session() |
