diff --git a/.gitignore b/.gitignore
index 09953c70..3ed5b685 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,7 @@ config/prod.json
node_modules
frontend/static/js-build
+data/
static/js-build
.DS_Store
diff --git a/apis/sprints_api.py b/apis/sprints_api.py
index 2dc8d247..083de855 100644
--- a/apis/sprints_api.py
+++ b/apis/sprints_api.py
@@ -3,14 +3,22 @@
import json
import datetime
+import wsbot
from app.db import get_db
from pymysql.cursors import DictCursor
from util.decorators import check_admin, check_request_json
from wikispeedruns import prompts
+from wsbot.search import GreedySearch, BeamSearch
+from wsbot.embeddings import LocalEmbeddings
+from wsbot.graph import APIGraph, SQLGraph
sprint_api = Blueprint('sprints', __name__, url_prefix='/api/sprints')
+ # this script doesn't work
+ # !./get_embeddings.sh
+embeddings_provider = LocalEmbeddings("data/wiki2vec.txt")
+graph_provider = APIGraph()
### Prompt Management Endpoints
@@ -168,3 +176,16 @@ def check_duplicate_prompt():
res = prompts.check_for_sprint_duplicates(start, end)
return jsonify(res)
+# get the next hint
+@sprint_api.get('/hint')
+# @check_request_json({"start": str, "end": str})
+def get_hint():
+ start = request.args.get('start')
+ end = request.args.get('end')
+
+ if (start is None or end is None): return "Invalid Request", 400
+
+ greedy_search = GreedySearch(embeddings_provider, graph_provider)
+ hint = greedy_search.get_next_greedy_link(start, end)
+
+ return [hint]
diff --git a/frontend/static/js/pages/play.js b/frontend/static/js/pages/play.js
index d4a75cd1..39ea606d 100644
--- a/frontend/static/js/pages/play.js
+++ b/frontend/static/js/pages/play.js
@@ -222,6 +222,8 @@ let app = new Vue({
pageCallback: function(page, loadTime) {
window.scrollTo(0, 0);
+ // makes sure hint box is reset after each link is clicked
+ document.getElementById('hint').innerText = ""
this.hidePreview();
if (this.isScroll) {
document.getElementById("wikipedia-frame").scrollTo(0, 0);
@@ -262,6 +264,41 @@ let app = new Vue({
},
+ async getHint(start, end) {
+ document.getElementById('hint').innerText = "Getting hint..."
+ const searchParams = new URLSearchParams({
+ start: start,
+ end: end
+ })
+ console.log("/api/sprints/hint?" + searchParams)
+
+ let hint
+ try{
+ const response = await fetch("/api/sprints/hint?" + searchParams)
+ let tmpData;
+ try {
+ tmpData = await response.json()
+ }
+ catch (err){
+ console.log(err)
+ }
+ hint = tmpData
+ if(!hint){
+ throw err;
+ }
+ }
+ catch(err){
+ document.getElementById('hint').innerText = "Sorry, couldn't find a hint!"
+ }
+
+ if(!hint){
+ document.getElementById('hint').innerText = "Sorry, couldn't find a hint!"
+ }else{
+ document.getElementById('hint').innerText = hint
+ }
+ return hint
+ },
+
async start() {
this.countdownTime = (Date.now() - this.startTime) / 1000;
diff --git a/frontend/templates/play.html b/frontend/templates/play.html
index b25d317d..26485b74 100644
--- a/frontend/templates/play.html
+++ b/frontend/templates/play.html
@@ -61,6 +61,12 @@
+
diff --git a/get_embeddings.sh b/get_embeddings.sh
new file mode 100755
index 00000000..b112828d
--- /dev/null
+++ b/get_embeddings.sh
@@ -0,0 +1,11 @@
+#/bin/bash
+
+EMBEDDINGS_FILE="data/wiki2vec.txt.bz2"
+if [[ -f $EMBEDDINGS_FILE ]]; then
+ mkdir -p data
+ wget "http://wikipedia2vec.s3.amazonaws.com/models/en/2018-04-20/enwiki_20180420_100d.txt.bz2" -O $EMBEDDINGS_FILE.bz2
+ bunzip2 $EMBEDDINGS_FILE.bz2
+else
+ echo "\"$EMBEDDINGS_FILE\" already exists! Skipping..."
+
+fi
\ No newline at end of file
diff --git a/wsbot/__init__.py b/wsbot/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/wsbot/embeddings.py b/wsbot/embeddings.py
new file mode 100644
index 00000000..5f3ce9c5
--- /dev/null
+++ b/wsbot/embeddings.py
@@ -0,0 +1,17 @@
+from abc import ABC, abstractmethod
+from wikipedia2vec import Wikipedia2Vec
+
+class EmbeddingsProvider(ABC):
+ @abstractmethod
+ def get_embedding(self, article: str):
+ pass
+
+ def get_embeddings(self, articles):
+ return [self.get_embeddings(a) for a in articles]
+
+class LocalEmbeddings(EmbeddingsProvider):
+ def __init__(self, filename: str):
+ self.wiki2vec = Wikipedia2Vec.load_text(filename)
+
+ def get_embedding(self, article: str):
+ return self.wiki2vec.get_entity_vector(article)
diff --git a/wsbot/graph.py b/wsbot/graph.py
new file mode 100644
index 00000000..4f83e522
--- /dev/null
+++ b/wsbot/graph.py
@@ -0,0 +1,78 @@
+from abc import ABC, abstractmethod
+from wikipedia2vec import Wikipedia2Vec
+
+import pymysql
+from pymysql.cursors import DictCursor
+
+import requests
+
+# TODO make these context proveriders?
+class GraphProvider(ABC):
+ '''
+ Provide the outgoing links and other operations on the Wikipedia graph
+ '''
+
+ @abstractmethod
+ def get_links(self, article):
+ pass
+
+ def get_links_batch(self, articles):
+ return [self.get_links(a) for a in articles]
+
+
+class APIGraph(GraphProvider):
+ '''
+ Graph queries served by the public Wikipedia API
+ '''
+ URL = "https://en.wikipedia.org/w/api.php"
+ PARAMS = {
+ "action": "query",
+ "format": "json",
+ "prop": "links",
+ "pllimit": "max"
+ }
+
+ def __init__(self):
+ pass
+
+ def _links_from_resp(self, resp):
+ links = list(resp["query"]["pages"].values())[0]["links"]
+ links = [link["title"] for link in links]
+ return list(filter(lambda title: ":" not in title, links))
+
+ def get_links(self, article):
+ resp = requests.get(self.URL, params={**self.PARAMS, "titles": article}).json()
+ return self._links_from_resp(resp)
+
+ def get_links_batch(self, articles):
+ # TODO figure out what happens if this returns too much
+ resp = requests.get(url, params={**self.PARAMS, "titles": "|".join(articles)}).json()
+ return self._links_from_resp(resp)
+
+
+class SQLGraph(GraphProvider):
+ '''
+ Graph queries served by the custom wikipedia speedruns SQL database graph
+ '''
+ def __init__(self, host, user, password, database):
+ self.db = pymysql.connect(host=host, user=user, password=password, database=database)
+ self.cursor = self.db.cursor(cursor=DictCursor)
+
+ def get_links(self, article):
+ id_query = "SELECT * FROM articleid WHERE name=%s"
+ edge_query = """
+ SELECT a.name FROM edgeidarticleid AS e
+ JOIN articleid AS a
+ ON e.dest = a.articleID
+ WHERE e.src = %s
+ """
+ self.cursor.execute(id_query, article)
+ article_id = self.cursor.fetchone()["articleID"]
+ if article_id is None: return None
+
+ self.cursor.execute(edge_query, article_id)
+
+ return [row["name"] for row in self.cursor.fetchall()]
+
+ # TODO write a query that does this properly
+ #def get_links_batch(self, articles):
diff --git a/wsbot/search.py b/wsbot/search.py
new file mode 100644
index 00000000..e3b88d9a
--- /dev/null
+++ b/wsbot/search.py
@@ -0,0 +1,152 @@
+import scipy
+from scipy.spatial import distance
+
+# TODO base class
+
+class MaxIterationsException(Exception):
+ pass
+
+class PathNotFoundException(Exception):
+ pass
+
+
+class GreedySearch:
+ def __init__(self, embedding_provider, graph_provider, max_iterations=20):
+ self.embeddings = embedding_provider
+ self.graph = graph_provider
+ self.max_iterations = max_iterations
+
+ # Returns the next link to go to based on a greedy approach
+ def get_next_greedy_link(self, start: str, end: str):
+ min_dist = 2
+ next_article = ""
+ end_v = self.embeddings.get_embedding(end)
+
+ for link in self.graph.get_links(start):
+ if (link == end):
+ return link
+ try:
+ cur_v = self.embeddings.get_embedding(link)
+ except KeyError:
+ continue
+ dist = distance.cosine(cur_v, end_v)
+ print(dist)
+ if dist <= min_dist:
+ next_article = link
+ min_dist = dist
+
+ if next_article == "":
+ raise PathNotFoundException(f"GreedySearch: could not find path, current: {ret}")
+ return next_article
+
+
+ def search(self, start: str, end: str):
+ # Greedily searches the wikipedia graph
+
+ # Replace with this code to use the get_next_greedy_link helper function.
+ # Currently the original implementation is uncommented.
+ # cur = start
+ # ret = [start, ]
+
+ # for i in range(self.max_iterations):
+ # next_article = get_next_greedy_link(cur, end)
+ # ret.append(next_article)
+ # if(next_article == end):
+ # return ret
+ # cur = next_article
+
+ # raise MaxIterationsException(f"GreedySearch: Max iterations {self.max_iterations} reached, current path: {ret}")
+
+ cur = start
+ end_v = self.embeddings.get_embedding(end)
+
+ ret = [start, ]
+
+ for i in range(self.max_iterations):
+ min_dist = 2
+ next_article = ""
+
+ for link in self.graph.get_links(cur):
+ if link in ret:
+ continue
+
+ if (link == end):
+ #print(f"Found link in {cur}!")
+ ret.append(link)
+ return ret
+
+ try:
+ cur_v = self.embeddings.get_embedding(link)
+ except KeyError:
+ continue
+
+ dist = distance.cosine(cur_v, end_v)
+
+ if dist <= min_dist:
+ next_article = link
+ min_dist = dist
+
+ if next_article == "":
+ raise PathNotFoundException(f"GreedySearch: could not find path, current: {ret}")
+
+ ret.append(next_article)
+ cur = next_article
+
+ raise MaxIterationsException(f"GreedySearch: Max iterations {self.max_iterations} reached, current path: {ret}")
+
+
+class BeamSearch:
+ def __init__(self, embedding_provider, graph_provider, max_iterations=20, width=10):
+ self.embeddings = embedding_provider
+ self.graph = graph_provider
+ self.max_iterations = max_iterations
+ self.width = width
+
+ def _get_path(self, end, parent):
+ ret = []
+ cur = end
+ while(parent[cur] != cur):
+ ret.append(cur)
+ cur = parent[cur]
+
+ ret.append(cur)
+ return list(reversed(ret))
+
+
+ def search(self, start: str, end: str):
+ # Define distance metric
+ # TODO customizable
+ end_v = self.embeddings.get_embedding(end)
+ def get_dist(article):
+ try:
+ cur_v = self.embeddings.get_embedding(link)
+ except KeyError:
+ return 100
+ return distance.cosine(cur_v, end_v)
+
+ # Greedily searches the wikipedia graph
+ cur_set = [start]
+ # Keeps track of parent articles, also serves as visitor set
+ parent = {start: start}
+
+ for i in range(self.max_iterations):
+ next_set = []
+ for article in cur_set:
+ outgoing = self.graph.get_links(article)
+ for link in outgoing:
+ if link in parent:
+ continue
+ parent[link] = article
+ next_set.append((get_dist(link), link))
+
+ if link == end:
+ return self._get_path(link, parent)
+
+ cur_set = [article for (_, article) in sorted(next_set)]
+ cur_set = cur_set[:self.width]
+ print(f"Articles in iteration {i}: ", cur_set)
+
+ raise MaxIterationsException(f"BeamSearch: Max iterations {self.max_iterations} reached")
+
+# TODO probabilistic search (for random results)
+# TODO other heuristics