diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-12-14 18:10:27 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-12-14 18:10:27 +0100 |
| commit | 38746f284b17400d4e2555509ea60df5912b824a (patch) | |
| tree | 6dea93f4ba348d12a58a761424ec5547697fcf1f /megapixels/app | |
| parent | 36b6082dfa768cbf35d40dc2c82706dfae0b687b (diff) | |
all the sql stuff communicating nicely
Diffstat (limited to 'megapixels/app')
| -rw-r--r-- | megapixels/app/models/sql_factory.py | 61 | ||||
| -rw-r--r-- | megapixels/app/server/api.py | 72 | ||||
| -rw-r--r-- | megapixels/app/server/api/image.py | 40 | ||||
| -rw-r--r-- | megapixels/app/server/create.py | 23 |
4 files changed, 144 insertions, 52 deletions
diff --git a/megapixels/app/models/sql_factory.py b/megapixels/app/models/sql_factory.py index 525492f1..2a18d6af 100644 --- a/megapixels/app/models/sql_factory.py +++ b/megapixels/app/models/sql_factory.py @@ -1,9 +1,15 @@ import os +import glob +import time +import pandas as pd from sqlalchemy import create_engine, Table, Column, String, Integer, DateTime, Float from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base +from app.utils.file_utils import load_recipe, load_csv_safe +from app.settings import app_cfg as cfg + connection_url = "mysql+mysqldb://{}:{}@{}/{}".format( os.getenv("DB_USER"), os.getenv("DB_PASS"), @@ -11,8 +17,49 @@ connection_url = "mysql+mysqldb://{}:{}@{}/{}".format( os.getenv("DB_NAME") ) -# Session = sessionmaker(bind=engine) -# session = Session() +datasets = {} +loaded = False + +def list_datasets(): + return [{ + 'name': name, + 'tables': list(datasets[name].tables.keys()), + } for name in datasets.keys()] + +def get_dataset(name): + return datasets[name] if name in datasets else None + +def get_table(name, table_name): + dataset = get_dataset(name) + return dataset.get_table(table_name) if dataset else None + +def load_sql_datasets(replace=False, base_model=None): + global datasets, loaded + if loaded: + return datasets + engine = create_engine(connection_url) if replace else None + for path in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")): + dataset = load_sql_dataset(path, replace, engine, base_model) + datasets[dataset.name] = dataset + loaded = True + return datasets + +def load_sql_dataset(path, replace=False, engine=None, base_model=None): + name = os.path.basename(path) + dataset = SqlDataset(name, base_model=base_model) + + for fn in glob.iglob(os.path.join(path, "*.csv")): + key = os.path.basename(fn).replace(".csv", "") + table = dataset.get_table(key) + if table is None: + continue + if replace: + print('loading dataset {}'.format(fn)) + df = pd.read_csv(fn) + # fix columns that are named "index", a sql reserved word + df.columns = table.__table__.columns.keys() + df.to_sql(name=table.__tablename__, con=engine, if_exists='replace', index=False) + return dataset class SqlDataset: """ @@ -21,18 +68,18 @@ class SqlDataset: - names will be fixed to work in SQL (index -> id) - we can then have more generic models for fetching this info after doing a FAISS query """ - def __init__(self, name, base_model=None): + def __init__(self, name, engine=None, base_model=None): self.name = name self.tables = {} if base_model is None: - engine = create_engine(connection_url) + self.engine = create_engine(connection_url) base_model = declarative_base(engine) self.base_model = base_model def get_table(self, type): if type in self.tables: return self.tables[type] - elif type == 'uuid': + elif type == 'uuids': self.tables[type] = self.uuid_table() elif type == 'roi': self.tables[type] = self.roi_table() @@ -96,3 +143,7 @@ class SqlDataset: roll = Column(Float, nullable=False) yaw = Column(Float, nullable=False) return Pose + + +# Session = sessionmaker(bind=engine) +# session = Session() diff --git a/megapixels/app/server/api.py b/megapixels/app/server/api.py new file mode 100644 index 00000000..e7db11f1 --- /dev/null +++ b/megapixels/app/server/api.py @@ -0,0 +1,72 @@ +from flask import Blueprint, jsonify + +from app.models.sql_factory import list_datasets, get_dataset, get_table + +# from jinja2 import TemplateNotFound + +# import os +# import sys +# import json +# import time +# import argparse +# import cv2 as cv +# import numpy as np +# from datetime import datetime +# from flask import Flask, request, render_template, jsonify +# from PIL import Image # todo: try to remove PIL dependency +# import re + +# sanitize_re = re.compile('[\W]+') +# valid_exts = ['.gif', '.jpg', '.jpeg', '.png'] + +# from dotenv import load_dotenv +# load_dotenv() + +# from feature_extractor import FeatureExtractor + +# DEFAULT_LIMIT = 50 + +api = Blueprint('api', __name__) + +@api.route('/') +def index(): + return jsonify({ 'datasets': list_datasets() }) + +@api.route('/dataset/<dataset>/test', methods=['POST']) +def test(dataset='test'): + dataset = get_dataset(dataset) + print('hiiiiii') + return jsonify({ 'test': 'OK', 'dataset': dataset }) + +# @router.route('/<dataset>/face', methods=['POST']) +# def upload(name): +# file = request.files['query_img'] +# fn = file.filename +# if fn.endswith('blob'): +# fn = 'filename.jpg' + +# basename, ext = os.path.splitext(fn) +# print("got {}, type {}".format(basename, ext)) +# if ext.lower() not in valid_exts: +# return jsonify({ 'error': 'not an image' }) + +# uploaded_fn = datetime.now().isoformat() + "_" + basename +# uploaded_fn = sanitize_re.sub('', uploaded_fn) +# uploaded_img_path = "static/uploaded/" + uploaded_fn + ext +# uploaded_img_path = uploaded_img_path.lower() +# print('query: {}'.format(uploaded_img_path)) + +# img = Image.open(file.stream).convert('RGB') +# # img.save(uploaded_img_path) +# # vec = db.load_feature_vector_from_file(uploaded_img_path) +# vec = fe.extract(img) +# # print(vec.shape) + +# results = db.search(vec, limit=limit) +# query = { +# 'timing': time.time() - start, +# } +# print(results) +# return jsonify({ +# 'results': results, +# }) diff --git a/megapixels/app/server/api/image.py b/megapixels/app/server/api/image.py deleted file mode 100644 index f2f4a4f9..00000000 --- a/megapixels/app/server/api/image.py +++ /dev/null @@ -1,40 +0,0 @@ -from flask import Blueprint, render_template, abort -# from jinja2 import TemplateNotFound - -router = Blueprint('image', __name__) - -@router.route('/<dataset>/test', methods=['POST']) -def test(name): - # dataset = -@router.route('/<dataset>/face', methods=['POST']) -def upload(name): - file = request.files['query_img'] - fn = file.filename - if fn.endswith('blob'): - fn = 'filename.jpg' - - basename, ext = os.path.splitext(fn) - print("got {}, type {}".format(basename, ext)) - if ext.lower() not in valid_exts: - return jsonify({ 'error': 'not an image' }) - - uploaded_fn = datetime.now().isoformat() + "_" + basename - uploaded_fn = sanitize_re.sub('', uploaded_fn) - uploaded_img_path = "static/uploaded/" + uploaded_fn + ext - uploaded_img_path = uploaded_img_path.lower() - print('query: {}'.format(uploaded_img_path)) - - img = Image.open(file.stream).convert('RGB') - # img.save(uploaded_img_path) - # vec = db.load_feature_vector_from_file(uploaded_img_path) - vec = fe.extract(img) - # print(vec.shape) - - results = db.search(vec, limit=limit) - query = { - 'timing': time.time() - start, - } - print(results) - return jsonify({ - 'results': results, - }) diff --git a/megapixels/app/server/create.py b/megapixels/app/server/create.py index 1119ee8f..9efed669 100644 --- a/megapixels/app/server/create.py +++ b/megapixels/app/server/create.py @@ -1,10 +1,8 @@ -from flask import Flask, Blueprint +from flask import Flask, Blueprint, jsonify from flask_sqlalchemy import SQLAlchemy -from app.models.sql_factory import connection_url +from app.models.sql_factory import connection_url, load_sql_datasets -from app.server.api import router as api_router - -# from app.server.views.assets import assets +from app.server.api import api db = SQLAlchemy() @@ -14,8 +12,10 @@ def create_app(script_info=None): app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False db.init_app(app) - app.register_blueprint(api) - + datasets = load_sql_datasets(replace=False, base_model=db.Model) + + app.register_blueprint(api, url_prefix='/api') + @app.route('/', methods=['GET']) def index(): return app.send_static_file('index.html') @@ -24,4 +24,13 @@ def create_app(script_info=None): def shell_context(): return { 'app': app, 'db': db } + @app.route("/site-map") + def site_map(): + links = [] + for rule in app.url_map.iter_rules(): + # url = url_for(rule.endpoint, **(rule.defaults or {})) + # print(url) + links.append((rule.endpoint)) + return(jsonify(links)) + return app |
