From 86e34a1bc43d1995e20c52fa639412c46105d400 Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Mon, 29 Apr 2019 01:36:27 +0200 Subject: import csv --- check/app/server/api.py | 28 +++++++++++------- check/app/utils/file_utils.py | 1 - check/app/utils/process_utils.py | 60 ++++++++++++++++++++++++++++++++++++++ check/commands/phash/dedupe.py | 5 +++- check/commands/phash/import_csv.py | 54 ++++++++++++++++++++++++++++++++++ 5 files changed, 135 insertions(+), 13 deletions(-) create mode 100644 check/app/utils/process_utils.py create mode 100644 check/commands/phash/import_csv.py diff --git a/check/app/server/api.py b/check/app/server/api.py index 324a63a..184e28d 100644 --- a/check/app/server/api.py +++ b/check/app/server/api.py @@ -30,6 +30,20 @@ def index(): """ return jsonify({ 'status': 'ok' }) +def fetch_url(url): + if not url: + return None, 'no_image' + basename, ext = os.path.splitext(url) + if ext.lower() not in valid_exts: + return None, 'not_an_image' + ext = ext[1:].lower() + + remote_request = urllib.request.Request(url) + remote_response = urllib.request.urlopen(remote_request) + raw = remote_response.read() + im = Image.open(io.BytesIO(raw)).convert('RGB') + return raw, im + def get_params(default_threshold=MATCH_THRESHOLD, default_limit=MATCH_LIMIT): try: threshold = int(request.form.get('threshold') or default_threshold) @@ -58,17 +72,9 @@ def get_params(default_threshold=MATCH_THRESHOLD, default_limit=MATCH_LIMIT): # Fetch remote URL else: url = request.form.get('url') - if not url: - return None, 'no_image' - basename, ext = os.path.splitext(url) - if ext.lower() not in valid_exts: - return None, 'not_an_image' - ext = ext[1:].lower() - - remote_request = urllib.request.Request(url) - remote_response = urllib.request.urlopen(remote_request) - raw = remote_response.read() - im = Image.open(io.BytesIO(raw)).convert('RGB') + raw, im = fetch_url(url) + if raw is None: + return raw, im # error return (threshold, limit, url, ext, raw, im,), None @api.route('/v1/match', methods=['POST']) diff --git a/check/app/utils/file_utils.py b/check/app/utils/file_utils.py index a185cf4..5c9311e 100644 --- a/check/app/utils/file_utils.py +++ b/check/app/utils/file_utils.py @@ -247,7 +247,6 @@ def write_csv(data, fp_out, header=None): for k, v in data.items(): fp.writerow('{},{}'.format(k, v)) - def write_serialized_items(items, fp_out, ensure_path=True, minify=True, sort_keys=True): """Writes serialized data :param items: (dict) a sha256 dict of MappingItems diff --git a/check/app/utils/process_utils.py b/check/app/utils/process_utils.py new file mode 100644 index 0000000..7f243ae --- /dev/null +++ b/check/app/utils/process_utils.py @@ -0,0 +1,60 @@ +import os +import pathos.pools as pp +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor, as_completed + +def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=3): + """ + A parallel version of the map function with a progress bar. + + Args: + array (array-like): An array to iterate over. + function (function): A python function to apply to the elements of array + n_jobs (int, default=16): The number of cores to use + use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of + keyword arguments to function + front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job. + Useful for catching bugs + Returns: + [function(array[0]), function(array[1]), ...] + """ + #We run the first few iterations serially to catch bugs + if front_num > 0: + front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]] + #If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging. + if n_jobs==1: + return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])] + #Assemble the workers + with ProcessPoolExecutor(max_workers=n_jobs) as pool: + #Pass the elements of array into function + if use_kwargs: + futures = [pool.submit(function, **a) for a in array[front_num:]] + else: + futures = [pool.submit(function, a) for a in array[front_num:]] + kwargs = { + 'total': len(futures), + 'unit': 'it', + 'unit_scale': True, + 'leave': True + } + #Print out the progress as tasks complete + for f in tqdm(as_completed(futures), **kwargs): + pass + out = [] + #Get the results from the futures. + for i, future in tqdm(enumerate(futures)): + try: + out.append(future.result()) + except Exception as e: + out.append(e) + return front + out + +def parallelize(rows, func): + print("Processing {} items".format(len(rows))) + if hasattr(os, 'sched_getaffinity'): + processCount = len(os.sched_getaffinity(0)) + else: + processCount = 4 + print('processes {}'.format(processCount)) + with pp.ProcessPool(processes=processCount) as pool: + pool.map(func, rows) \ No newline at end of file diff --git a/check/commands/phash/dedupe.py b/check/commands/phash/dedupe.py index 28266f4..6b8194b 100644 --- a/check/commands/phash/dedupe.py +++ b/check/commands/phash/dedupe.py @@ -17,7 +17,7 @@ from app.utils.file_utils import write_json, sha256 help="Input glob to search -- e.g. '../docs/images/*.jpg'") @click.option('-o', '--output', 'opt_output_fn', required=False, - help="Input glob to search -- e.g. '../docs/images/*.jpg'") + help="Output filename") @click.option('-t', '--threshold', 'opt_threshold', required=True, default=6, @@ -36,10 +36,13 @@ def cli(ctx, opt_input_glob, opt_output_fn, opt_threshold): phash = compute_phash(im) if is_phash_new(fn, phash, seen, opt_threshold): hash = sha256(fn) + fpart, ext = os.path.splitext(fn) + ext = ext[1:] seen.append({ 'sha256': hash, 'phash': phash, 'fn': fn, + 'ext': ext, }) if opt_output_fn: write_json(seen, opt_output_fn) diff --git a/check/commands/phash/import_csv.py b/check/commands/phash/import_csv.py new file mode 100644 index 0000000..5e09aa8 --- /dev/null +++ b/check/commands/phash/import_csv.py @@ -0,0 +1,54 @@ +""" +Import a CSV of URLs +""" + +import click +import os +import glob +import io +import random + +from PIL import Image + +from app.models.sql_factory import add_phash +from app.utils.im_utils import compute_phash_int +from app.utils.file_utils import load_csv, sha256_stream +from app.utils.process_utils import parallelize +from app.server.api import fetch_url + +@click.command() +@click.option('-i', '--input', 'opt_input_fn', + required=True, + help="Input path to CSV") +@click.option('-b', '--base_href', 'opt_base_href', + required=False, + default="", + help="Base href, default is empty string") +@click.option('-e', '--field', 'opt_field', + required=False, + default="address", + help="Field in CSV containing URL") +@click.pass_context +def cli(ctx, opt_input_fn, opt_base_href, opt_field): + """ + Import a folder of images, deduping them first + """ + def add_url(url): + fname, ext = os.path.splitext(url) + if ext not in ['.gif', '.jpg', '.jpeg', '.png']: + return + ext = ext[1:] + try: + raw, im = fetch_url(url) + except: + # print('404 {}'.format(url)) + return + print(url) + phash = compute_phash_int(im) + hash = sha256_stream(io.BytesIO(raw)) + add_phash(sha256=hash, phash=phash, ext=ext, url=url) + + rows = load_csv(opt_input_fn) + urls = [opt_base_href + row['address'] for row in rows] + random.shuffle(urls) + parallelize(urls, add_url) -- cgit v1.2.3-70-g09d2