diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2020-04-07 19:07:26 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2020-04-07 19:07:26 +0200 |
| commit | 762c2b723782f3859a0e85fcc3810058ca12e26e (patch) | |
| tree | f352e617b4dea4d4152142b4897e3390479a9a82 | |
| parent | 9cb1b1bded44273f7abf2ae3d4950e03046f3579 (diff) | |
rewrite solver, descend both trees at once until you find where they meet
| -rw-r--r-- | cli/app/thesaurus/api.py | 2 | ||||
| -rw-r--r-- | cli/commands/bridge/naive.py | 255 | ||||
| -rw-r--r-- | cli/commands/bridge/words.py | 373 |
3 files changed, 449 insertions, 181 deletions
diff --git a/cli/app/thesaurus/api.py b/cli/app/thesaurus/api.py index 467d5fb..89e5ad1 100644 --- a/cli/app/thesaurus/api.py +++ b/cli/app/thesaurus/api.py @@ -24,7 +24,7 @@ class Thesaurus: try: data = api_fn(word) except Exception as e: - print("Got HTTP error, sleeping for 5 seconds") + print("Got HTTP error for {word}, sleeping for 5 seconds") time.sleep(5) pass write_json(path, data) diff --git a/cli/commands/bridge/naive.py b/cli/commands/bridge/naive.py new file mode 100644 index 0000000..4014279 --- /dev/null +++ b/cli/commands/bridge/naive.py @@ -0,0 +1,255 @@ +""" +Find connections between two words +""" + +import click +import random +import simplejson as json +from tqdm import tqdm + +from app.thesaurus.api import Thesaurus + +@click.command() +@click.option('-a', '--a', 'opt_word_a', required=True, + help='Starting word') +@click.option('-b', '--b', 'opt_word_b', required=True, + help='Ending word') +@click.option('-oe', '--include_oe', 'opt_include_oe', is_flag=True, + help='Whether to include OE/archaic words') +@click.option('-sl', '--include_slang', 'opt_include_slang', is_flag=True, + help='Whether to include slang/colloquial words') +@click.option('-w', '--words_per_step', 'opt_words_per_step', default=20, + help='Number of words to check per step') +@click.option('-c', '--categories_per_word', 'opt_categories_per_word', default=3, + help='Number of categories to check per word') +@click.pass_context +def cli(ctx, opt_word_a, opt_word_b, opt_include_oe, opt_include_slang, opt_words_per_step, opt_categories_per_word): + """Find connections between two words + """ + thesaurus = Thesaurus() + print(f"Starting word: {opt_word_a}") + print(f"Ending word: {opt_word_b}") + + categories = thesaurus.search(opt_word_b)['categories'] + initial_tree = {} + initial_tree[opt_word_a] = 0 + initial_tree[opt_word_b] = 999 + target_word_count = 0 + # for cat in categories: + # catid = cat['catid'] + # initial_tree[catid] = 998 + # category_result = thesaurus.category(catid) + # print(initial_tree) + print(f"Potential target words: {target_word_count}") + + step = 0 + max_dist = 0 + marked = initial_tree.copy() + queue = [opt_word_a] + newqueue = [] + skip = {} + found = False + should_reset = False + + def reset(): + marked = initial_tree.copy() + queue = [opt_word_a] + newqueue = [] + + # First compute distance to each node to find a path + while len(queue): + step = step + 1 + print("") + print(f"Iteration step {step}, depth {max_dist}, {len(queue) + len(newqueue)} items in queue") + print(f"Words: {', '.join(queue[:7])} ...") + print("") + if step > 1: + print_chain(thesaurus, opt_word_a, queue, marked, skip, prompt_to_remove=False) + for word_q in tqdm(queue): + word_result = thesaurus.search(word_q) + # print(json.dumps(word_result, indent=2)) + categories = word_result['categories'] + if step > 1 and len(categories) > opt_categories_per_word: + categories = categories[:opt_categories_per_word] + for cat in categories: + catid = cat['catid'] + if catid in marked: + if marked[catid] > 990: + should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True) + if should_reset: + reset() + break + continue + if word_q in skip and catid in skip[word_q]: + continue + marked[catid] = marked[word_q] + 1 + category_result = thesaurus.category(catid) + # print(json.dumps(category_result, indent=2)) + for word_c in category_result['words']: + word_n = fix_word(word_c['word']) + years = word_c['years'].lower() + if not opt_include_oe and (('oe' in years and 'oe-' not in years) or 'arch' in years): + continue + if not opt_include_slang and 'slang' in years or 'colloq' in years or 'Scots' in years: + continue + if word_n in marked: + if marked[word_n] > 990: + print(f"Found {word_n} in {catid}") + should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True) + if should_reset: + reset() + break + continue + if catid in skip and word_n in skip[catid]: + continue + marked[word_n] = marked[catid] + 1 + max_dist = max(max_dist, marked[word_n]) + if word_n == opt_word_b: + thesaurus.search(word_n) + # print(queue) + should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True) + if should_reset: + reset() + break + else: + newqueue.append(word_n) + if should_reset: + break + if should_reset: + break + if should_reset: + should_reset = False + continue + if step > 1 and len(newqueue) > opt_words_per_step: + random.shuffle(newqueue) + queue = newqueue[:opt_words_per_step] + newqueue = newqueue[opt_words_per_step:] + else: + queue = [] + newqueue + if not found: + print(f"No path found, step {step} reached, {len(marked)} nodes checked") + return + +def fix_word(word_n): + if '<' in word_n or '/' in word_n or ',' in word_n: + word_n = word_n.split("<")[0] + word_n = word_n.split(",")[0] + word_n = word_n.split("/")[0] + return word_n.strip() + +def print_chain(thesaurus, opt_word_a, opt_words_b, marked, skip, prompt_to_remove=False): + """Follow the chain of shortest distance from the end back to the start""" + # print(opt_word_a) + if prompt_to_remove: + print("") + print("") + print("--------------- PATH FOUND ---------------") + print("") + word_n = opt_words_b[0] + dist = marked[word_n] + tries = 0 + depth_tries = 0 + chain = [] + skip_here = [] + cat_reverse = {} + while word_n != opt_word_a: + if tries > len(opt_words_b): + print(f"tries: {tries}, targets: {len(opt_words_b)}") + print("Too many tries to produce a chain...") + return False + next_word = "" + break_loop = False + word_result = thesaurus.search(word_n) + if word_n not in chain: + chain.append(word_n) + else: + depth_tries = 100 + # print(word_result['word']) + for cat in word_result['categories']: + catid = cat['catid'] + if (word_n in skip and catid in skip[word_n]) or catid in skip_here: + continue + if catid in marked and marked[catid] < dist: + dist = marked[catid] + # print(f"{dist}: {catid}") + cat_result = thesaurus.category(catid) + for word_c in cat_result['words']: + word_m = word_c['word'] + if '<' in word_m or '/' in word_m or ',' in word_m: + word_m = word_m.split("<")[0] + word_n = word_n.split(",")[0] + word_m = word_m.split("/")[0].strip() + # print(word_m) + if (catid in skip and word_m in skip[catid]) or word_m in skip_here: + continue + if word_m == opt_word_a or (word_m in marked and marked[word_m] < dist): + dist = marked[word_m] + # print(f"{dist}: {word_m}") + next_word = word_m + break_loop = True + break + if break_loop: + cat_name = cat_result['category'] + cat_reverse[cat_name] = catid + chain.append(cat_name) + break + if next_word == '': + if depth_tries < 100 and len(chain) > 2: + to_skip = chain[-1] + if to_skip in cat_reverse: + to_skip = cat_reverse[to_skip] + skip_here.append(to_skip) + to_skip = chain[-2] + if to_skip in cat_reverse: + to_skip = cat_reverse[to_skip] + skip_here.append(to_skip) + chain = chain[:-2] + word_n = chain[-1] + depth_tries += 1 + elif depth_tries >= 100: + # print(skip_here) + # print(f"{depth_tries} {tries}") + tries += 1 + if tries >= len(opt_words_b): + # print(f"tries: {tries}, targets: {len(opt_words_b)}") + # print("Too many tries to produce a chain...") + return False + word_n = opt_words_b[tries] + chain = [] + dist = marked[word_n] + depth_tries = 0 + continue + word_n = next_word + chain.append(opt_word_a) + chain = list(reversed(chain)) + for i, word in enumerate(chain): + if (i % 2) == 0: + print(f"{i+1} -> {word}") + else: + print(f"{i+1} => {word}") + if prompt_to_remove: + print("") + print("If you don't like this path, enter the IDs of words to remove separated by spaces, or Ctrl-C to exit.") + ids = input("Enter numbers > ") + ids = ids.split(" ") + for id in ids: + if len(id): + try: + id = int(id) + id -= 1 + word_a = chain[id] + # print(f"Removing {word_a}") + if word_a in cat_reverse: + word_a = cat_reverse[word_a] + word_b = chain[id-1] + # print(f"Connected upward to {word_b}") + if word_b in cat_reverse: + word_b = cat_reverse[word_b] + if word_a in skip: + skip[word_a].append(word_b) + else: + skip[word_a] = [word_b] + except Exception as e: + continue + return True + return False diff --git a/cli/commands/bridge/words.py b/cli/commands/bridge/words.py index 558f6e3..bcfd2c4 100644 --- a/cli/commands/bridge/words.py +++ b/cli/commands/bridge/words.py @@ -2,6 +2,8 @@ Find connections between two words """ +import sys +import time import click import random import simplejson as json @@ -24,193 +26,204 @@ from app.thesaurus.api import Thesaurus help='Number of categories to check per word') @click.pass_context def cli(ctx, opt_word_a, opt_word_b, opt_include_oe, opt_include_slang, opt_words_per_step, opt_categories_per_word): - """Find connections between two words + """ + Find connections between two words """ thesaurus = Thesaurus() + solver = TreeSolver(thesaurus, opt_word_a, opt_word_b, opt_include_oe, opt_include_slang, opt_words_per_step, opt_categories_per_word) print(f"Starting word: {opt_word_a}") print(f"Ending word: {opt_word_b}") - visited = set() - step = 0 - max_dist = 0 - marked = { opt_word_a: 0 } - queue = [opt_word_a] - newqueue = [] - skip = {} - found = False - should_reset = False - # First compute distance to each node to find a path - while len(queue): - step = step + 1 - print(f"Iteration step {step}, depth {max_dist}, {len(queue) + len(newqueue)} items in queue") - print(f"Words: {', '.join(queue[:7])} ...") - if step > 1: - print_chain(thesaurus, opt_word_a, queue, marked, skip, prompt_to_remove=False) - for word_q in tqdm(queue): - word_result = thesaurus.search(word_q) - # print(json.dumps(word_result, indent=2)) - categories = word_result['categories'] - if step > 1 and len(categories) > opt_categories_per_word: - categories = categories[:opt_categories_per_word] - for cat in categories: - catid = cat['catid'] - if catid in marked: + queue_a = [opt_word_a] + queue_b = [opt_word_b] + + while True: + queue_a = solver.build_tree(words=queue_a, tree=solver.tree_a, target=solver.tree_b) + if solver.should_reset: + queue_a = [ opt_word_a ] + queue_b = [ opt_word_b ] + solver.reset() + queue_b = solver.build_tree(words=queue_b, tree=solver.tree_b, target=solver.tree_a) + if solver.should_reset: + queue_a = [ opt_word_a ] + queue_b = [ opt_word_b ] + solver.reset() + print(f"[depth] {solver.max_dist} [queue a] {len(queue_a)} [queue b] {len(queue_b)} [skips] {len(solver.skips)}") + # print(solver.skips) + +class TreeSolver: + def __init__(self, thesaurus, word_a, word_b, include_oe, include_slang, words_per_step, categories_per_word): + self.thesaurus = thesaurus + self.word_a = word_a + self.word_b = word_b + self.include_oe = include_oe + self.include_slang = include_slang + self.words_per_step = words_per_step + self.categories_per_word = categories_per_word + self.skips = [] + self.max_dist = 0 + self.reset() + + def reset(self): + self.tree_a = { self.word_a: 0 } + self.tree_b = { self.word_b: 0 } + self.should_reset = False + + def build_tree(self, words=[], tree={}, target={}, depth=999): + next_queue = [] + if len(words) > self.words_per_step: + next_queue += words[self.words_per_step:] + words = words[:self.words_per_step] + for word in tqdm(words): + categories = self.thesaurus.search(word)['categories'] + count = 0 + for category in categories: + if count > self.categories_per_word: + break + catid = category['catid'] + if (word, str(catid),) in self.skips: + # print(f"Skip {word} {catid}") continue - if word_q in skip and catid in skip[word_q]: + if catid in tree: continue - marked[catid] = marked[word_q] + 1 - category_result = thesaurus.category(catid) - # print(json.dumps(category_result, indent=2)) - for word_c in category_result['words']: - word_n = word_c['word'] - years = word_c['years'].lower() - if not opt_include_oe and (('oe' in years and 'oe-' not in years) or 'arch' in years): - continue - if not opt_include_slang and 'slang' in years or 'colloq' in years or 'Scots' in years: - continue - if '<' in word_n or '/' in word_n or ',' in word_n: - word_n = word_n.split("<")[0] - word_n = word_n.split(",")[0] - word_n = word_n.split("/")[0].strip() - if word_n in marked: - continue - if catid in skip and word_n in skip[catid]: - continue - marked[word_n] = marked[catid] + 1 - max_dist = max(max_dist, marked[word_n]) - if word_n == opt_word_b: - thesaurus.search(word_n) - # print(queue) - should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True) - if should_reset: - marked = { opt_word_a: 0 } - queue = [opt_word_a] - newqueue = [] - break - else: - newqueue.append(word_n) - if should_reset: - break - if should_reset: - break - if should_reset: - should_reset = False - continue - if step > 1 and len(newqueue) > opt_words_per_step: - random.shuffle(newqueue) - queue = newqueue[:opt_words_per_step] - newqueue = newqueue[opt_words_per_step:] - else: - queue = [] + newqueue - if not found: - print(f"No path found, step {step} reached, {len(marked)} nodes checked") - return + tree[catid] = tree[word] + 1 + add_to_queue = self.process_category(catid, tree, target) + if self.should_reset: + return [] + if len(add_to_queue): + next_queue += add_to_queue + count += 1 + return next_queue -def print_chain(thesaurus, opt_word_a, opt_words_b, marked, skip, prompt_to_remove=False): - """Follow the chain of shortest distance from the end back to the start""" - # print(opt_word_a) - if prompt_to_remove: - print("") - print("--------------- PATH FOUND ---------------") - print("") - word_n = opt_words_b[0] - dist = marked[word_n] - tries = 0 - depth_tries = 0 - chain = [] - skip_here = [] - cat_reverse = {} - while word_n != opt_word_a: - if tries > len(opt_words_b): - return - next_word = "" - break_loop = False - word_result = thesaurus.search(word_n) - chain.append(word_result['word']) - # print(word_result['word']) - for cat in word_result['categories']: - catid = cat['catid'] - if (word_n in skip and catid in skip[word_n]) or catid in skip_here: + def process_category(self, catid, tree, target): + queue = [] + category_result = self.thesaurus.category(catid) + for category_word in category_result['words']: + word = self.fix_word(category_word['word']) + years = category_word['years'].lower() + if (catid, word,) in self.skips: continue - if catid in marked and marked[catid] < dist: - dist = marked[catid] - # print(f"{dist}: {catid}") - cat_result = thesaurus.category(catid) - for word_c in cat_result['words']: - word_m = word_c['word'] - if '<' in word_m or '/' in word_m or ',' in word_m: - word_m = word_m.split("<")[0] - word_n = word_n.split(",")[0] - word_m = word_m.split("/")[0].strip() - # print(word_m) - if (catid in skip and word_m in skip[catid]) or word_m in skip_here: - continue - if word_m == opt_word_a or (word_m in marked and marked[word_m] < dist): - dist = marked[word_m] - # print(f"{dist}: {word_m}") - next_word = word_m - break_loop = True - break - if break_loop: - cat_name = cat_result['category'] - cat_reverse[cat_name] = catid - chain.append(cat_name) - break - if next_word == '': - if depth_tries < 100 and len(chain) > 2: - to_skip = chain[-1] - if to_skip in cat_reverse: - to_skip = cat_reverse[to_skip] - skip_here.append(to_skip) - to_skip = chain[-2] - if to_skip in cat_reverse: - to_skip = cat_reverse[to_skip] - skip_here.append(to_skip) - chain.pop() - chain.pop() - word_n = chain[-1] - depth_tries += 1 - elif depth_tries >= 100: - print(f"{depth_tries} {tries}") - tries += 1 - if tries >= len(opt_words_b): - return - word_n = opt_words_b[tries] - chain = [] - dist = marked[word_n] - depth_tries = 0 - continue - word_n = next_word - chain.append(opt_word_a) - chain = list(reversed(chain)) - for i, word in enumerate(chain): - if (i % 2) == 0: - print(f"{i+1} -> {word}") + word = self.process_word(word, years, catid, tree, target) + if word: + queue.append(word) + if self.should_reset: + return + return queue + + def process_word(self, word, years, catid, tree, target): + if not self.include_oe and self.is_oe(years): + return None + if not self.include_slang and self.is_slang(years): + return None + if word not in tree: + tree[word] = tree[catid] + 1 + self.max_dist = max(self.max_dist, tree[word]) + if word in target: + self.make_chain(hinge=word, can_remove=True) + return word + if word in target: + self.make_chain(hinge=word, can_remove=True) + return None + + def make_chain(self, hinge, can_remove=True): + # tqdm.write(f"Making chain from {hinge}") + chain_a = self.descend_chain(hinge, self.tree_a) + chain_b = self.descend_chain(hinge, self.tree_b) + chain = list(reversed(chain_a)) + [hinge] + chain_b + self.display_chain(chain) + if can_remove: + tqdm.write("Enter a number to break the chain, enter to keep searching, or Ctrl-C to exit") + tqdm.write("") + index = input("> ").strip() + if index and self.is_integer(index): + item = chain[int(index)] + if item in chain_a: + self.add_skip(item, chain_a) + self.should_reset = True + if item in chain_b: + self.add_skip(item, chain_b) + self.should_reset = True + return True + return False + + def add_skip(self, item, chain): + index = chain.index(item) + if index == len(chain) - 1: + return + prev_item = chain[index + 1] + self.skips.append((prev_item, item)) + if self.is_integer(item): + tqdm.write(f"Removing: {prev_item} => {self.get_category_name(item)}") else: - print(f"{i+1} => {word}") - if prompt_to_remove: - print("") - print("If you don't like this path, enter the IDs of words to remove separated by spaces, or Ctrl-C to exit.") - ids = input("Enter numbers > ") - ids = ids.split(" ") - for id in ids: - if len(id): - try: - id = int(id) - id -= 1 - word_a = chain[id] - # print(f"Removing {word_a}") - if word_a in cat_reverse: - word_a = cat_reverse[word_a] - word_b = chain[id-1] - # print(f"Connected upward to {word_b}") - if word_b in cat_reverse: - word_b = cat_reverse[word_b] - if word_a in skip: - skip[word_a].append(word_b) - else: - skip[word_a] = [word_b] - except Exception as e: - continue - return True - return False + tqdm.write(f"Removing: {self.get_category_name(prev_item)} => {item}") + + def descend_chain(self, word, tree): + start_word = word + chain = [] + while word is not None: + match = None + if self.is_integer(word): + category_result = self.thesaurus.category(word) + for category_word in category_result['words']: + cat_word = self.fix_word(category_word['word']) + if cat_word != word and cat_word in tree and tree[cat_word] < tree[word]: + chain.append(cat_word) + match = cat_word + break + else: + categories = self.thesaurus.search(word)['categories'] + for category in categories: + catid = category['catid'] + if catid != word and catid in tree and tree[catid] < tree[word]: + chain.append(catid) + match = catid + break + if match is not None: + word = match + if tree[word] == 0: + break + else: + if self.is_integer(word): + tqdm.write(f"No match for: {self.get_category_name(word)}") + tqdm.write(f"Chain started with {start_word}") + self.display_chain(chain) + else: + tqdm.write(f"No match for: {word}") + tqdm.write(f"Chain started with {start_word}") + self.display_chain(chain) + return [] + return chain + + def display_chain(self, chain): + tqdm.write("") + for i, word in enumerate(chain): + if self.is_integer(word): + word = self.get_category_name(word) + tqdm.write(f"{i} -> {word}") + else: + tqdm.write(f"{i} => {word}") + tqdm.write("") + + def get_category_name(self, catid): + category = self.thesaurus.category(catid) + return category['category'] + + def is_integer(self, s): + try: + int(s) + return True + except Exception as e: + return False + + def is_oe(self, years): + return (('oe' in years and 'oe-' not in years) or 'arch' in years) + + def is_slang(self, years): + return 'slang' in years or 'colloq' in years or 'Scots' in years + + def fix_word(self, word): + if '<' in word or '/' in word or ',' in word: + word = word.split("<")[0] + word = word.split(",")[0] + word = word.split("/")[0] + return word.strip() |
