diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2019-04-25 18:29:46 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2019-04-25 18:29:46 +0200 |
| commit | 4d5c3d59f32b80638d82373d33a476652520e260 (patch) | |
| tree | 88edd56458963229511b54276586c236604504b6 /check/app | |
| parent | 4f4df4d4e38f8ce27dc7e471359f9f644ca74092 (diff) | |
test API
Diffstat (limited to 'check/app')
| -rw-r--r-- | check/app/models/sql_factory.py | 12 | ||||
| -rw-r--r-- | check/app/server/api.py | 59 | ||||
| -rw-r--r-- | check/app/utils/file_utils.py | 10 |
3 files changed, 71 insertions, 10 deletions
diff --git a/check/app/models/sql_factory.py b/check/app/models/sql_factory.py index 1d32a68..68c2e30 100644 --- a/check/app/models/sql_factory.py +++ b/check/app/models/sql_factory.py @@ -44,11 +44,17 @@ class FileTable(Base): Base.metadata.create_all(engine) -def search_by_phash(phash, threshold=6): +def search_by_phash(phash, threshold=6, limit=1): """Search files for a particular phash""" connection = engine.connect() - cmd = 'SELECT files.*, BIT_COUNT(phash ^ :phash) as hamming_distance FROM files HAVING hamming_distance < :threshold ORDER BY hamming_distance ASC LIMIT 1' - matches = connection.execute(text(cmd), phash=phash, threshold=threshold).fetchall() + cmd = """ + SELECT files.*, BIT_COUNT(phash ^ :phash) + AS hamming_distance FROM files + HAVING hamming_distance < :threshold + ORDER BY hamming_distance ASC + LIMIT :limit + """ + matches = connection.execute(text(cmd), phash=phash, threshold=threshold, limit=limit).fetchall() keys = ('id', 'sha256', 'phash', 'ext', 'score') results = [ dict(zip(keys, values)) for values in matches ] return results diff --git a/check/app/server/api.py b/check/app/server/api.py index c4f9f80..322d899 100644 --- a/check/app/server/api.py +++ b/check/app/server/api.py @@ -2,11 +2,13 @@ import os import re import time import numpy as np +import logging from flask import Blueprint, request, jsonify from PIL import Image from app.models.sql_factory import search_by_phash, add_phash -from app.utils.im_utils import pil2np +from app.utils.im_utils import compute_phash_int +from app.utils.file_utils import sha256_stream sanitize_re = re.compile('[\W]+') valid_exts = ['.gif', '.jpg', '.jpeg', '.png'] @@ -22,29 +24,72 @@ def index(): """ return jsonify({ 'status': 'ok' }) -@api.route('/v1/match/', methods=['POST']) -def upload(): +@api.route('/v1/match', methods=['POST']) +def match(): """ Search by uploading an image """ start = time.time() + logging.debug(start) - file = request.files['query_img'] + file = request.files['q'] fn = file.filename if fn.endswith('blob'): # FIX PNG IMAGES? fn = 'filename.jpg' + logging.debug(fn) basename, ext = os.path.splitext(fn) if ext.lower() not in valid_exts: return jsonify({ + 'success': False, + 'match': False, 'error': 'not_an_image' }) + ext = ext[1:].lower() + im = Image.open(file.stream).convert('RGB') phash = compute_phash_int(im) - threshold = request.args.get('threshold') || 6 + logging.debug(phash) + try: + threshold = int(request.args.get('threshold') or 6) + limit = int(request.args.get('limit') or 1) + add = str(request.args.get('add') or 'true') == 'true' + except: + return jsonify({ + 'success': False, + 'match': False, + 'error': 'param_error' + }) + + results = search_by_phash(phash=phash, threshold=threshold, limit=limit) - res = search_by_phash(phash, threshold) + if len(results) == 0: + if add: + hash = sha256_stream(file) + add_phash(sha256=hash, phash=phash, ext=ext) + if limit == 1: + return jsonify({ + 'success': True, + 'match': False, + }) + else: + return jsonify({ + 'success': True, + 'match': False, + 'results': [], + }) + + if limit > 1: + return jsonify({ + 'success': True, + 'match': True, + 'results': results, + }) - return jsonify({ 'res': res }) + return jsonify({ + 'success': True, + 'match': True, + 'closest_match': results[0], + }) diff --git a/check/app/utils/file_utils.py b/check/app/utils/file_utils.py index 1ed1833..a185cf4 100644 --- a/check/app/utils/file_utils.py +++ b/check/app/utils/file_utils.py @@ -352,6 +352,16 @@ def sha256(fp_in, block_size=65536): sha256.update(block) return sha256.hexdigest() +def sha256_stream(stream, block_size=65536): + """Generates SHA256 hash for a file stream (from Flask) + :param fp_in: (FileStream) stream object + :param block_size: (int) byte size of block + :returns: (str) hash + """ + sha256 = hashlib.sha256() + for block in iter(lambda: stream.read(block_size), b''): + sha256.update(block) + return sha256.hexdigest() def sha256_tree(sha256): """Split hash into branches with tree-depth for faster file indexing |
