summaryrefslogtreecommitdiff
path: root/check/app
diff options
context:
space:
mode:
Diffstat (limited to 'check/app')
-rw-r--r--check/app/models/sql_factory.py26
-rw-r--r--check/app/server/api.py124
-rw-r--r--check/app/server/create.py2
-rw-r--r--check/app/utils/file_utils.py10
4 files changed, 134 insertions, 28 deletions
diff --git a/check/app/models/sql_factory.py b/check/app/models/sql_factory.py
index 1d32a68..ad27f62 100644
--- a/check/app/models/sql_factory.py
+++ b/check/app/models/sql_factory.py
@@ -32,24 +32,32 @@ class FileTable(Base):
__tablename__ = 'files'
id = Column(Integer, primary_key=True)
sha256 = Column(String(64, convert_unicode=True), nullable=False)
- phash = Column(BigInteger, nullable=False)
+ phash = Column(BigInteger, nullable=False, index=True)
ext = Column(String(4, convert_unicode=True), nullable=False)
+ url = Column(String(255, convert_unicode=True), nullable=False)
def toJSON(self):
return {
'id': self.id,
'sha256': self.sha256,
'phash': self.phash,
'ext': self.ext,
+ 'url': self.url,
}
Base.metadata.create_all(engine)
-def search_by_phash(phash, threshold=6):
+def search_by_phash(phash, threshold=6, limit=1):
"""Search files for a particular phash"""
connection = engine.connect()
- cmd = 'SELECT files.*, BIT_COUNT(phash ^ :phash) as hamming_distance FROM files HAVING hamming_distance < :threshold ORDER BY hamming_distance ASC LIMIT 1'
- matches = connection.execute(text(cmd), phash=phash, threshold=threshold).fetchall()
- keys = ('id', 'sha256', 'phash', 'ext', 'score')
+ cmd = """
+ SELECT files.*, BIT_COUNT(phash ^ :phash)
+ AS hamming_distance FROM files
+ HAVING hamming_distance < :threshold
+ ORDER BY hamming_distance ASC
+ LIMIT :limit
+ """
+ matches = connection.execute(text(cmd), phash=phash, threshold=threshold, limit=limit).fetchall()
+ keys = ('id', 'sha256', 'phash', 'ext', 'url', 'score')
results = [ dict(zip(keys, values)) for values in matches ]
return results
@@ -58,11 +66,9 @@ def search_by_hash(hash):
match = session.query(FileTable).filter(FileTable.sha256 == hash)
return match.first()
-def add_phash(sha256, phash, ext):
+def add_phash(sha256=None, phash=None, ext=None, url=None):
"""Add a file to the table"""
- rec = FileTable(
- sha256=sha256, phash=phash, ext=ext,
- )
+ rec = FileTable(sha256=sha256, phash=phash, ext=ext, url=url)
session = Session()
session.add(rec)
session.commit()
@@ -87,4 +93,4 @@ def add_phash_by_filename(path):
hash = sha256(path)
- add_phash(sha256=hash, phash=phash, ext=ext)
+ add_phash(sha256=hash, phash=phash, ext=ext, url=path)
diff --git a/check/app/server/api.py b/check/app/server/api.py
index c4f9f80..66a0dd1 100644
--- a/check/app/server/api.py
+++ b/check/app/server/api.py
@@ -1,17 +1,25 @@
+import io
import os
import re
import time
import numpy as np
+import logging
+import urllib.request
from flask import Blueprint, request, jsonify
from PIL import Image
from app.models.sql_factory import search_by_phash, add_phash
-from app.utils.im_utils import pil2np
+from app.utils.im_utils import compute_phash_int
+from app.utils.file_utils import sha256_stream
sanitize_re = re.compile('[\W]+')
valid_exts = ['.gif', '.jpg', '.jpeg', '.png']
-LIMIT = 9
+MATCH_THRESHOLD = 20
+MATCH_LIMIT = 10
+
+SIMILAR_THRESHOLD = 20
+SIMILAR_LIMIT = 10
api = Blueprint('api', __name__)
@@ -22,29 +30,111 @@ def index():
"""
return jsonify({ 'status': 'ok' })
-@api.route('/v1/match/', methods=['POST'])
-def upload():
+def get_params(default_threshold=MATCH_THRESHOLD, default_limit=MATCH_LIMIT):
+ try:
+ threshold = int(request.form.get('threshold') or default_threshold)
+ limit = int(request.form.get('limit') or default_limit)
+ except:
+ return jsonify({
+ 'success': False,
+ 'match': False,
+ 'error': 'param_error'
+ })
+
+ if 'q' in request.files:
+ file = request.files['q']
+ fn = file.filename
+ if fn.endswith('blob'): # FIX PNG IMAGES?
+ logging.debug('received a blob, assuming JPEG')
+ fn = 'filename.jpg'
+
+ basename, ext = os.path.splitext(fn)
+ if ext.lower() not in valid_exts:
+ return jsonify({
+ 'success': False,
+ 'match': False,
+ 'error': 'not_an_image'
+ })
+
+ raw = None
+ im = Image.open(file.stream).convert('RGB')
+ else:
+ url = request.form.get('url')
+ if not url:
+ return jsonify({
+ 'success': False,
+ 'match': False,
+ 'error': 'no_image'
+ })
+ basename, ext = os.path.splitext(url)
+ if ext.lower() not in valid_exts:
+ return jsonify({
+ 'success': False,
+ 'match': False,
+ 'error': 'not_an_image'
+ })
+
+ remote_request = urllib.request.Request(url)
+ remote_response = urllib.request.urlopen(remote_request)
+ raw = remote_response.read()
+ im = Image.open(io.BytesIO(raw)).convert('RGB')
+
+@api.route('/v1/match', methods=['POST'])
+def match():
"""
Search by uploading an image
"""
start = time.time()
- file = request.files['query_img']
- fn = file.filename
- if fn.endswith('blob'): # FIX PNG IMAGES?
- fn = 'filename.jpg'
+ threshold, limit, raw, im = get_params()
- basename, ext = os.path.splitext(fn)
- if ext.lower() not in valid_exts:
- return jsonify({
- 'error': 'not_an_image'
- })
+ phash = compute_phash_int(im)
+ ext = ext[1:].lower()
+
+ results = search_by_phash(phash=phash, threshold=threshold, limit=limit)
+
+ if len(results) == 0:
+ if url:
+ # hash = sha256_stream(file)
+ hash = sha256_stream(io.BytesIO(raw))
+ add_phash(sha256=hash, phash=phash, ext=ext, url=url)
+ match = False
+ else:
+ match = True
+
+ logging.debug('query took {0:.2g} s.'.format(time.time() - start))
+
+ return jsonify({
+ 'success': True,
+ 'match': match,
+ 'results': results,
+ 'timing': time.time() - start,
+ })
+
+@api.route('/v1/similar', methods=['POST'])
+def similar():
+ """
+ Search by uploading an image
+ """
+ start = time.time()
+
+ threshold, limit, raw, im = get_params(default_threshold=SIMILARITY_THRESHOLD, default_limit=SIMILARITY_LIMIT)
- im = Image.open(file.stream).convert('RGB')
phash = compute_phash_int(im)
+ ext = ext[1:].lower()
+
+ results = search_by_phash(phash=phash, threshold=threshold, limit=limit)
- threshold = request.args.get('threshold') || 6
+ if len(results) == 0:
+ match = False
+ else:
+ match = True
- res = search_by_phash(phash, threshold)
+ logging.debug('query took {0:.2g} s.'.format(time.time() - start))
- return jsonify({ 'res': res })
+ return jsonify({
+ 'success': True,
+ 'match': match,
+ 'results': results,
+ 'timing': time.time() - start,
+ })
diff --git a/check/app/server/create.py b/check/app/server/create.py
index 788bfaa..6dab331 100644
--- a/check/app/server/create.py
+++ b/check/app/server/create.py
@@ -27,7 +27,7 @@ def create_app(script_info=None):
"""
functional pattern for creating the flask app
"""
- app = Flask(__name__, static_folder='static', static_url_path='')
+ app = Flask(__name__, static_folder='static', static_url_path='/static')
app.config['SQLALCHEMY_DATABASE_URI'] = connection_url
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['CELERY_BROKER_URL'] = cfg.CELERY_BROKER_URL
diff --git a/check/app/utils/file_utils.py b/check/app/utils/file_utils.py
index 1ed1833..a185cf4 100644
--- a/check/app/utils/file_utils.py
+++ b/check/app/utils/file_utils.py
@@ -352,6 +352,16 @@ def sha256(fp_in, block_size=65536):
sha256.update(block)
return sha256.hexdigest()
+def sha256_stream(stream, block_size=65536):
+ """Generates SHA256 hash for a file stream (from Flask)
+ :param fp_in: (FileStream) stream object
+ :param block_size: (int) byte size of block
+ :returns: (str) hash
+ """
+ sha256 = hashlib.sha256()
+ for block in iter(lambda: stream.read(block_size), b''):
+ sha256.update(block)
+ return sha256.hexdigest()
def sha256_tree(sha256):
"""Split hash into branches with tree-depth for faster file indexing