diff options
Diffstat (limited to 'cli/commands/bridge/naive.py')
| -rw-r--r-- | cli/commands/bridge/naive.py | 255 |
1 files changed, 255 insertions, 0 deletions
diff --git a/cli/commands/bridge/naive.py b/cli/commands/bridge/naive.py new file mode 100644 index 0000000..4014279 --- /dev/null +++ b/cli/commands/bridge/naive.py @@ -0,0 +1,255 @@ +""" +Find connections between two words +""" + +import click +import random +import simplejson as json +from tqdm import tqdm + +from app.thesaurus.api import Thesaurus + +@click.command() +@click.option('-a', '--a', 'opt_word_a', required=True, + help='Starting word') +@click.option('-b', '--b', 'opt_word_b', required=True, + help='Ending word') +@click.option('-oe', '--include_oe', 'opt_include_oe', is_flag=True, + help='Whether to include OE/archaic words') +@click.option('-sl', '--include_slang', 'opt_include_slang', is_flag=True, + help='Whether to include slang/colloquial words') +@click.option('-w', '--words_per_step', 'opt_words_per_step', default=20, + help='Number of words to check per step') +@click.option('-c', '--categories_per_word', 'opt_categories_per_word', default=3, + help='Number of categories to check per word') +@click.pass_context +def cli(ctx, opt_word_a, opt_word_b, opt_include_oe, opt_include_slang, opt_words_per_step, opt_categories_per_word): + """Find connections between two words + """ + thesaurus = Thesaurus() + print(f"Starting word: {opt_word_a}") + print(f"Ending word: {opt_word_b}") + + categories = thesaurus.search(opt_word_b)['categories'] + initial_tree = {} + initial_tree[opt_word_a] = 0 + initial_tree[opt_word_b] = 999 + target_word_count = 0 + # for cat in categories: + # catid = cat['catid'] + # initial_tree[catid] = 998 + # category_result = thesaurus.category(catid) + # print(initial_tree) + print(f"Potential target words: {target_word_count}") + + step = 0 + max_dist = 0 + marked = initial_tree.copy() + queue = [opt_word_a] + newqueue = [] + skip = {} + found = False + should_reset = False + + def reset(): + marked = initial_tree.copy() + queue = [opt_word_a] + newqueue = [] + + # First compute distance to each node to find a path + while len(queue): + step = step + 1 + print("") + print(f"Iteration step {step}, depth {max_dist}, {len(queue) + len(newqueue)} items in queue") + print(f"Words: {', '.join(queue[:7])} ...") + print("") + if step > 1: + print_chain(thesaurus, opt_word_a, queue, marked, skip, prompt_to_remove=False) + for word_q in tqdm(queue): + word_result = thesaurus.search(word_q) + # print(json.dumps(word_result, indent=2)) + categories = word_result['categories'] + if step > 1 and len(categories) > opt_categories_per_word: + categories = categories[:opt_categories_per_word] + for cat in categories: + catid = cat['catid'] + if catid in marked: + if marked[catid] > 990: + should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True) + if should_reset: + reset() + break + continue + if word_q in skip and catid in skip[word_q]: + continue + marked[catid] = marked[word_q] + 1 + category_result = thesaurus.category(catid) + # print(json.dumps(category_result, indent=2)) + for word_c in category_result['words']: + word_n = fix_word(word_c['word']) + years = word_c['years'].lower() + if not opt_include_oe and (('oe' in years and 'oe-' not in years) or 'arch' in years): + continue + if not opt_include_slang and 'slang' in years or 'colloq' in years or 'Scots' in years: + continue + if word_n in marked: + if marked[word_n] > 990: + print(f"Found {word_n} in {catid}") + should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True) + if should_reset: + reset() + break + continue + if catid in skip and word_n in skip[catid]: + continue + marked[word_n] = marked[catid] + 1 + max_dist = max(max_dist, marked[word_n]) + if word_n == opt_word_b: + thesaurus.search(word_n) + # print(queue) + should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True) + if should_reset: + reset() + break + else: + newqueue.append(word_n) + if should_reset: + break + if should_reset: + break + if should_reset: + should_reset = False + continue + if step > 1 and len(newqueue) > opt_words_per_step: + random.shuffle(newqueue) + queue = newqueue[:opt_words_per_step] + newqueue = newqueue[opt_words_per_step:] + else: + queue = [] + newqueue + if not found: + print(f"No path found, step {step} reached, {len(marked)} nodes checked") + return + +def fix_word(word_n): + if '<' in word_n or '/' in word_n or ',' in word_n: + word_n = word_n.split("<")[0] + word_n = word_n.split(",")[0] + word_n = word_n.split("/")[0] + return word_n.strip() + +def print_chain(thesaurus, opt_word_a, opt_words_b, marked, skip, prompt_to_remove=False): + """Follow the chain of shortest distance from the end back to the start""" + # print(opt_word_a) + if prompt_to_remove: + print("") + print("") + print("--------------- PATH FOUND ---------------") + print("") + word_n = opt_words_b[0] + dist = marked[word_n] + tries = 0 + depth_tries = 0 + chain = [] + skip_here = [] + cat_reverse = {} + while word_n != opt_word_a: + if tries > len(opt_words_b): + print(f"tries: {tries}, targets: {len(opt_words_b)}") + print("Too many tries to produce a chain...") + return False + next_word = "" + break_loop = False + word_result = thesaurus.search(word_n) + if word_n not in chain: + chain.append(word_n) + else: + depth_tries = 100 + # print(word_result['word']) + for cat in word_result['categories']: + catid = cat['catid'] + if (word_n in skip and catid in skip[word_n]) or catid in skip_here: + continue + if catid in marked and marked[catid] < dist: + dist = marked[catid] + # print(f"{dist}: {catid}") + cat_result = thesaurus.category(catid) + for word_c in cat_result['words']: + word_m = word_c['word'] + if '<' in word_m or '/' in word_m or ',' in word_m: + word_m = word_m.split("<")[0] + word_n = word_n.split(",")[0] + word_m = word_m.split("/")[0].strip() + # print(word_m) + if (catid in skip and word_m in skip[catid]) or word_m in skip_here: + continue + if word_m == opt_word_a or (word_m in marked and marked[word_m] < dist): + dist = marked[word_m] + # print(f"{dist}: {word_m}") + next_word = word_m + break_loop = True + break + if break_loop: + cat_name = cat_result['category'] + cat_reverse[cat_name] = catid + chain.append(cat_name) + break + if next_word == '': + if depth_tries < 100 and len(chain) > 2: + to_skip = chain[-1] + if to_skip in cat_reverse: + to_skip = cat_reverse[to_skip] + skip_here.append(to_skip) + to_skip = chain[-2] + if to_skip in cat_reverse: + to_skip = cat_reverse[to_skip] + skip_here.append(to_skip) + chain = chain[:-2] + word_n = chain[-1] + depth_tries += 1 + elif depth_tries >= 100: + # print(skip_here) + # print(f"{depth_tries} {tries}") + tries += 1 + if tries >= len(opt_words_b): + # print(f"tries: {tries}, targets: {len(opt_words_b)}") + # print("Too many tries to produce a chain...") + return False + word_n = opt_words_b[tries] + chain = [] + dist = marked[word_n] + depth_tries = 0 + continue + word_n = next_word + chain.append(opt_word_a) + chain = list(reversed(chain)) + for i, word in enumerate(chain): + if (i % 2) == 0: + print(f"{i+1} -> {word}") + else: + print(f"{i+1} => {word}") + if prompt_to_remove: + print("") + print("If you don't like this path, enter the IDs of words to remove separated by spaces, or Ctrl-C to exit.") + ids = input("Enter numbers > ") + ids = ids.split(" ") + for id in ids: + if len(id): + try: + id = int(id) + id -= 1 + word_a = chain[id] + # print(f"Removing {word_a}") + if word_a in cat_reverse: + word_a = cat_reverse[word_a] + word_b = chain[id-1] + # print(f"Connected upward to {word_b}") + if word_b in cat_reverse: + word_b = cat_reverse[word_b] + if word_a in skip: + skip[word_a].append(word_b) + else: + skip[word_a] = [word_b] + except Exception as e: + continue + return True + return False |
