summaryrefslogtreecommitdiff
path: root/check/app
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-04-25 18:29:46 +0200
committerJules Laplace <julescarbon@gmail.com>2019-04-25 18:29:46 +0200
commit4d5c3d59f32b80638d82373d33a476652520e260 (patch)
tree88edd56458963229511b54276586c236604504b6 /check/app
parent4f4df4d4e38f8ce27dc7e471359f9f644ca74092 (diff)
test API
Diffstat (limited to 'check/app')
-rw-r--r--check/app/models/sql_factory.py12
-rw-r--r--check/app/server/api.py59
-rw-r--r--check/app/utils/file_utils.py10
3 files changed, 71 insertions, 10 deletions
diff --git a/check/app/models/sql_factory.py b/check/app/models/sql_factory.py
index 1d32a68..68c2e30 100644
--- a/check/app/models/sql_factory.py
+++ b/check/app/models/sql_factory.py
@@ -44,11 +44,17 @@ class FileTable(Base):
Base.metadata.create_all(engine)
-def search_by_phash(phash, threshold=6):
+def search_by_phash(phash, threshold=6, limit=1):
"""Search files for a particular phash"""
connection = engine.connect()
- cmd = 'SELECT files.*, BIT_COUNT(phash ^ :phash) as hamming_distance FROM files HAVING hamming_distance < :threshold ORDER BY hamming_distance ASC LIMIT 1'
- matches = connection.execute(text(cmd), phash=phash, threshold=threshold).fetchall()
+ cmd = """
+ SELECT files.*, BIT_COUNT(phash ^ :phash)
+ AS hamming_distance FROM files
+ HAVING hamming_distance < :threshold
+ ORDER BY hamming_distance ASC
+ LIMIT :limit
+ """
+ matches = connection.execute(text(cmd), phash=phash, threshold=threshold, limit=limit).fetchall()
keys = ('id', 'sha256', 'phash', 'ext', 'score')
results = [ dict(zip(keys, values)) for values in matches ]
return results
diff --git a/check/app/server/api.py b/check/app/server/api.py
index c4f9f80..322d899 100644
--- a/check/app/server/api.py
+++ b/check/app/server/api.py
@@ -2,11 +2,13 @@ import os
import re
import time
import numpy as np
+import logging
from flask import Blueprint, request, jsonify
from PIL import Image
from app.models.sql_factory import search_by_phash, add_phash
-from app.utils.im_utils import pil2np
+from app.utils.im_utils import compute_phash_int
+from app.utils.file_utils import sha256_stream
sanitize_re = re.compile('[\W]+')
valid_exts = ['.gif', '.jpg', '.jpeg', '.png']
@@ -22,29 +24,72 @@ def index():
"""
return jsonify({ 'status': 'ok' })
-@api.route('/v1/match/', methods=['POST'])
-def upload():
+@api.route('/v1/match', methods=['POST'])
+def match():
"""
Search by uploading an image
"""
start = time.time()
+ logging.debug(start)
- file = request.files['query_img']
+ file = request.files['q']
fn = file.filename
if fn.endswith('blob'): # FIX PNG IMAGES?
fn = 'filename.jpg'
+ logging.debug(fn)
basename, ext = os.path.splitext(fn)
if ext.lower() not in valid_exts:
return jsonify({
+ 'success': False,
+ 'match': False,
'error': 'not_an_image'
})
+ ext = ext[1:].lower()
+
im = Image.open(file.stream).convert('RGB')
phash = compute_phash_int(im)
- threshold = request.args.get('threshold') || 6
+ logging.debug(phash)
+ try:
+ threshold = int(request.args.get('threshold') or 6)
+ limit = int(request.args.get('limit') or 1)
+ add = str(request.args.get('add') or 'true') == 'true'
+ except:
+ return jsonify({
+ 'success': False,
+ 'match': False,
+ 'error': 'param_error'
+ })
+
+ results = search_by_phash(phash=phash, threshold=threshold, limit=limit)
- res = search_by_phash(phash, threshold)
+ if len(results) == 0:
+ if add:
+ hash = sha256_stream(file)
+ add_phash(sha256=hash, phash=phash, ext=ext)
+ if limit == 1:
+ return jsonify({
+ 'success': True,
+ 'match': False,
+ })
+ else:
+ return jsonify({
+ 'success': True,
+ 'match': False,
+ 'results': [],
+ })
+
+ if limit > 1:
+ return jsonify({
+ 'success': True,
+ 'match': True,
+ 'results': results,
+ })
- return jsonify({ 'res': res })
+ return jsonify({
+ 'success': True,
+ 'match': True,
+ 'closest_match': results[0],
+ })
diff --git a/check/app/utils/file_utils.py b/check/app/utils/file_utils.py
index 1ed1833..a185cf4 100644
--- a/check/app/utils/file_utils.py
+++ b/check/app/utils/file_utils.py
@@ -352,6 +352,16 @@ def sha256(fp_in, block_size=65536):
sha256.update(block)
return sha256.hexdigest()
+def sha256_stream(stream, block_size=65536):
+ """Generates SHA256 hash for a file stream (from Flask)
+ :param fp_in: (FileStream) stream object
+ :param block_size: (int) byte size of block
+ :returns: (str) hash
+ """
+ sha256 = hashlib.sha256()
+ for block in iter(lambda: stream.read(block_size), b''):
+ sha256.update(block)
+ return sha256.hexdigest()
def sha256_tree(sha256):
"""Split hash into branches with tree-depth for faster file indexing