summaryrefslogtreecommitdiff
path: root/megapixels/app
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-12-14 18:10:27 +0100
committerJules Laplace <julescarbon@gmail.com>2018-12-14 18:10:27 +0100
commit38746f284b17400d4e2555509ea60df5912b824a (patch)
tree6dea93f4ba348d12a58a761424ec5547697fcf1f /megapixels/app
parent36b6082dfa768cbf35d40dc2c82706dfae0b687b (diff)
all the sql stuff communicating nicely
Diffstat (limited to 'megapixels/app')
-rw-r--r--megapixels/app/models/sql_factory.py61
-rw-r--r--megapixels/app/server/api.py72
-rw-r--r--megapixels/app/server/api/image.py40
-rw-r--r--megapixels/app/server/create.py23
4 files changed, 144 insertions, 52 deletions
diff --git a/megapixels/app/models/sql_factory.py b/megapixels/app/models/sql_factory.py
index 525492f1..2a18d6af 100644
--- a/megapixels/app/models/sql_factory.py
+++ b/megapixels/app/models/sql_factory.py
@@ -1,9 +1,15 @@
import os
+import glob
+import time
+import pandas as pd
from sqlalchemy import create_engine, Table, Column, String, Integer, DateTime, Float
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
+from app.utils.file_utils import load_recipe, load_csv_safe
+from app.settings import app_cfg as cfg
+
connection_url = "mysql+mysqldb://{}:{}@{}/{}".format(
os.getenv("DB_USER"),
os.getenv("DB_PASS"),
@@ -11,8 +17,49 @@ connection_url = "mysql+mysqldb://{}:{}@{}/{}".format(
os.getenv("DB_NAME")
)
-# Session = sessionmaker(bind=engine)
-# session = Session()
+datasets = {}
+loaded = False
+
+def list_datasets():
+ return [{
+ 'name': name,
+ 'tables': list(datasets[name].tables.keys()),
+ } for name in datasets.keys()]
+
+def get_dataset(name):
+ return datasets[name] if name in datasets else None
+
+def get_table(name, table_name):
+ dataset = get_dataset(name)
+ return dataset.get_table(table_name) if dataset else None
+
+def load_sql_datasets(replace=False, base_model=None):
+ global datasets, loaded
+ if loaded:
+ return datasets
+ engine = create_engine(connection_url) if replace else None
+ for path in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")):
+ dataset = load_sql_dataset(path, replace, engine, base_model)
+ datasets[dataset.name] = dataset
+ loaded = True
+ return datasets
+
+def load_sql_dataset(path, replace=False, engine=None, base_model=None):
+ name = os.path.basename(path)
+ dataset = SqlDataset(name, base_model=base_model)
+
+ for fn in glob.iglob(os.path.join(path, "*.csv")):
+ key = os.path.basename(fn).replace(".csv", "")
+ table = dataset.get_table(key)
+ if table is None:
+ continue
+ if replace:
+ print('loading dataset {}'.format(fn))
+ df = pd.read_csv(fn)
+ # fix columns that are named "index", a sql reserved word
+ df.columns = table.__table__.columns.keys()
+ df.to_sql(name=table.__tablename__, con=engine, if_exists='replace', index=False)
+ return dataset
class SqlDataset:
"""
@@ -21,18 +68,18 @@ class SqlDataset:
- names will be fixed to work in SQL (index -> id)
- we can then have more generic models for fetching this info after doing a FAISS query
"""
- def __init__(self, name, base_model=None):
+ def __init__(self, name, engine=None, base_model=None):
self.name = name
self.tables = {}
if base_model is None:
- engine = create_engine(connection_url)
+ self.engine = create_engine(connection_url)
base_model = declarative_base(engine)
self.base_model = base_model
def get_table(self, type):
if type in self.tables:
return self.tables[type]
- elif type == 'uuid':
+ elif type == 'uuids':
self.tables[type] = self.uuid_table()
elif type == 'roi':
self.tables[type] = self.roi_table()
@@ -96,3 +143,7 @@ class SqlDataset:
roll = Column(Float, nullable=False)
yaw = Column(Float, nullable=False)
return Pose
+
+
+# Session = sessionmaker(bind=engine)
+# session = Session()
diff --git a/megapixels/app/server/api.py b/megapixels/app/server/api.py
new file mode 100644
index 00000000..e7db11f1
--- /dev/null
+++ b/megapixels/app/server/api.py
@@ -0,0 +1,72 @@
+from flask import Blueprint, jsonify
+
+from app.models.sql_factory import list_datasets, get_dataset, get_table
+
+# from jinja2 import TemplateNotFound
+
+# import os
+# import sys
+# import json
+# import time
+# import argparse
+# import cv2 as cv
+# import numpy as np
+# from datetime import datetime
+# from flask import Flask, request, render_template, jsonify
+# from PIL import Image # todo: try to remove PIL dependency
+# import re
+
+# sanitize_re = re.compile('[\W]+')
+# valid_exts = ['.gif', '.jpg', '.jpeg', '.png']
+
+# from dotenv import load_dotenv
+# load_dotenv()
+
+# from feature_extractor import FeatureExtractor
+
+# DEFAULT_LIMIT = 50
+
+api = Blueprint('api', __name__)
+
+@api.route('/')
+def index():
+ return jsonify({ 'datasets': list_datasets() })
+
+@api.route('/dataset/<dataset>/test', methods=['POST'])
+def test(dataset='test'):
+ dataset = get_dataset(dataset)
+ print('hiiiiii')
+ return jsonify({ 'test': 'OK', 'dataset': dataset })
+
+# @router.route('/<dataset>/face', methods=['POST'])
+# def upload(name):
+# file = request.files['query_img']
+# fn = file.filename
+# if fn.endswith('blob'):
+# fn = 'filename.jpg'
+
+# basename, ext = os.path.splitext(fn)
+# print("got {}, type {}".format(basename, ext))
+# if ext.lower() not in valid_exts:
+# return jsonify({ 'error': 'not an image' })
+
+# uploaded_fn = datetime.now().isoformat() + "_" + basename
+# uploaded_fn = sanitize_re.sub('', uploaded_fn)
+# uploaded_img_path = "static/uploaded/" + uploaded_fn + ext
+# uploaded_img_path = uploaded_img_path.lower()
+# print('query: {}'.format(uploaded_img_path))
+
+# img = Image.open(file.stream).convert('RGB')
+# # img.save(uploaded_img_path)
+# # vec = db.load_feature_vector_from_file(uploaded_img_path)
+# vec = fe.extract(img)
+# # print(vec.shape)
+
+# results = db.search(vec, limit=limit)
+# query = {
+# 'timing': time.time() - start,
+# }
+# print(results)
+# return jsonify({
+# 'results': results,
+# })
diff --git a/megapixels/app/server/api/image.py b/megapixels/app/server/api/image.py
deleted file mode 100644
index f2f4a4f9..00000000
--- a/megapixels/app/server/api/image.py
+++ /dev/null
@@ -1,40 +0,0 @@
-from flask import Blueprint, render_template, abort
-# from jinja2 import TemplateNotFound
-
-router = Blueprint('image', __name__)
-
-@router.route('/<dataset>/test', methods=['POST'])
-def test(name):
- # dataset =
-@router.route('/<dataset>/face', methods=['POST'])
-def upload(name):
- file = request.files['query_img']
- fn = file.filename
- if fn.endswith('blob'):
- fn = 'filename.jpg'
-
- basename, ext = os.path.splitext(fn)
- print("got {}, type {}".format(basename, ext))
- if ext.lower() not in valid_exts:
- return jsonify({ 'error': 'not an image' })
-
- uploaded_fn = datetime.now().isoformat() + "_" + basename
- uploaded_fn = sanitize_re.sub('', uploaded_fn)
- uploaded_img_path = "static/uploaded/" + uploaded_fn + ext
- uploaded_img_path = uploaded_img_path.lower()
- print('query: {}'.format(uploaded_img_path))
-
- img = Image.open(file.stream).convert('RGB')
- # img.save(uploaded_img_path)
- # vec = db.load_feature_vector_from_file(uploaded_img_path)
- vec = fe.extract(img)
- # print(vec.shape)
-
- results = db.search(vec, limit=limit)
- query = {
- 'timing': time.time() - start,
- }
- print(results)
- return jsonify({
- 'results': results,
- })
diff --git a/megapixels/app/server/create.py b/megapixels/app/server/create.py
index 1119ee8f..9efed669 100644
--- a/megapixels/app/server/create.py
+++ b/megapixels/app/server/create.py
@@ -1,10 +1,8 @@
-from flask import Flask, Blueprint
+from flask import Flask, Blueprint, jsonify
from flask_sqlalchemy import SQLAlchemy
-from app.models.sql_factory import connection_url
+from app.models.sql_factory import connection_url, load_sql_datasets
-from app.server.api import router as api_router
-
-# from app.server.views.assets import assets
+from app.server.api import api
db = SQLAlchemy()
@@ -14,8 +12,10 @@ def create_app(script_info=None):
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
db.init_app(app)
- app.register_blueprint(api)
-
+ datasets = load_sql_datasets(replace=False, base_model=db.Model)
+
+ app.register_blueprint(api, url_prefix='/api')
+
@app.route('/', methods=['GET'])
def index():
return app.send_static_file('index.html')
@@ -24,4 +24,13 @@ def create_app(script_info=None):
def shell_context():
return { 'app': app, 'db': db }
+ @app.route("/site-map")
+ def site_map():
+ links = []
+ for rule in app.url_map.iter_rules():
+ # url = url_for(rule.endpoint, **(rule.defaults or {}))
+ # print(url)
+ links.append((rule.endpoint))
+ return(jsonify(links))
+
return app