summaryrefslogtreecommitdiff
path: root/check/app/models/sql_factory.py
blob: 68c2e30fc0f6275806c5016b6ae19b37b9709d7d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import glob
import time
import pandas as pd

from PIL import Image

from sqlalchemy import create_engine, Table, Column, String, Integer, BigInteger, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base

from app.settings import app_cfg as cfg
from app.settings.types import VALID_IMAGE_EXTENSIONS

from app.utils.im_utils import compute_phash_int
from app.utils.file_utils import sha256

connection_url = "mysql+mysqlconnector://{}:{}@{}/{}?charset=utf8mb4".format(
  os.getenv("DB_USER"),
  os.getenv("DB_PASS"),
  os.getenv("DB_HOST"),
  os.getenv("DB_NAME")
)

loaded = False
engine = create_engine(connection_url, encoding="utf-8", pool_recycle=3600)
Session = sessionmaker(bind=engine)
Base = declarative_base()

class FileTable(Base):
  """Table for storing various hashes of images"""
  __tablename__ = 'files'
  id = Column(Integer, primary_key=True)
  sha256 = Column(String(64, convert_unicode=True), nullable=False)
  phash = Column(BigInteger, nullable=False)
  ext = Column(String(4, convert_unicode=True), nullable=False)
  def toJSON(self):
    return {
      'id': self.id,
      'sha256': self.sha256,
      'phash': self.phash,
      'ext': self.ext,
    }

Base.metadata.create_all(engine)

def search_by_phash(phash, threshold=6, limit=1):
  """Search files for a particular phash"""
  connection = engine.connect()
  cmd = """
    SELECT files.*, BIT_COUNT(phash ^ :phash) 
    AS hamming_distance FROM files 
    HAVING hamming_distance < :threshold 
    ORDER BY hamming_distance ASC 
    LIMIT :limit
  """
  matches = connection.execute(text(cmd), phash=phash, threshold=threshold, limit=limit).fetchall()
  keys = ('id', 'sha256', 'phash', 'ext', 'score')
  results = [ dict(zip(keys, values)) for values in matches ]
  return results

def search_by_hash(hash):
  session = Session()
  match = session.query(FileTable).filter(FileTable.sha256 == hash)
  return match.first()

def add_phash(sha256, phash, ext):
  """Add a file to the table"""
  rec = FileTable(
    sha256=sha256, phash=phash, ext=ext,
  )
  session = Session()
  session.add(rec)
  session.commit()
  session.flush()

def add_phash_by_filename(path):
  """Add a file by filename, getting all the necessary attributes"""
  print(path)
  if not os.path.exists(path):
    print("File does not exist")
    return

  dir, fn = os.path.split(path)
  root, ext = os.path.splitext(fn)
  ext = ext.strip('.')
  if ext not in VALID_IMAGE_EXTENSIONS:
    print("Not an image file")
    return

  im = Image.open(path).convert('RGB')
  phash = compute_phash_int(im)

  hash = sha256(path)

  add_phash(sha256=hash, phash=phash, ext=ext)