diff options
Diffstat (limited to 'check/app')
| -rw-r--r-- | check/app/server/api.py | 28 | ||||
| -rw-r--r-- | check/app/utils/file_utils.py | 1 | ||||
| -rw-r--r-- | check/app/utils/process_utils.py | 60 |
3 files changed, 77 insertions, 12 deletions
diff --git a/check/app/server/api.py b/check/app/server/api.py index 324a63a..184e28d 100644 --- a/check/app/server/api.py +++ b/check/app/server/api.py @@ -30,6 +30,20 @@ def index(): """ return jsonify({ 'status': 'ok' }) +def fetch_url(url): + if not url: + return None, 'no_image' + basename, ext = os.path.splitext(url) + if ext.lower() not in valid_exts: + return None, 'not_an_image' + ext = ext[1:].lower() + + remote_request = urllib.request.Request(url) + remote_response = urllib.request.urlopen(remote_request) + raw = remote_response.read() + im = Image.open(io.BytesIO(raw)).convert('RGB') + return raw, im + def get_params(default_threshold=MATCH_THRESHOLD, default_limit=MATCH_LIMIT): try: threshold = int(request.form.get('threshold') or default_threshold) @@ -58,17 +72,9 @@ def get_params(default_threshold=MATCH_THRESHOLD, default_limit=MATCH_LIMIT): # Fetch remote URL else: url = request.form.get('url') - if not url: - return None, 'no_image' - basename, ext = os.path.splitext(url) - if ext.lower() not in valid_exts: - return None, 'not_an_image' - ext = ext[1:].lower() - - remote_request = urllib.request.Request(url) - remote_response = urllib.request.urlopen(remote_request) - raw = remote_response.read() - im = Image.open(io.BytesIO(raw)).convert('RGB') + raw, im = fetch_url(url) + if raw is None: + return raw, im # error return (threshold, limit, url, ext, raw, im,), None @api.route('/v1/match', methods=['POST']) diff --git a/check/app/utils/file_utils.py b/check/app/utils/file_utils.py index a185cf4..5c9311e 100644 --- a/check/app/utils/file_utils.py +++ b/check/app/utils/file_utils.py @@ -247,7 +247,6 @@ def write_csv(data, fp_out, header=None): for k, v in data.items(): fp.writerow('{},{}'.format(k, v)) - def write_serialized_items(items, fp_out, ensure_path=True, minify=True, sort_keys=True): """Writes serialized data :param items: (dict) a sha256 dict of MappingItems diff --git a/check/app/utils/process_utils.py b/check/app/utils/process_utils.py new file mode 100644 index 0000000..7f243ae --- /dev/null +++ b/check/app/utils/process_utils.py @@ -0,0 +1,60 @@ +import os +import pathos.pools as pp +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor, as_completed + +def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=3): + """ + A parallel version of the map function with a progress bar. + + Args: + array (array-like): An array to iterate over. + function (function): A python function to apply to the elements of array + n_jobs (int, default=16): The number of cores to use + use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of + keyword arguments to function + front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job. + Useful for catching bugs + Returns: + [function(array[0]), function(array[1]), ...] + """ + #We run the first few iterations serially to catch bugs + if front_num > 0: + front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]] + #If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging. + if n_jobs==1: + return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])] + #Assemble the workers + with ProcessPoolExecutor(max_workers=n_jobs) as pool: + #Pass the elements of array into function + if use_kwargs: + futures = [pool.submit(function, **a) for a in array[front_num:]] + else: + futures = [pool.submit(function, a) for a in array[front_num:]] + kwargs = { + 'total': len(futures), + 'unit': 'it', + 'unit_scale': True, + 'leave': True + } + #Print out the progress as tasks complete + for f in tqdm(as_completed(futures), **kwargs): + pass + out = [] + #Get the results from the futures. + for i, future in tqdm(enumerate(futures)): + try: + out.append(future.result()) + except Exception as e: + out.append(e) + return front + out + +def parallelize(rows, func): + print("Processing {} items".format(len(rows))) + if hasattr(os, 'sched_getaffinity'): + processCount = len(os.sched_getaffinity(0)) + else: + processCount = 4 + print('processes {}'.format(processCount)) + with pp.ProcessPool(processes=processCount) as pool: + pool.map(func, rows)
\ No newline at end of file |
