summaryrefslogtreecommitdiff
path: root/megapixels/app/models/sql_factory.py
diff options
context:
space:
mode:
Diffstat (limited to 'megapixels/app/models/sql_factory.py')
-rw-r--r--megapixels/app/models/sql_factory.py61
1 files changed, 56 insertions, 5 deletions
diff --git a/megapixels/app/models/sql_factory.py b/megapixels/app/models/sql_factory.py
index 525492f1..2a18d6af 100644
--- a/megapixels/app/models/sql_factory.py
+++ b/megapixels/app/models/sql_factory.py
@@ -1,9 +1,15 @@
import os
+import glob
+import time
+import pandas as pd
from sqlalchemy import create_engine, Table, Column, String, Integer, DateTime, Float
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
+from app.utils.file_utils import load_recipe, load_csv_safe
+from app.settings import app_cfg as cfg
+
connection_url = "mysql+mysqldb://{}:{}@{}/{}".format(
os.getenv("DB_USER"),
os.getenv("DB_PASS"),
@@ -11,8 +17,49 @@ connection_url = "mysql+mysqldb://{}:{}@{}/{}".format(
os.getenv("DB_NAME")
)
-# Session = sessionmaker(bind=engine)
-# session = Session()
+datasets = {}
+loaded = False
+
+def list_datasets():
+ return [{
+ 'name': name,
+ 'tables': list(datasets[name].tables.keys()),
+ } for name in datasets.keys()]
+
+def get_dataset(name):
+ return datasets[name] if name in datasets else None
+
+def get_table(name, table_name):
+ dataset = get_dataset(name)
+ return dataset.get_table(table_name) if dataset else None
+
+def load_sql_datasets(replace=False, base_model=None):
+ global datasets, loaded
+ if loaded:
+ return datasets
+ engine = create_engine(connection_url) if replace else None
+ for path in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")):
+ dataset = load_sql_dataset(path, replace, engine, base_model)
+ datasets[dataset.name] = dataset
+ loaded = True
+ return datasets
+
+def load_sql_dataset(path, replace=False, engine=None, base_model=None):
+ name = os.path.basename(path)
+ dataset = SqlDataset(name, base_model=base_model)
+
+ for fn in glob.iglob(os.path.join(path, "*.csv")):
+ key = os.path.basename(fn).replace(".csv", "")
+ table = dataset.get_table(key)
+ if table is None:
+ continue
+ if replace:
+ print('loading dataset {}'.format(fn))
+ df = pd.read_csv(fn)
+ # fix columns that are named "index", a sql reserved word
+ df.columns = table.__table__.columns.keys()
+ df.to_sql(name=table.__tablename__, con=engine, if_exists='replace', index=False)
+ return dataset
class SqlDataset:
"""
@@ -21,18 +68,18 @@ class SqlDataset:
- names will be fixed to work in SQL (index -> id)
- we can then have more generic models for fetching this info after doing a FAISS query
"""
- def __init__(self, name, base_model=None):
+ def __init__(self, name, engine=None, base_model=None):
self.name = name
self.tables = {}
if base_model is None:
- engine = create_engine(connection_url)
+ self.engine = create_engine(connection_url)
base_model = declarative_base(engine)
self.base_model = base_model
def get_table(self, type):
if type in self.tables:
return self.tables[type]
- elif type == 'uuid':
+ elif type == 'uuids':
self.tables[type] = self.uuid_table()
elif type == 'roi':
self.tables[type] = self.roi_table()
@@ -96,3 +143,7 @@ class SqlDataset:
roll = Column(Float, nullable=False)
yaw = Column(Float, nullable=False)
return Pose
+
+
+# Session = sessionmaker(bind=engine)
+# session = Session()