1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
import os
import glob
import time
import pandas as pd
from PIL import Image
from sqlalchemy import create_engine, Table, Column, String, Integer, BigInteger, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from app.settings import app_cfg as cfg
from app.settings.types import VALID_IMAGE_EXTENSIONS
from app.utils.im_utils import compute_phash_int
from app.utils.file_utils import sha256
connection_url = "mysql+mysqlconnector://{}:{}@{}/{}?charset=utf8mb4".format(
os.getenv("DB_USER"),
os.getenv("DB_PASS"),
os.getenv("DB_HOST"),
os.getenv("DB_NAME")
)
loaded = False
engine = create_engine(connection_url, encoding="utf-8", pool_recycle=3600)
Session = sessionmaker(bind=engine)
Base = declarative_base()
class FileTable(Base):
"""Table for storing various hashes of images"""
__tablename__ = 'files'
id = Column(Integer, primary_key=True)
sha256 = Column(String(64, convert_unicode=True), nullable=False, unique=True)
phash = Column(BigInteger, nullable=False, index=True)
ext = Column(String(4, convert_unicode=True), nullable=False)
url = Column(String(255, convert_unicode=True), nullable=False)
def toJSON(self):
return {
'id': self.id,
'sha256': self.sha256,
'phash': self.phash,
'ext': self.ext,
'url': self.url,
}
Base.metadata.create_all(engine)
def search_by_phash(phash, threshold=6, limit=1):
"""Search files for a particular phash"""
connection = engine.connect()
cmd = """
SELECT files.*, BIT_COUNT(phash ^ :phash)
AS hamming_distance FROM files
HAVING hamming_distance < :threshold
ORDER BY hamming_distance ASC
LIMIT :limit
"""
matches = connection.execute(text(cmd), phash=phash, threshold=threshold, limit=limit).fetchall()
keys = ('id', 'sha256', 'phash', 'ext', 'url', 'score')
results = [ dict(zip(keys, values)) for values in matches ]
return results
def search_by_hash(hash):
session = Session()
match = session.query(FileTable).filter(FileTable.sha256 == hash)
return match.first()
def add_phash(sha256=None, phash=None, ext=None, url=None):
"""Add a file to the table"""
rec = FileTable(sha256=sha256, phash=phash, ext=ext, url=url)
session = Session()
session.add(rec)
session.commit()
session.flush()
def add_phash_by_filename(path):
"""Add a file by filename, getting all the necessary attributes"""
print(path)
if not os.path.exists(path):
print("File does not exist")
return
dir, fn = os.path.split(path)
root, ext = os.path.splitext(fn)
ext = ext.strip('.')
if ext not in VALID_IMAGE_EXTENSIONS:
print("Not an image file")
return
im = Image.open(path).convert('RGB')
phash = compute_phash_int(im)
hash = sha256(path)
add_phash(sha256=hash, phash=phash, ext=ext, url=path)
|