summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-04-07 19:07:26 +0200
committerJules Laplace <julescarbon@gmail.com>2020-04-07 19:07:26 +0200
commit762c2b723782f3859a0e85fcc3810058ca12e26e (patch)
treef352e617b4dea4d4152142b4897e3390479a9a82
parent9cb1b1bded44273f7abf2ae3d4950e03046f3579 (diff)
rewrite solver, descend both trees at once until you find where they meet
-rw-r--r--cli/app/thesaurus/api.py2
-rw-r--r--cli/commands/bridge/naive.py255
-rw-r--r--cli/commands/bridge/words.py373
3 files changed, 449 insertions, 181 deletions
diff --git a/cli/app/thesaurus/api.py b/cli/app/thesaurus/api.py
index 467d5fb..89e5ad1 100644
--- a/cli/app/thesaurus/api.py
+++ b/cli/app/thesaurus/api.py
@@ -24,7 +24,7 @@ class Thesaurus:
try:
data = api_fn(word)
except Exception as e:
- print("Got HTTP error, sleeping for 5 seconds")
+ print("Got HTTP error for {word}, sleeping for 5 seconds")
time.sleep(5)
pass
write_json(path, data)
diff --git a/cli/commands/bridge/naive.py b/cli/commands/bridge/naive.py
new file mode 100644
index 0000000..4014279
--- /dev/null
+++ b/cli/commands/bridge/naive.py
@@ -0,0 +1,255 @@
+"""
+Find connections between two words
+"""
+
+import click
+import random
+import simplejson as json
+from tqdm import tqdm
+
+from app.thesaurus.api import Thesaurus
+
+@click.command()
+@click.option('-a', '--a', 'opt_word_a', required=True,
+ help='Starting word')
+@click.option('-b', '--b', 'opt_word_b', required=True,
+ help='Ending word')
+@click.option('-oe', '--include_oe', 'opt_include_oe', is_flag=True,
+ help='Whether to include OE/archaic words')
+@click.option('-sl', '--include_slang', 'opt_include_slang', is_flag=True,
+ help='Whether to include slang/colloquial words')
+@click.option('-w', '--words_per_step', 'opt_words_per_step', default=20,
+ help='Number of words to check per step')
+@click.option('-c', '--categories_per_word', 'opt_categories_per_word', default=3,
+ help='Number of categories to check per word')
+@click.pass_context
+def cli(ctx, opt_word_a, opt_word_b, opt_include_oe, opt_include_slang, opt_words_per_step, opt_categories_per_word):
+ """Find connections between two words
+ """
+ thesaurus = Thesaurus()
+ print(f"Starting word: {opt_word_a}")
+ print(f"Ending word: {opt_word_b}")
+
+ categories = thesaurus.search(opt_word_b)['categories']
+ initial_tree = {}
+ initial_tree[opt_word_a] = 0
+ initial_tree[opt_word_b] = 999
+ target_word_count = 0
+ # for cat in categories:
+ # catid = cat['catid']
+ # initial_tree[catid] = 998
+ # category_result = thesaurus.category(catid)
+ # print(initial_tree)
+ print(f"Potential target words: {target_word_count}")
+
+ step = 0
+ max_dist = 0
+ marked = initial_tree.copy()
+ queue = [opt_word_a]
+ newqueue = []
+ skip = {}
+ found = False
+ should_reset = False
+
+ def reset():
+ marked = initial_tree.copy()
+ queue = [opt_word_a]
+ newqueue = []
+
+ # First compute distance to each node to find a path
+ while len(queue):
+ step = step + 1
+ print("")
+ print(f"Iteration step {step}, depth {max_dist}, {len(queue) + len(newqueue)} items in queue")
+ print(f"Words: {', '.join(queue[:7])} ...")
+ print("")
+ if step > 1:
+ print_chain(thesaurus, opt_word_a, queue, marked, skip, prompt_to_remove=False)
+ for word_q in tqdm(queue):
+ word_result = thesaurus.search(word_q)
+ # print(json.dumps(word_result, indent=2))
+ categories = word_result['categories']
+ if step > 1 and len(categories) > opt_categories_per_word:
+ categories = categories[:opt_categories_per_word]
+ for cat in categories:
+ catid = cat['catid']
+ if catid in marked:
+ if marked[catid] > 990:
+ should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True)
+ if should_reset:
+ reset()
+ break
+ continue
+ if word_q in skip and catid in skip[word_q]:
+ continue
+ marked[catid] = marked[word_q] + 1
+ category_result = thesaurus.category(catid)
+ # print(json.dumps(category_result, indent=2))
+ for word_c in category_result['words']:
+ word_n = fix_word(word_c['word'])
+ years = word_c['years'].lower()
+ if not opt_include_oe and (('oe' in years and 'oe-' not in years) or 'arch' in years):
+ continue
+ if not opt_include_slang and 'slang' in years or 'colloq' in years or 'Scots' in years:
+ continue
+ if word_n in marked:
+ if marked[word_n] > 990:
+ print(f"Found {word_n} in {catid}")
+ should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True)
+ if should_reset:
+ reset()
+ break
+ continue
+ if catid in skip and word_n in skip[catid]:
+ continue
+ marked[word_n] = marked[catid] + 1
+ max_dist = max(max_dist, marked[word_n])
+ if word_n == opt_word_b:
+ thesaurus.search(word_n)
+ # print(queue)
+ should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True)
+ if should_reset:
+ reset()
+ break
+ else:
+ newqueue.append(word_n)
+ if should_reset:
+ break
+ if should_reset:
+ break
+ if should_reset:
+ should_reset = False
+ continue
+ if step > 1 and len(newqueue) > opt_words_per_step:
+ random.shuffle(newqueue)
+ queue = newqueue[:opt_words_per_step]
+ newqueue = newqueue[opt_words_per_step:]
+ else:
+ queue = [] + newqueue
+ if not found:
+ print(f"No path found, step {step} reached, {len(marked)} nodes checked")
+ return
+
+def fix_word(word_n):
+ if '<' in word_n or '/' in word_n or ',' in word_n:
+ word_n = word_n.split("<")[0]
+ word_n = word_n.split(",")[0]
+ word_n = word_n.split("/")[0]
+ return word_n.strip()
+
+def print_chain(thesaurus, opt_word_a, opt_words_b, marked, skip, prompt_to_remove=False):
+ """Follow the chain of shortest distance from the end back to the start"""
+ # print(opt_word_a)
+ if prompt_to_remove:
+ print("")
+ print("")
+ print("--------------- PATH FOUND ---------------")
+ print("")
+ word_n = opt_words_b[0]
+ dist = marked[word_n]
+ tries = 0
+ depth_tries = 0
+ chain = []
+ skip_here = []
+ cat_reverse = {}
+ while word_n != opt_word_a:
+ if tries > len(opt_words_b):
+ print(f"tries: {tries}, targets: {len(opt_words_b)}")
+ print("Too many tries to produce a chain...")
+ return False
+ next_word = ""
+ break_loop = False
+ word_result = thesaurus.search(word_n)
+ if word_n not in chain:
+ chain.append(word_n)
+ else:
+ depth_tries = 100
+ # print(word_result['word'])
+ for cat in word_result['categories']:
+ catid = cat['catid']
+ if (word_n in skip and catid in skip[word_n]) or catid in skip_here:
+ continue
+ if catid in marked and marked[catid] < dist:
+ dist = marked[catid]
+ # print(f"{dist}: {catid}")
+ cat_result = thesaurus.category(catid)
+ for word_c in cat_result['words']:
+ word_m = word_c['word']
+ if '<' in word_m or '/' in word_m or ',' in word_m:
+ word_m = word_m.split("<")[0]
+ word_n = word_n.split(",")[0]
+ word_m = word_m.split("/")[0].strip()
+ # print(word_m)
+ if (catid in skip and word_m in skip[catid]) or word_m in skip_here:
+ continue
+ if word_m == opt_word_a or (word_m in marked and marked[word_m] < dist):
+ dist = marked[word_m]
+ # print(f"{dist}: {word_m}")
+ next_word = word_m
+ break_loop = True
+ break
+ if break_loop:
+ cat_name = cat_result['category']
+ cat_reverse[cat_name] = catid
+ chain.append(cat_name)
+ break
+ if next_word == '':
+ if depth_tries < 100 and len(chain) > 2:
+ to_skip = chain[-1]
+ if to_skip in cat_reverse:
+ to_skip = cat_reverse[to_skip]
+ skip_here.append(to_skip)
+ to_skip = chain[-2]
+ if to_skip in cat_reverse:
+ to_skip = cat_reverse[to_skip]
+ skip_here.append(to_skip)
+ chain = chain[:-2]
+ word_n = chain[-1]
+ depth_tries += 1
+ elif depth_tries >= 100:
+ # print(skip_here)
+ # print(f"{depth_tries} {tries}")
+ tries += 1
+ if tries >= len(opt_words_b):
+ # print(f"tries: {tries}, targets: {len(opt_words_b)}")
+ # print("Too many tries to produce a chain...")
+ return False
+ word_n = opt_words_b[tries]
+ chain = []
+ dist = marked[word_n]
+ depth_tries = 0
+ continue
+ word_n = next_word
+ chain.append(opt_word_a)
+ chain = list(reversed(chain))
+ for i, word in enumerate(chain):
+ if (i % 2) == 0:
+ print(f"{i+1} -> {word}")
+ else:
+ print(f"{i+1} => {word}")
+ if prompt_to_remove:
+ print("")
+ print("If you don't like this path, enter the IDs of words to remove separated by spaces, or Ctrl-C to exit.")
+ ids = input("Enter numbers > ")
+ ids = ids.split(" ")
+ for id in ids:
+ if len(id):
+ try:
+ id = int(id)
+ id -= 1
+ word_a = chain[id]
+ # print(f"Removing {word_a}")
+ if word_a in cat_reverse:
+ word_a = cat_reverse[word_a]
+ word_b = chain[id-1]
+ # print(f"Connected upward to {word_b}")
+ if word_b in cat_reverse:
+ word_b = cat_reverse[word_b]
+ if word_a in skip:
+ skip[word_a].append(word_b)
+ else:
+ skip[word_a] = [word_b]
+ except Exception as e:
+ continue
+ return True
+ return False
diff --git a/cli/commands/bridge/words.py b/cli/commands/bridge/words.py
index 558f6e3..bcfd2c4 100644
--- a/cli/commands/bridge/words.py
+++ b/cli/commands/bridge/words.py
@@ -2,6 +2,8 @@
Find connections between two words
"""
+import sys
+import time
import click
import random
import simplejson as json
@@ -24,193 +26,204 @@ from app.thesaurus.api import Thesaurus
help='Number of categories to check per word')
@click.pass_context
def cli(ctx, opt_word_a, opt_word_b, opt_include_oe, opt_include_slang, opt_words_per_step, opt_categories_per_word):
- """Find connections between two words
+ """
+ Find connections between two words
"""
thesaurus = Thesaurus()
+ solver = TreeSolver(thesaurus, opt_word_a, opt_word_b, opt_include_oe, opt_include_slang, opt_words_per_step, opt_categories_per_word)
print(f"Starting word: {opt_word_a}")
print(f"Ending word: {opt_word_b}")
- visited = set()
- step = 0
- max_dist = 0
- marked = { opt_word_a: 0 }
- queue = [opt_word_a]
- newqueue = []
- skip = {}
- found = False
- should_reset = False
- # First compute distance to each node to find a path
- while len(queue):
- step = step + 1
- print(f"Iteration step {step}, depth {max_dist}, {len(queue) + len(newqueue)} items in queue")
- print(f"Words: {', '.join(queue[:7])} ...")
- if step > 1:
- print_chain(thesaurus, opt_word_a, queue, marked, skip, prompt_to_remove=False)
- for word_q in tqdm(queue):
- word_result = thesaurus.search(word_q)
- # print(json.dumps(word_result, indent=2))
- categories = word_result['categories']
- if step > 1 and len(categories) > opt_categories_per_word:
- categories = categories[:opt_categories_per_word]
- for cat in categories:
- catid = cat['catid']
- if catid in marked:
+ queue_a = [opt_word_a]
+ queue_b = [opt_word_b]
+
+ while True:
+ queue_a = solver.build_tree(words=queue_a, tree=solver.tree_a, target=solver.tree_b)
+ if solver.should_reset:
+ queue_a = [ opt_word_a ]
+ queue_b = [ opt_word_b ]
+ solver.reset()
+ queue_b = solver.build_tree(words=queue_b, tree=solver.tree_b, target=solver.tree_a)
+ if solver.should_reset:
+ queue_a = [ opt_word_a ]
+ queue_b = [ opt_word_b ]
+ solver.reset()
+ print(f"[depth] {solver.max_dist} [queue a] {len(queue_a)} [queue b] {len(queue_b)} [skips] {len(solver.skips)}")
+ # print(solver.skips)
+
+class TreeSolver:
+ def __init__(self, thesaurus, word_a, word_b, include_oe, include_slang, words_per_step, categories_per_word):
+ self.thesaurus = thesaurus
+ self.word_a = word_a
+ self.word_b = word_b
+ self.include_oe = include_oe
+ self.include_slang = include_slang
+ self.words_per_step = words_per_step
+ self.categories_per_word = categories_per_word
+ self.skips = []
+ self.max_dist = 0
+ self.reset()
+
+ def reset(self):
+ self.tree_a = { self.word_a: 0 }
+ self.tree_b = { self.word_b: 0 }
+ self.should_reset = False
+
+ def build_tree(self, words=[], tree={}, target={}, depth=999):
+ next_queue = []
+ if len(words) > self.words_per_step:
+ next_queue += words[self.words_per_step:]
+ words = words[:self.words_per_step]
+ for word in tqdm(words):
+ categories = self.thesaurus.search(word)['categories']
+ count = 0
+ for category in categories:
+ if count > self.categories_per_word:
+ break
+ catid = category['catid']
+ if (word, str(catid),) in self.skips:
+ # print(f"Skip {word} {catid}")
continue
- if word_q in skip and catid in skip[word_q]:
+ if catid in tree:
continue
- marked[catid] = marked[word_q] + 1
- category_result = thesaurus.category(catid)
- # print(json.dumps(category_result, indent=2))
- for word_c in category_result['words']:
- word_n = word_c['word']
- years = word_c['years'].lower()
- if not opt_include_oe and (('oe' in years and 'oe-' not in years) or 'arch' in years):
- continue
- if not opt_include_slang and 'slang' in years or 'colloq' in years or 'Scots' in years:
- continue
- if '<' in word_n or '/' in word_n or ',' in word_n:
- word_n = word_n.split("<")[0]
- word_n = word_n.split(",")[0]
- word_n = word_n.split("/")[0].strip()
- if word_n in marked:
- continue
- if catid in skip and word_n in skip[catid]:
- continue
- marked[word_n] = marked[catid] + 1
- max_dist = max(max_dist, marked[word_n])
- if word_n == opt_word_b:
- thesaurus.search(word_n)
- # print(queue)
- should_reset = print_chain(thesaurus, opt_word_a, [opt_word_b], marked, skip, prompt_to_remove=True)
- if should_reset:
- marked = { opt_word_a: 0 }
- queue = [opt_word_a]
- newqueue = []
- break
- else:
- newqueue.append(word_n)
- if should_reset:
- break
- if should_reset:
- break
- if should_reset:
- should_reset = False
- continue
- if step > 1 and len(newqueue) > opt_words_per_step:
- random.shuffle(newqueue)
- queue = newqueue[:opt_words_per_step]
- newqueue = newqueue[opt_words_per_step:]
- else:
- queue = [] + newqueue
- if not found:
- print(f"No path found, step {step} reached, {len(marked)} nodes checked")
- return
+ tree[catid] = tree[word] + 1
+ add_to_queue = self.process_category(catid, tree, target)
+ if self.should_reset:
+ return []
+ if len(add_to_queue):
+ next_queue += add_to_queue
+ count += 1
+ return next_queue
-def print_chain(thesaurus, opt_word_a, opt_words_b, marked, skip, prompt_to_remove=False):
- """Follow the chain of shortest distance from the end back to the start"""
- # print(opt_word_a)
- if prompt_to_remove:
- print("")
- print("--------------- PATH FOUND ---------------")
- print("")
- word_n = opt_words_b[0]
- dist = marked[word_n]
- tries = 0
- depth_tries = 0
- chain = []
- skip_here = []
- cat_reverse = {}
- while word_n != opt_word_a:
- if tries > len(opt_words_b):
- return
- next_word = ""
- break_loop = False
- word_result = thesaurus.search(word_n)
- chain.append(word_result['word'])
- # print(word_result['word'])
- for cat in word_result['categories']:
- catid = cat['catid']
- if (word_n in skip and catid in skip[word_n]) or catid in skip_here:
+ def process_category(self, catid, tree, target):
+ queue = []
+ category_result = self.thesaurus.category(catid)
+ for category_word in category_result['words']:
+ word = self.fix_word(category_word['word'])
+ years = category_word['years'].lower()
+ if (catid, word,) in self.skips:
continue
- if catid in marked and marked[catid] < dist:
- dist = marked[catid]
- # print(f"{dist}: {catid}")
- cat_result = thesaurus.category(catid)
- for word_c in cat_result['words']:
- word_m = word_c['word']
- if '<' in word_m or '/' in word_m or ',' in word_m:
- word_m = word_m.split("<")[0]
- word_n = word_n.split(",")[0]
- word_m = word_m.split("/")[0].strip()
- # print(word_m)
- if (catid in skip and word_m in skip[catid]) or word_m in skip_here:
- continue
- if word_m == opt_word_a or (word_m in marked and marked[word_m] < dist):
- dist = marked[word_m]
- # print(f"{dist}: {word_m}")
- next_word = word_m
- break_loop = True
- break
- if break_loop:
- cat_name = cat_result['category']
- cat_reverse[cat_name] = catid
- chain.append(cat_name)
- break
- if next_word == '':
- if depth_tries < 100 and len(chain) > 2:
- to_skip = chain[-1]
- if to_skip in cat_reverse:
- to_skip = cat_reverse[to_skip]
- skip_here.append(to_skip)
- to_skip = chain[-2]
- if to_skip in cat_reverse:
- to_skip = cat_reverse[to_skip]
- skip_here.append(to_skip)
- chain.pop()
- chain.pop()
- word_n = chain[-1]
- depth_tries += 1
- elif depth_tries >= 100:
- print(f"{depth_tries} {tries}")
- tries += 1
- if tries >= len(opt_words_b):
- return
- word_n = opt_words_b[tries]
- chain = []
- dist = marked[word_n]
- depth_tries = 0
- continue
- word_n = next_word
- chain.append(opt_word_a)
- chain = list(reversed(chain))
- for i, word in enumerate(chain):
- if (i % 2) == 0:
- print(f"{i+1} -> {word}")
+ word = self.process_word(word, years, catid, tree, target)
+ if word:
+ queue.append(word)
+ if self.should_reset:
+ return
+ return queue
+
+ def process_word(self, word, years, catid, tree, target):
+ if not self.include_oe and self.is_oe(years):
+ return None
+ if not self.include_slang and self.is_slang(years):
+ return None
+ if word not in tree:
+ tree[word] = tree[catid] + 1
+ self.max_dist = max(self.max_dist, tree[word])
+ if word in target:
+ self.make_chain(hinge=word, can_remove=True)
+ return word
+ if word in target:
+ self.make_chain(hinge=word, can_remove=True)
+ return None
+
+ def make_chain(self, hinge, can_remove=True):
+ # tqdm.write(f"Making chain from {hinge}")
+ chain_a = self.descend_chain(hinge, self.tree_a)
+ chain_b = self.descend_chain(hinge, self.tree_b)
+ chain = list(reversed(chain_a)) + [hinge] + chain_b
+ self.display_chain(chain)
+ if can_remove:
+ tqdm.write("Enter a number to break the chain, enter to keep searching, or Ctrl-C to exit")
+ tqdm.write("")
+ index = input("> ").strip()
+ if index and self.is_integer(index):
+ item = chain[int(index)]
+ if item in chain_a:
+ self.add_skip(item, chain_a)
+ self.should_reset = True
+ if item in chain_b:
+ self.add_skip(item, chain_b)
+ self.should_reset = True
+ return True
+ return False
+
+ def add_skip(self, item, chain):
+ index = chain.index(item)
+ if index == len(chain) - 1:
+ return
+ prev_item = chain[index + 1]
+ self.skips.append((prev_item, item))
+ if self.is_integer(item):
+ tqdm.write(f"Removing: {prev_item} => {self.get_category_name(item)}")
else:
- print(f"{i+1} => {word}")
- if prompt_to_remove:
- print("")
- print("If you don't like this path, enter the IDs of words to remove separated by spaces, or Ctrl-C to exit.")
- ids = input("Enter numbers > ")
- ids = ids.split(" ")
- for id in ids:
- if len(id):
- try:
- id = int(id)
- id -= 1
- word_a = chain[id]
- # print(f"Removing {word_a}")
- if word_a in cat_reverse:
- word_a = cat_reverse[word_a]
- word_b = chain[id-1]
- # print(f"Connected upward to {word_b}")
- if word_b in cat_reverse:
- word_b = cat_reverse[word_b]
- if word_a in skip:
- skip[word_a].append(word_b)
- else:
- skip[word_a] = [word_b]
- except Exception as e:
- continue
- return True
- return False
+ tqdm.write(f"Removing: {self.get_category_name(prev_item)} => {item}")
+
+ def descend_chain(self, word, tree):
+ start_word = word
+ chain = []
+ while word is not None:
+ match = None
+ if self.is_integer(word):
+ category_result = self.thesaurus.category(word)
+ for category_word in category_result['words']:
+ cat_word = self.fix_word(category_word['word'])
+ if cat_word != word and cat_word in tree and tree[cat_word] < tree[word]:
+ chain.append(cat_word)
+ match = cat_word
+ break
+ else:
+ categories = self.thesaurus.search(word)['categories']
+ for category in categories:
+ catid = category['catid']
+ if catid != word and catid in tree and tree[catid] < tree[word]:
+ chain.append(catid)
+ match = catid
+ break
+ if match is not None:
+ word = match
+ if tree[word] == 0:
+ break
+ else:
+ if self.is_integer(word):
+ tqdm.write(f"No match for: {self.get_category_name(word)}")
+ tqdm.write(f"Chain started with {start_word}")
+ self.display_chain(chain)
+ else:
+ tqdm.write(f"No match for: {word}")
+ tqdm.write(f"Chain started with {start_word}")
+ self.display_chain(chain)
+ return []
+ return chain
+
+ def display_chain(self, chain):
+ tqdm.write("")
+ for i, word in enumerate(chain):
+ if self.is_integer(word):
+ word = self.get_category_name(word)
+ tqdm.write(f"{i} -> {word}")
+ else:
+ tqdm.write(f"{i} => {word}")
+ tqdm.write("")
+
+ def get_category_name(self, catid):
+ category = self.thesaurus.category(catid)
+ return category['category']
+
+ def is_integer(self, s):
+ try:
+ int(s)
+ return True
+ except Exception as e:
+ return False
+
+ def is_oe(self, years):
+ return (('oe' in years and 'oe-' not in years) or 'arch' in years)
+
+ def is_slang(self, years):
+ return 'slang' in years or 'colloq' in years or 'Scots' in years
+
+ def fix_word(self, word):
+ if '<' in word or '/' in word or ',' in word:
+ word = word.split("<")[0]
+ word = word.split(",")[0]
+ word = word.split("/")[0]
+ return word.strip()