summaryrefslogtreecommitdiff
path: root/cli/commands/bridge/naive.py
diff options
context:
space:
mode:
Diffstat (limited to 'cli/commands/bridge/naive.py')
-rw-r--r--cli/commands/bridge/naive.py255
1 files changed, 255 insertions, 0 deletions
diff --git a/cli/commands/bridge/naive.py b/cli/commands/bridge/naive.py
new file mode 100644
index 0000000..4014279
--- /dev/null
+++ b/cli/commands/bridge/naive.py
@@ -0,0 +1,255 @@
+"""
+Find connections between two words
+"""
+
+import click
+import random
+import simplejson as json
+from tqdm import tqdm
+
+from app.thesaurus.api import Thesaurus
+
+@click.command()
+@click.option('-a', '--a', 'opt_word_a', required=True,
+ help='Starting word')
+@click.option('-b', '--b', 'opt_word_b', required=True,
+ help='Ending word')
+@click.option('-oe', '--include_oe', 'opt_include_oe', is_flag=True,
+ help='Whether to include OE/archaic words')
+@click.option('-sl', '--include_slang', 'opt_include_slang', is_flag=True,
+ help='Whether to include slang/colloquial words')
+@click.option('-w', '--words_per_step', 'opt_words_per_step', default=20,
+ help='Number of words to check per step')
+@click.option('-c', '--categories_per_word', 'opt_categories_per_word', default=3,
+ help='Number of categories to check per word')
+@click.pass_context
+def cli(ctx, opt_word_a, opt_word_b, opt_include_oe, opt_include_slang, opt_words_per_step, opt_categories_per_word):
+ """Find connections between two words
+ """
+ thesaurus = Thesaurus()
+ print(f"Starting word: {opt_word_a}")
+ print(f"Ending word: {opt_word_b}")
+
+ categories = thesaurus.search(opt_word_b)['categories']
+ initial_tree = {}
+ initial_tree[opt_word_a] = 0
+ initial_tree[opt_word_b] = 999
+ target_word_count = 0
+ # for cat in categories:
+ # catid = cat['catid']
+ # initial_tree[catid] = 998
+ # category_result = thesaurus.category(catid)
+ # print(initial_tree)
+ print(f"Potential target words: {target_word_count}")
+
+ step = 0
+ max_dist = 0
+ marked = initial_tree.copy()
+ queue = [opt_word_a]
+ newqueue = []
+ skip = {}
+ found = False
+ should_reset = False
+
+ def reset():
+ marked = initial_tree.copy()
+ queue = [opt_word_a]
+ newqueue = []
+
+ # First compute distance to each node to find a path
+ while len(queue):
+ step = step + 1
+ print("")
+ print(f"Iteration step {step}, depth {max_dist}, {len(queue) + len(newqueue)} items in queue")
+ print(f"Words: {', '.join(queue[:7])} ...")
+ print("")
+ if step > 1:
+ print_chain(thesaurus, opt_word_a, queue, marked, skip, prompt_to_remove=False)
+ for word_q in tqdm(queue):
+ word_result = thesaurus.search(word_q)
+ # print(json.dumps(word_result, indent=2))
+ categories = word_result['categories']
+ if step > 1 and len(categories) > opt_categories_per_word:
+ categories = categories[:opt_categories_per_word]
+ for cat in categories:
+ catid = cat['catid']
+ if catid in marked:
+ if marked[catid] > 990:
+ should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True)
+ if should_reset:
+ reset()
+ break
+ continue
+ if word_q in skip and catid in skip[word_q]:
+ continue
+ marked[catid] = marked[word_q] + 1
+ category_result = thesaurus.category(catid)
+ # print(json.dumps(category_result, indent=2))
+ for word_c in category_result['words']:
+ word_n = fix_word(word_c['word'])
+ years = word_c['years'].lower()
+ if not opt_include_oe and (('oe' in years and 'oe-' not in years) or 'arch' in years):
+ continue
+ if not opt_include_slang and 'slang' in years or 'colloq' in years or 'Scots' in years:
+ continue
+ if word_n in marked:
+ if marked[word_n] > 990:
+ print(f"Found {word_n} in {catid}")
+ should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True)
+ if should_reset:
+ reset()
+ break
+ continue
+ if catid in skip and word_n in skip[catid]:
+ continue
+ marked[word_n] = marked[catid] + 1
+ max_dist = max(max_dist, marked[word_n])
+ if word_n == opt_word_b:
+ thesaurus.search(word_n)
+ # print(queue)
+ should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True)
+ if should_reset:
+ reset()
+ break
+ else:
+ newqueue.append(word_n)
+ if should_reset:
+ break
+ if should_reset:
+ break
+ if should_reset:
+ should_reset = False
+ continue
+ if step > 1 and len(newqueue) > opt_words_per_step:
+ random.shuffle(newqueue)
+ queue = newqueue[:opt_words_per_step]
+ newqueue = newqueue[opt_words_per_step:]
+ else:
+ queue = [] + newqueue
+ if not found:
+ print(f"No path found, step {step} reached, {len(marked)} nodes checked")
+ return
+
+def fix_word(word_n):
+ if '<' in word_n or '/' in word_n or ',' in word_n:
+ word_n = word_n.split("<")[0]
+ word_n = word_n.split(",")[0]
+ word_n = word_n.split("/")[0]
+ return word_n.strip()
+
+def print_chain(thesaurus, opt_word_a, opt_words_b, marked, skip, prompt_to_remove=False):
+ """Follow the chain of shortest distance from the end back to the start"""
+ # print(opt_word_a)
+ if prompt_to_remove:
+ print("")
+ print("")
+ print("--------------- PATH FOUND ---------------")
+ print("")
+ word_n = opt_words_b[0]
+ dist = marked[word_n]
+ tries = 0
+ depth_tries = 0
+ chain = []
+ skip_here = []
+ cat_reverse = {}
+ while word_n != opt_word_a:
+ if tries > len(opt_words_b):
+ print(f"tries: {tries}, targets: {len(opt_words_b)}")
+ print("Too many tries to produce a chain...")
+ return False
+ next_word = ""
+ break_loop = False
+ word_result = thesaurus.search(word_n)
+ if word_n not in chain:
+ chain.append(word_n)
+ else:
+ depth_tries = 100
+ # print(word_result['word'])
+ for cat in word_result['categories']:
+ catid = cat['catid']
+ if (word_n in skip and catid in skip[word_n]) or catid in skip_here:
+ continue
+ if catid in marked and marked[catid] < dist:
+ dist = marked[catid]
+ # print(f"{dist}: {catid}")
+ cat_result = thesaurus.category(catid)
+ for word_c in cat_result['words']:
+ word_m = word_c['word']
+ if '<' in word_m or '/' in word_m or ',' in word_m:
+ word_m = word_m.split("<")[0]
+ word_n = word_n.split(",")[0]
+ word_m = word_m.split("/")[0].strip()
+ # print(word_m)
+ if (catid in skip and word_m in skip[catid]) or word_m in skip_here:
+ continue
+ if word_m == opt_word_a or (word_m in marked and marked[word_m] < dist):
+ dist = marked[word_m]
+ # print(f"{dist}: {word_m}")
+ next_word = word_m
+ break_loop = True
+ break
+ if break_loop:
+ cat_name = cat_result['category']
+ cat_reverse[cat_name] = catid
+ chain.append(cat_name)
+ break
+ if next_word == '':
+ if depth_tries < 100 and len(chain) > 2:
+ to_skip = chain[-1]
+ if to_skip in cat_reverse:
+ to_skip = cat_reverse[to_skip]
+ skip_here.append(to_skip)
+ to_skip = chain[-2]
+ if to_skip in cat_reverse:
+ to_skip = cat_reverse[to_skip]
+ skip_here.append(to_skip)
+ chain = chain[:-2]
+ word_n = chain[-1]
+ depth_tries += 1
+ elif depth_tries >= 100:
+ # print(skip_here)
+ # print(f"{depth_tries} {tries}")
+ tries += 1
+ if tries >= len(opt_words_b):
+ # print(f"tries: {tries}, targets: {len(opt_words_b)}")
+ # print("Too many tries to produce a chain...")
+ return False
+ word_n = opt_words_b[tries]
+ chain = []
+ dist = marked[word_n]
+ depth_tries = 0
+ continue
+ word_n = next_word
+ chain.append(opt_word_a)
+ chain = list(reversed(chain))
+ for i, word in enumerate(chain):
+ if (i % 2) == 0:
+ print(f"{i+1} -> {word}")
+ else:
+ print(f"{i+1} => {word}")
+ if prompt_to_remove:
+ print("")
+ print("If you don't like this path, enter the IDs of words to remove separated by spaces, or Ctrl-C to exit.")
+ ids = input("Enter numbers > ")
+ ids = ids.split(" ")
+ for id in ids:
+ if len(id):
+ try:
+ id = int(id)
+ id -= 1
+ word_a = chain[id]
+ # print(f"Removing {word_a}")
+ if word_a in cat_reverse:
+ word_a = cat_reverse[word_a]
+ word_b = chain[id-1]
+ # print(f"Connected upward to {word_b}")
+ if word_b in cat_reverse:
+ word_b = cat_reverse[word_b]
+ if word_a in skip:
+ skip[word_a].append(word_b)
+ else:
+ skip[word_a] = [word_b]
+ except Exception as e:
+ continue
+ return True
+ return False