summaryrefslogtreecommitdiff
path: root/check
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-04-29 01:36:27 +0200
committerJules Laplace <julescarbon@gmail.com>2019-04-29 01:36:27 +0200
commit86e34a1bc43d1995e20c52fa639412c46105d400 (patch)
tree35623620556e6cfac2ca67e2b6f4f61cd2329e2a /check
parentdbfaa9024c844dd5c14259c858564e32149afd87 (diff)
import csvHEADmaster
Diffstat (limited to 'check')
-rw-r--r--check/app/server/api.py28
-rw-r--r--check/app/utils/file_utils.py1
-rw-r--r--check/app/utils/process_utils.py60
-rw-r--r--check/commands/phash/dedupe.py5
-rw-r--r--check/commands/phash/import_csv.py54
5 files changed, 135 insertions, 13 deletions
diff --git a/check/app/server/api.py b/check/app/server/api.py
index 324a63a..184e28d 100644
--- a/check/app/server/api.py
+++ b/check/app/server/api.py
@@ -30,6 +30,20 @@ def index():
"""
return jsonify({ 'status': 'ok' })
+def fetch_url(url):
+ if not url:
+ return None, 'no_image'
+ basename, ext = os.path.splitext(url)
+ if ext.lower() not in valid_exts:
+ return None, 'not_an_image'
+ ext = ext[1:].lower()
+
+ remote_request = urllib.request.Request(url)
+ remote_response = urllib.request.urlopen(remote_request)
+ raw = remote_response.read()
+ im = Image.open(io.BytesIO(raw)).convert('RGB')
+ return raw, im
+
def get_params(default_threshold=MATCH_THRESHOLD, default_limit=MATCH_LIMIT):
try:
threshold = int(request.form.get('threshold') or default_threshold)
@@ -58,17 +72,9 @@ def get_params(default_threshold=MATCH_THRESHOLD, default_limit=MATCH_LIMIT):
# Fetch remote URL
else:
url = request.form.get('url')
- if not url:
- return None, 'no_image'
- basename, ext = os.path.splitext(url)
- if ext.lower() not in valid_exts:
- return None, 'not_an_image'
- ext = ext[1:].lower()
-
- remote_request = urllib.request.Request(url)
- remote_response = urllib.request.urlopen(remote_request)
- raw = remote_response.read()
- im = Image.open(io.BytesIO(raw)).convert('RGB')
+ raw, im = fetch_url(url)
+ if raw is None:
+ return raw, im # error
return (threshold, limit, url, ext, raw, im,), None
@api.route('/v1/match', methods=['POST'])
diff --git a/check/app/utils/file_utils.py b/check/app/utils/file_utils.py
index a185cf4..5c9311e 100644
--- a/check/app/utils/file_utils.py
+++ b/check/app/utils/file_utils.py
@@ -247,7 +247,6 @@ def write_csv(data, fp_out, header=None):
for k, v in data.items():
fp.writerow('{},{}'.format(k, v))
-
def write_serialized_items(items, fp_out, ensure_path=True, minify=True, sort_keys=True):
"""Writes serialized data
:param items: (dict) a sha256 dict of MappingItems
diff --git a/check/app/utils/process_utils.py b/check/app/utils/process_utils.py
new file mode 100644
index 0000000..7f243ae
--- /dev/null
+++ b/check/app/utils/process_utils.py
@@ -0,0 +1,60 @@
+import os
+import pathos.pools as pp
+from tqdm import tqdm
+from concurrent.futures import ProcessPoolExecutor, as_completed
+
+def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=3):
+ """
+ A parallel version of the map function with a progress bar.
+
+ Args:
+ array (array-like): An array to iterate over.
+ function (function): A python function to apply to the elements of array
+ n_jobs (int, default=16): The number of cores to use
+ use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
+ keyword arguments to function
+ front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
+ Useful for catching bugs
+ Returns:
+ [function(array[0]), function(array[1]), ...]
+ """
+ #We run the first few iterations serially to catch bugs
+ if front_num > 0:
+ front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]]
+ #If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
+ if n_jobs==1:
+ return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
+ #Assemble the workers
+ with ProcessPoolExecutor(max_workers=n_jobs) as pool:
+ #Pass the elements of array into function
+ if use_kwargs:
+ futures = [pool.submit(function, **a) for a in array[front_num:]]
+ else:
+ futures = [pool.submit(function, a) for a in array[front_num:]]
+ kwargs = {
+ 'total': len(futures),
+ 'unit': 'it',
+ 'unit_scale': True,
+ 'leave': True
+ }
+ #Print out the progress as tasks complete
+ for f in tqdm(as_completed(futures), **kwargs):
+ pass
+ out = []
+ #Get the results from the futures.
+ for i, future in tqdm(enumerate(futures)):
+ try:
+ out.append(future.result())
+ except Exception as e:
+ out.append(e)
+ return front + out
+
+def parallelize(rows, func):
+ print("Processing {} items".format(len(rows)))
+ if hasattr(os, 'sched_getaffinity'):
+ processCount = len(os.sched_getaffinity(0))
+ else:
+ processCount = 4
+ print('processes {}'.format(processCount))
+ with pp.ProcessPool(processes=processCount) as pool:
+ pool.map(func, rows) \ No newline at end of file
diff --git a/check/commands/phash/dedupe.py b/check/commands/phash/dedupe.py
index 28266f4..6b8194b 100644
--- a/check/commands/phash/dedupe.py
+++ b/check/commands/phash/dedupe.py
@@ -17,7 +17,7 @@ from app.utils.file_utils import write_json, sha256
help="Input glob to search -- e.g. '../docs/images/*.jpg'")
@click.option('-o', '--output', 'opt_output_fn',
required=False,
- help="Input glob to search -- e.g. '../docs/images/*.jpg'")
+ help="Output filename")
@click.option('-t', '--threshold', 'opt_threshold',
required=True,
default=6,
@@ -36,10 +36,13 @@ def cli(ctx, opt_input_glob, opt_output_fn, opt_threshold):
phash = compute_phash(im)
if is_phash_new(fn, phash, seen, opt_threshold):
hash = sha256(fn)
+ fpart, ext = os.path.splitext(fn)
+ ext = ext[1:]
seen.append({
'sha256': hash,
'phash': phash,
'fn': fn,
+ 'ext': ext,
})
if opt_output_fn:
write_json(seen, opt_output_fn)
diff --git a/check/commands/phash/import_csv.py b/check/commands/phash/import_csv.py
new file mode 100644
index 0000000..5e09aa8
--- /dev/null
+++ b/check/commands/phash/import_csv.py
@@ -0,0 +1,54 @@
+"""
+Import a CSV of URLs
+"""
+
+import click
+import os
+import glob
+import io
+import random
+
+from PIL import Image
+
+from app.models.sql_factory import add_phash
+from app.utils.im_utils import compute_phash_int
+from app.utils.file_utils import load_csv, sha256_stream
+from app.utils.process_utils import parallelize
+from app.server.api import fetch_url
+
+@click.command()
+@click.option('-i', '--input', 'opt_input_fn',
+ required=True,
+ help="Input path to CSV")
+@click.option('-b', '--base_href', 'opt_base_href',
+ required=False,
+ default="",
+ help="Base href, default is empty string")
+@click.option('-e', '--field', 'opt_field',
+ required=False,
+ default="address",
+ help="Field in CSV containing URL")
+@click.pass_context
+def cli(ctx, opt_input_fn, opt_base_href, opt_field):
+ """
+ Import a folder of images, deduping them first
+ """
+ def add_url(url):
+ fname, ext = os.path.splitext(url)
+ if ext not in ['.gif', '.jpg', '.jpeg', '.png']:
+ return
+ ext = ext[1:]
+ try:
+ raw, im = fetch_url(url)
+ except:
+ # print('404 {}'.format(url))
+ return
+ print(url)
+ phash = compute_phash_int(im)
+ hash = sha256_stream(io.BytesIO(raw))
+ add_phash(sha256=hash, phash=phash, ext=ext, url=url)
+
+ rows = load_csv(opt_input_fn)
+ urls = [opt_base_href + row['address'] for row in rows]
+ random.shuffle(urls)
+ parallelize(urls, add_url)