From c1ea0ed3871923d957712ede6105d353d8751595 Mon Sep 17 00:00:00 2001 From: scooterbreak Date: Sun, 28 Jan 2024 14:28:51 -0800 Subject: [PATCH 1/2] initial commit for hints feature --- .gitignore | 1 + apis/sprints_api.py | 28 ++++++++ frontend/static/js/pages/play.js | 34 +++++++++ frontend/templates/play.html | 6 ++ get_embeddings.sh | 11 +++ wsbot/__init__.py | 0 wsbot/embeddings.py | 17 +++++ wsbot/graph.py | 78 +++++++++++++++++++++ wsbot/search.py | 114 +++++++++++++++++++++++++++++++ 9 files changed, 289 insertions(+) create mode 100755 get_embeddings.sh create mode 100644 wsbot/__init__.py create mode 100644 wsbot/embeddings.py create mode 100644 wsbot/graph.py create mode 100644 wsbot/search.py diff --git a/.gitignore b/.gitignore index 331c6cc5..8c0e0faf 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ config/prod.json node_modules frontend/static/js-build +data/ .DS_Store .vscode \ No newline at end of file diff --git a/apis/sprints_api.py b/apis/sprints_api.py index 2dc8d247..2314f397 100644 --- a/apis/sprints_api.py +++ b/apis/sprints_api.py @@ -3,14 +3,22 @@ import json import datetime +import wsbot from app.db import get_db from pymysql.cursors import DictCursor from util.decorators import check_admin, check_request_json from wikispeedruns import prompts +from wsbot.search import GreedySearch, BeamSearch +from wsbot.embeddings import LocalEmbeddings +from wsbot.graph import APIGraph, SQLGraph sprint_api = Blueprint('sprints', __name__, url_prefix='/api/sprints') + # this script doesn't work + # !./get_embeddings.sh +embeddings_provider = LocalEmbeddings("data/wiki2vec.txt") +graph_provider = APIGraph() ### Prompt Management Endpoints @@ -168,3 +176,23 @@ def check_duplicate_prompt(): res = prompts.check_for_sprint_duplicates(start, end) return jsonify(res) +# get the next hint +@sprint_api.get('/hint') +# @check_request_json({"start": str, "end": str}) +def get_hint(): + start = request.args.get('start') + end = request.args.get('end') + + print(start) + print(end) + + if (start is None or end is None): return "Invalid Request", 400 + + # which algorithm to use? + # greedy = GreedySearch(embeddings_provider, graph_provider) + # path = greedy.search(start, end) + + beam = BeamSearch(embeddings_provider, graph_provider) + path = beam.search(start, end) + + return path diff --git a/frontend/static/js/pages/play.js b/frontend/static/js/pages/play.js index e183ad86..34ef43ed 100644 --- a/frontend/static/js/pages/play.js +++ b/frontend/static/js/pages/play.js @@ -231,6 +231,40 @@ let app = new Vue({ }, + async getHint(start, end) { + document.getElementById('hint').innerText = "Getting hint..." + const searchParams = new URLSearchParams({ + start: start, + end: end + }) + console.log("HI") + console.log("/api/sprints/hint?" + searchParams) + + let hint + try{ + const response = await fetch("/api/sprints/hint?" + searchParams) + let tmpData = await response.json() + hint = tmpData[1] + if(!hint){ + throw err; + } + } + catch(err){ + document.getElementById('hint').innerText = "Sorry, couldn't find a hint!" + } + + // const response = await fetch("/api/sprints/hint?" + searchParams) + // let tmpData = await response.json() + // hint = tmpData[1] + // console.log(hint) + if(!hint){ + document.getElementById('hint').innerText = "Sorry, couldn't find a hint!" + }else{ + document.getElementById('hint').innerText = hint + } + return hint + }, + async start() { this.countdownTime = (Date.now() - this.startTime) / 1000; diff --git a/frontend/templates/play.html b/frontend/templates/play.html index 2d4d1d18..eea38c05 100644 --- a/frontend/templates/play.html +++ b/frontend/templates/play.html @@ -57,6 +57,12 @@ Current Article
[[currentArticle]] +
+ +
+
diff --git a/get_embeddings.sh b/get_embeddings.sh new file mode 100755 index 00000000..b112828d --- /dev/null +++ b/get_embeddings.sh @@ -0,0 +1,11 @@ +#/bin/bash + +EMBEDDINGS_FILE="data/wiki2vec.txt.bz2" +if [[ -f $EMBEDDINGS_FILE ]]; then + mkdir -p data + wget "http://wikipedia2vec.s3.amazonaws.com/models/en/2018-04-20/enwiki_20180420_100d.txt.bz2" -O $EMBEDDINGS_FILE.bz2 + bunzip2 $EMBEDDINGS_FILE.bz2 +else + echo "\"$EMBEDDINGS_FILE\" already exists! Skipping..." + +fi \ No newline at end of file diff --git a/wsbot/__init__.py b/wsbot/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/wsbot/embeddings.py b/wsbot/embeddings.py new file mode 100644 index 00000000..5f3ce9c5 --- /dev/null +++ b/wsbot/embeddings.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from wikipedia2vec import Wikipedia2Vec + +class EmbeddingsProvider(ABC): + @abstractmethod + def get_embedding(self, article: str): + pass + + def get_embeddings(self, articles): + return [self.get_embeddings(a) for a in articles] + +class LocalEmbeddings(EmbeddingsProvider): + def __init__(self, filename: str): + self.wiki2vec = Wikipedia2Vec.load_text(filename) + + def get_embedding(self, article: str): + return self.wiki2vec.get_entity_vector(article) diff --git a/wsbot/graph.py b/wsbot/graph.py new file mode 100644 index 00000000..4f83e522 --- /dev/null +++ b/wsbot/graph.py @@ -0,0 +1,78 @@ +from abc import ABC, abstractmethod +from wikipedia2vec import Wikipedia2Vec + +import pymysql +from pymysql.cursors import DictCursor + +import requests + +# TODO make these context proveriders? +class GraphProvider(ABC): + ''' + Provide the outgoing links and other operations on the Wikipedia graph + ''' + + @abstractmethod + def get_links(self, article): + pass + + def get_links_batch(self, articles): + return [self.get_links(a) for a in articles] + + +class APIGraph(GraphProvider): + ''' + Graph queries served by the public Wikipedia API + ''' + URL = "https://en.wikipedia.org/w/api.php" + PARAMS = { + "action": "query", + "format": "json", + "prop": "links", + "pllimit": "max" + } + + def __init__(self): + pass + + def _links_from_resp(self, resp): + links = list(resp["query"]["pages"].values())[0]["links"] + links = [link["title"] for link in links] + return list(filter(lambda title: ":" not in title, links)) + + def get_links(self, article): + resp = requests.get(self.URL, params={**self.PARAMS, "titles": article}).json() + return self._links_from_resp(resp) + + def get_links_batch(self, articles): + # TODO figure out what happens if this returns too much + resp = requests.get(url, params={**self.PARAMS, "titles": "|".join(articles)}).json() + return self._links_from_resp(resp) + + +class SQLGraph(GraphProvider): + ''' + Graph queries served by the custom wikipedia speedruns SQL database graph + ''' + def __init__(self, host, user, password, database): + self.db = pymysql.connect(host=host, user=user, password=password, database=database) + self.cursor = self.db.cursor(cursor=DictCursor) + + def get_links(self, article): + id_query = "SELECT * FROM articleid WHERE name=%s" + edge_query = """ + SELECT a.name FROM edgeidarticleid AS e + JOIN articleid AS a + ON e.dest = a.articleID + WHERE e.src = %s + """ + self.cursor.execute(id_query, article) + article_id = self.cursor.fetchone()["articleID"] + if article_id is None: return None + + self.cursor.execute(edge_query, article_id) + + return [row["name"] for row in self.cursor.fetchall()] + + # TODO write a query that does this properly + #def get_links_batch(self, articles): diff --git a/wsbot/search.py b/wsbot/search.py new file mode 100644 index 00000000..74909aaa --- /dev/null +++ b/wsbot/search.py @@ -0,0 +1,114 @@ +import scipy +from scipy.spatial import distance + +# TODO base class + +class MaxIterationsException(Exception): + pass + +class PathNotFoundException(Exception): + pass + + +class GreedySearch: + def __init__(self, embedding_provider, graph_provider, max_iterations=20): + self.embeddings = embedding_provider + self.graph = graph_provider + self.max_iterations = max_iterations + + + def search(self, start: str, end: str): + # Greedily searches the wikipedia graph + cur = start + end_v = self.embeddings.get_embedding(end) + + ret = [start, ] + + for i in range(self.max_iterations): + min_dist = 2 + next_article = "" + + for link in self.graph.get_links(cur): + if link in ret: + continue + + if (link == end): + #print(f"Found link in {cur}!") + ret.append(link) + return ret + + try: + cur_v = self.embeddings.get_embedding(link) + except KeyError: + continue + + dist = distance.cosine(cur_v, end_v) + + if dist <= min_dist: + next_article = link + min_dist = dist + + if next_article == "": + raise PathNotFoundException(f"GreedySearch: could not find path, current: {ret}") + + ret.append(next_article) + cur = next_article + + raise MaxIterationsException(f"GreedySearch: Max iterations {self.max_iterations} reached, current path: {ret}") + + +class BeamSearch: + def __init__(self, embedding_provider, graph_provider, max_iterations=20, width=10): + self.embeddings = embedding_provider + self.graph = graph_provider + self.max_iterations = max_iterations + self.width = width + + def _get_path(self, end, parent): + ret = [] + cur = end + while(parent[cur] != cur): + ret.append(cur) + cur = parent[cur] + + ret.append(cur) + return list(reversed(ret)) + + + def search(self, start: str, end: str): + # Define distance metric + # TODO customizable + end_v = self.embeddings.get_embedding(end) + def get_dist(article): + try: + cur_v = self.embeddings.get_embedding(link) + except KeyError: + return 100 + return distance.cosine(cur_v, end_v) + + # Greedily searches the wikipedia graph + cur_set = [start] + # Keeps track of parent articles, also serves as visitor set + parent = {start: start} + + for i in range(self.max_iterations): + next_set = [] + for article in cur_set: + outgoing = self.graph.get_links(article) + for link in outgoing: + if link in parent: + continue + parent[link] = article + next_set.append((get_dist(link), link)) + + if link == end: + return self._get_path(link, parent) + + cur_set = [article for (_, article) in sorted(next_set)] + cur_set = cur_set[:self.width] + print(f"Articles in iteration {i}: ", cur_set) + + raise MaxIterationsException(f"BeamSearch: Max iterations {self.max_iterations} reached") + +# TODO probabilistic search (for random results) +# TODO other heuristics From aca87b137715aaf0bbc2ebfcbf40a42c32fdbdfb Mon Sep 17 00:00:00 2001 From: scooterbreak Date: Wed, 19 Jun 2024 14:49:24 -0700 Subject: [PATCH 2/2] added hints function based on greedy approach --- apis/sprints_api.py | 13 +++-------- frontend/static/js/pages/play.js | 17 ++++++++------ wsbot/search.py | 40 +++++++++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 18 deletions(-) diff --git a/apis/sprints_api.py b/apis/sprints_api.py index 2314f397..083de855 100644 --- a/apis/sprints_api.py +++ b/apis/sprints_api.py @@ -183,16 +183,9 @@ def get_hint(): start = request.args.get('start') end = request.args.get('end') - print(start) - print(end) - if (start is None or end is None): return "Invalid Request", 400 - # which algorithm to use? - # greedy = GreedySearch(embeddings_provider, graph_provider) - # path = greedy.search(start, end) - - beam = BeamSearch(embeddings_provider, graph_provider) - path = beam.search(start, end) + greedy_search = GreedySearch(embeddings_provider, graph_provider) + hint = greedy_search.get_next_greedy_link(start, end) - return path + return [hint] diff --git a/frontend/static/js/pages/play.js b/frontend/static/js/pages/play.js index 34ef43ed..72845971 100644 --- a/frontend/static/js/pages/play.js +++ b/frontend/static/js/pages/play.js @@ -197,6 +197,8 @@ let app = new Vue({ pageCallback: function(page, loadTime) { window.scrollTo(0, 0); + // makes sure hint box is reset after each link is clicked + document.getElementById('hint').innerText = "" this.hidePreview(); if (this.isScroll) { document.getElementById("wikipedia-frame").scrollTo(0, 0); @@ -237,14 +239,19 @@ let app = new Vue({ start: start, end: end }) - console.log("HI") console.log("/api/sprints/hint?" + searchParams) let hint try{ const response = await fetch("/api/sprints/hint?" + searchParams) - let tmpData = await response.json() - hint = tmpData[1] + let tmpData; + try { + tmpData = await response.json() + } + catch (err){ + console.log(err) + } + hint = tmpData if(!hint){ throw err; } @@ -253,10 +260,6 @@ let app = new Vue({ document.getElementById('hint').innerText = "Sorry, couldn't find a hint!" } - // const response = await fetch("/api/sprints/hint?" + searchParams) - // let tmpData = await response.json() - // hint = tmpData[1] - // console.log(hint) if(!hint){ document.getElementById('hint').innerText = "Sorry, couldn't find a hint!" }else{ diff --git a/wsbot/search.py b/wsbot/search.py index 74909aaa..e3b88d9a 100644 --- a/wsbot/search.py +++ b/wsbot/search.py @@ -16,9 +16,47 @@ def __init__(self, embedding_provider, graph_provider, max_iterations=20): self.graph = graph_provider self.max_iterations = max_iterations + # Returns the next link to go to based on a greedy approach + def get_next_greedy_link(self, start: str, end: str): + min_dist = 2 + next_article = "" + end_v = self.embeddings.get_embedding(end) + + for link in self.graph.get_links(start): + if (link == end): + return link + try: + cur_v = self.embeddings.get_embedding(link) + except KeyError: + continue + dist = distance.cosine(cur_v, end_v) + print(dist) + if dist <= min_dist: + next_article = link + min_dist = dist + + if next_article == "": + raise PathNotFoundException(f"GreedySearch: could not find path, current: {ret}") + return next_article + def search(self, start: str, end: str): # Greedily searches the wikipedia graph + + # Replace with this code to use the get_next_greedy_link helper function. + # Currently the original implementation is uncommented. + # cur = start + # ret = [start, ] + + # for i in range(self.max_iterations): + # next_article = get_next_greedy_link(cur, end) + # ret.append(next_article) + # if(next_article == end): + # return ret + # cur = next_article + + # raise MaxIterationsException(f"GreedySearch: Max iterations {self.max_iterations} reached, current path: {ret}") + cur = start end_v = self.embeddings.get_embedding(end) @@ -50,7 +88,7 @@ def search(self, start: str, end: str): if next_article == "": raise PathNotFoundException(f"GreedySearch: could not find path, current: {ret}") - + ret.append(next_article) cur = next_article