diff options
Diffstat (limited to 'check/app')
| -rw-r--r-- | check/app/models/sql_factory.py | 26 | ||||
| -rw-r--r-- | check/app/server/api.py | 124 | ||||
| -rw-r--r-- | check/app/server/create.py | 2 | ||||
| -rw-r--r-- | check/app/utils/file_utils.py | 10 |
4 files changed, 134 insertions, 28 deletions
diff --git a/check/app/models/sql_factory.py b/check/app/models/sql_factory.py index 1d32a68..ad27f62 100644 --- a/check/app/models/sql_factory.py +++ b/check/app/models/sql_factory.py @@ -32,24 +32,32 @@ class FileTable(Base): __tablename__ = 'files' id = Column(Integer, primary_key=True) sha256 = Column(String(64, convert_unicode=True), nullable=False) - phash = Column(BigInteger, nullable=False) + phash = Column(BigInteger, nullable=False, index=True) ext = Column(String(4, convert_unicode=True), nullable=False) + url = Column(String(255, convert_unicode=True), nullable=False) def toJSON(self): return { 'id': self.id, 'sha256': self.sha256, 'phash': self.phash, 'ext': self.ext, + 'url': self.url, } Base.metadata.create_all(engine) -def search_by_phash(phash, threshold=6): +def search_by_phash(phash, threshold=6, limit=1): """Search files for a particular phash""" connection = engine.connect() - cmd = 'SELECT files.*, BIT_COUNT(phash ^ :phash) as hamming_distance FROM files HAVING hamming_distance < :threshold ORDER BY hamming_distance ASC LIMIT 1' - matches = connection.execute(text(cmd), phash=phash, threshold=threshold).fetchall() - keys = ('id', 'sha256', 'phash', 'ext', 'score') + cmd = """ + SELECT files.*, BIT_COUNT(phash ^ :phash) + AS hamming_distance FROM files + HAVING hamming_distance < :threshold + ORDER BY hamming_distance ASC + LIMIT :limit + """ + matches = connection.execute(text(cmd), phash=phash, threshold=threshold, limit=limit).fetchall() + keys = ('id', 'sha256', 'phash', 'ext', 'url', 'score') results = [ dict(zip(keys, values)) for values in matches ] return results @@ -58,11 +66,9 @@ def search_by_hash(hash): match = session.query(FileTable).filter(FileTable.sha256 == hash) return match.first() -def add_phash(sha256, phash, ext): +def add_phash(sha256=None, phash=None, ext=None, url=None): """Add a file to the table""" - rec = FileTable( - sha256=sha256, phash=phash, ext=ext, - ) + rec = FileTable(sha256=sha256, phash=phash, ext=ext, url=url) session = Session() session.add(rec) session.commit() @@ -87,4 +93,4 @@ def add_phash_by_filename(path): hash = sha256(path) - add_phash(sha256=hash, phash=phash, ext=ext) + add_phash(sha256=hash, phash=phash, ext=ext, url=path) diff --git a/check/app/server/api.py b/check/app/server/api.py index c4f9f80..66a0dd1 100644 --- a/check/app/server/api.py +++ b/check/app/server/api.py @@ -1,17 +1,25 @@ +import io import os import re import time import numpy as np +import logging +import urllib.request from flask import Blueprint, request, jsonify from PIL import Image from app.models.sql_factory import search_by_phash, add_phash -from app.utils.im_utils import pil2np +from app.utils.im_utils import compute_phash_int +from app.utils.file_utils import sha256_stream sanitize_re = re.compile('[\W]+') valid_exts = ['.gif', '.jpg', '.jpeg', '.png'] -LIMIT = 9 +MATCH_THRESHOLD = 20 +MATCH_LIMIT = 10 + +SIMILAR_THRESHOLD = 20 +SIMILAR_LIMIT = 10 api = Blueprint('api', __name__) @@ -22,29 +30,111 @@ def index(): """ return jsonify({ 'status': 'ok' }) -@api.route('/v1/match/', methods=['POST']) -def upload(): +def get_params(default_threshold=MATCH_THRESHOLD, default_limit=MATCH_LIMIT): + try: + threshold = int(request.form.get('threshold') or default_threshold) + limit = int(request.form.get('limit') or default_limit) + except: + return jsonify({ + 'success': False, + 'match': False, + 'error': 'param_error' + }) + + if 'q' in request.files: + file = request.files['q'] + fn = file.filename + if fn.endswith('blob'): # FIX PNG IMAGES? + logging.debug('received a blob, assuming JPEG') + fn = 'filename.jpg' + + basename, ext = os.path.splitext(fn) + if ext.lower() not in valid_exts: + return jsonify({ + 'success': False, + 'match': False, + 'error': 'not_an_image' + }) + + raw = None + im = Image.open(file.stream).convert('RGB') + else: + url = request.form.get('url') + if not url: + return jsonify({ + 'success': False, + 'match': False, + 'error': 'no_image' + }) + basename, ext = os.path.splitext(url) + if ext.lower() not in valid_exts: + return jsonify({ + 'success': False, + 'match': False, + 'error': 'not_an_image' + }) + + remote_request = urllib.request.Request(url) + remote_response = urllib.request.urlopen(remote_request) + raw = remote_response.read() + im = Image.open(io.BytesIO(raw)).convert('RGB') + +@api.route('/v1/match', methods=['POST']) +def match(): """ Search by uploading an image """ start = time.time() - file = request.files['query_img'] - fn = file.filename - if fn.endswith('blob'): # FIX PNG IMAGES? - fn = 'filename.jpg' + threshold, limit, raw, im = get_params() - basename, ext = os.path.splitext(fn) - if ext.lower() not in valid_exts: - return jsonify({ - 'error': 'not_an_image' - }) + phash = compute_phash_int(im) + ext = ext[1:].lower() + + results = search_by_phash(phash=phash, threshold=threshold, limit=limit) + + if len(results) == 0: + if url: + # hash = sha256_stream(file) + hash = sha256_stream(io.BytesIO(raw)) + add_phash(sha256=hash, phash=phash, ext=ext, url=url) + match = False + else: + match = True + + logging.debug('query took {0:.2g} s.'.format(time.time() - start)) + + return jsonify({ + 'success': True, + 'match': match, + 'results': results, + 'timing': time.time() - start, + }) + +@api.route('/v1/similar', methods=['POST']) +def similar(): + """ + Search by uploading an image + """ + start = time.time() + + threshold, limit, raw, im = get_params(default_threshold=SIMILARITY_THRESHOLD, default_limit=SIMILARITY_LIMIT) - im = Image.open(file.stream).convert('RGB') phash = compute_phash_int(im) + ext = ext[1:].lower() + + results = search_by_phash(phash=phash, threshold=threshold, limit=limit) - threshold = request.args.get('threshold') || 6 + if len(results) == 0: + match = False + else: + match = True - res = search_by_phash(phash, threshold) + logging.debug('query took {0:.2g} s.'.format(time.time() - start)) - return jsonify({ 'res': res }) + return jsonify({ + 'success': True, + 'match': match, + 'results': results, + 'timing': time.time() - start, + }) diff --git a/check/app/server/create.py b/check/app/server/create.py index 788bfaa..6dab331 100644 --- a/check/app/server/create.py +++ b/check/app/server/create.py @@ -27,7 +27,7 @@ def create_app(script_info=None): """ functional pattern for creating the flask app """ - app = Flask(__name__, static_folder='static', static_url_path='') + app = Flask(__name__, static_folder='static', static_url_path='/static') app.config['SQLALCHEMY_DATABASE_URI'] = connection_url app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['CELERY_BROKER_URL'] = cfg.CELERY_BROKER_URL diff --git a/check/app/utils/file_utils.py b/check/app/utils/file_utils.py index 1ed1833..a185cf4 100644 --- a/check/app/utils/file_utils.py +++ b/check/app/utils/file_utils.py @@ -352,6 +352,16 @@ def sha256(fp_in, block_size=65536): sha256.update(block) return sha256.hexdigest() +def sha256_stream(stream, block_size=65536): + """Generates SHA256 hash for a file stream (from Flask) + :param fp_in: (FileStream) stream object + :param block_size: (int) byte size of block + :returns: (str) hash + """ + sha256 = hashlib.sha256() + for block in iter(lambda: stream.read(block_size), b''): + sha256.update(block) + return sha256.hexdigest() def sha256_tree(sha256): """Split hash into branches with tree-depth for faster file indexing |
