diff --git a/.gitignore b/.gitignore index 09953c70..3ed5b685 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ config/prod.json node_modules frontend/static/js-build +data/ static/js-build .DS_Store diff --git a/apis/sprints_api.py b/apis/sprints_api.py index 2dc8d247..083de855 100644 --- a/apis/sprints_api.py +++ b/apis/sprints_api.py @@ -3,14 +3,22 @@ import json import datetime +import wsbot from app.db import get_db from pymysql.cursors import DictCursor from util.decorators import check_admin, check_request_json from wikispeedruns import prompts +from wsbot.search import GreedySearch, BeamSearch +from wsbot.embeddings import LocalEmbeddings +from wsbot.graph import APIGraph, SQLGraph sprint_api = Blueprint('sprints', __name__, url_prefix='/api/sprints') + # this script doesn't work + # !./get_embeddings.sh +embeddings_provider = LocalEmbeddings("data/wiki2vec.txt") +graph_provider = APIGraph() ### Prompt Management Endpoints @@ -168,3 +176,16 @@ def check_duplicate_prompt(): res = prompts.check_for_sprint_duplicates(start, end) return jsonify(res) +# get the next hint +@sprint_api.get('/hint') +# @check_request_json({"start": str, "end": str}) +def get_hint(): + start = request.args.get('start') + end = request.args.get('end') + + if (start is None or end is None): return "Invalid Request", 400 + + greedy_search = GreedySearch(embeddings_provider, graph_provider) + hint = greedy_search.get_next_greedy_link(start, end) + + return [hint] diff --git a/frontend/static/js/pages/play.js b/frontend/static/js/pages/play.js index d4a75cd1..39ea606d 100644 --- a/frontend/static/js/pages/play.js +++ b/frontend/static/js/pages/play.js @@ -222,6 +222,8 @@ let app = new Vue({ pageCallback: function(page, loadTime) { window.scrollTo(0, 0); + // makes sure hint box is reset after each link is clicked + document.getElementById('hint').innerText = "" this.hidePreview(); if (this.isScroll) { document.getElementById("wikipedia-frame").scrollTo(0, 0); @@ -262,6 +264,41 @@ let app = new Vue({ }, + async getHint(start, end) { + document.getElementById('hint').innerText = "Getting hint..." + const searchParams = new URLSearchParams({ + start: start, + end: end + }) + console.log("/api/sprints/hint?" + searchParams) + + let hint + try{ + const response = await fetch("/api/sprints/hint?" + searchParams) + let tmpData; + try { + tmpData = await response.json() + } + catch (err){ + console.log(err) + } + hint = tmpData + if(!hint){ + throw err; + } + } + catch(err){ + document.getElementById('hint').innerText = "Sorry, couldn't find a hint!" + } + + if(!hint){ + document.getElementById('hint').innerText = "Sorry, couldn't find a hint!" + }else{ + document.getElementById('hint').innerText = hint + } + return hint + }, + async start() { this.countdownTime = (Date.now() - this.startTime) / 1000; diff --git a/frontend/templates/play.html b/frontend/templates/play.html index b25d317d..26485b74 100644 --- a/frontend/templates/play.html +++ b/frontend/templates/play.html @@ -61,6 +61,12 @@ +
+ +
+
diff --git a/get_embeddings.sh b/get_embeddings.sh new file mode 100755 index 00000000..b112828d --- /dev/null +++ b/get_embeddings.sh @@ -0,0 +1,11 @@ +#/bin/bash + +EMBEDDINGS_FILE="data/wiki2vec.txt.bz2" +if [[ -f $EMBEDDINGS_FILE ]]; then + mkdir -p data + wget "http://wikipedia2vec.s3.amazonaws.com/models/en/2018-04-20/enwiki_20180420_100d.txt.bz2" -O $EMBEDDINGS_FILE.bz2 + bunzip2 $EMBEDDINGS_FILE.bz2 +else + echo "\"$EMBEDDINGS_FILE\" already exists! Skipping..." + +fi \ No newline at end of file diff --git a/wsbot/__init__.py b/wsbot/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/wsbot/embeddings.py b/wsbot/embeddings.py new file mode 100644 index 00000000..5f3ce9c5 --- /dev/null +++ b/wsbot/embeddings.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from wikipedia2vec import Wikipedia2Vec + +class EmbeddingsProvider(ABC): + @abstractmethod + def get_embedding(self, article: str): + pass + + def get_embeddings(self, articles): + return [self.get_embeddings(a) for a in articles] + +class LocalEmbeddings(EmbeddingsProvider): + def __init__(self, filename: str): + self.wiki2vec = Wikipedia2Vec.load_text(filename) + + def get_embedding(self, article: str): + return self.wiki2vec.get_entity_vector(article) diff --git a/wsbot/graph.py b/wsbot/graph.py new file mode 100644 index 00000000..4f83e522 --- /dev/null +++ b/wsbot/graph.py @@ -0,0 +1,78 @@ +from abc import ABC, abstractmethod +from wikipedia2vec import Wikipedia2Vec + +import pymysql +from pymysql.cursors import DictCursor + +import requests + +# TODO make these context proveriders? +class GraphProvider(ABC): + ''' + Provide the outgoing links and other operations on the Wikipedia graph + ''' + + @abstractmethod + def get_links(self, article): + pass + + def get_links_batch(self, articles): + return [self.get_links(a) for a in articles] + + +class APIGraph(GraphProvider): + ''' + Graph queries served by the public Wikipedia API + ''' + URL = "https://en.wikipedia.org/w/api.php" + PARAMS = { + "action": "query", + "format": "json", + "prop": "links", + "pllimit": "max" + } + + def __init__(self): + pass + + def _links_from_resp(self, resp): + links = list(resp["query"]["pages"].values())[0]["links"] + links = [link["title"] for link in links] + return list(filter(lambda title: ":" not in title, links)) + + def get_links(self, article): + resp = requests.get(self.URL, params={**self.PARAMS, "titles": article}).json() + return self._links_from_resp(resp) + + def get_links_batch(self, articles): + # TODO figure out what happens if this returns too much + resp = requests.get(url, params={**self.PARAMS, "titles": "|".join(articles)}).json() + return self._links_from_resp(resp) + + +class SQLGraph(GraphProvider): + ''' + Graph queries served by the custom wikipedia speedruns SQL database graph + ''' + def __init__(self, host, user, password, database): + self.db = pymysql.connect(host=host, user=user, password=password, database=database) + self.cursor = self.db.cursor(cursor=DictCursor) + + def get_links(self, article): + id_query = "SELECT * FROM articleid WHERE name=%s" + edge_query = """ + SELECT a.name FROM edgeidarticleid AS e + JOIN articleid AS a + ON e.dest = a.articleID + WHERE e.src = %s + """ + self.cursor.execute(id_query, article) + article_id = self.cursor.fetchone()["articleID"] + if article_id is None: return None + + self.cursor.execute(edge_query, article_id) + + return [row["name"] for row in self.cursor.fetchall()] + + # TODO write a query that does this properly + #def get_links_batch(self, articles): diff --git a/wsbot/search.py b/wsbot/search.py new file mode 100644 index 00000000..e3b88d9a --- /dev/null +++ b/wsbot/search.py @@ -0,0 +1,152 @@ +import scipy +from scipy.spatial import distance + +# TODO base class + +class MaxIterationsException(Exception): + pass + +class PathNotFoundException(Exception): + pass + + +class GreedySearch: + def __init__(self, embedding_provider, graph_provider, max_iterations=20): + self.embeddings = embedding_provider + self.graph = graph_provider + self.max_iterations = max_iterations + + # Returns the next link to go to based on a greedy approach + def get_next_greedy_link(self, start: str, end: str): + min_dist = 2 + next_article = "" + end_v = self.embeddings.get_embedding(end) + + for link in self.graph.get_links(start): + if (link == end): + return link + try: + cur_v = self.embeddings.get_embedding(link) + except KeyError: + continue + dist = distance.cosine(cur_v, end_v) + print(dist) + if dist <= min_dist: + next_article = link + min_dist = dist + + if next_article == "": + raise PathNotFoundException(f"GreedySearch: could not find path, current: {ret}") + return next_article + + + def search(self, start: str, end: str): + # Greedily searches the wikipedia graph + + # Replace with this code to use the get_next_greedy_link helper function. + # Currently the original implementation is uncommented. + # cur = start + # ret = [start, ] + + # for i in range(self.max_iterations): + # next_article = get_next_greedy_link(cur, end) + # ret.append(next_article) + # if(next_article == end): + # return ret + # cur = next_article + + # raise MaxIterationsException(f"GreedySearch: Max iterations {self.max_iterations} reached, current path: {ret}") + + cur = start + end_v = self.embeddings.get_embedding(end) + + ret = [start, ] + + for i in range(self.max_iterations): + min_dist = 2 + next_article = "" + + for link in self.graph.get_links(cur): + if link in ret: + continue + + if (link == end): + #print(f"Found link in {cur}!") + ret.append(link) + return ret + + try: + cur_v = self.embeddings.get_embedding(link) + except KeyError: + continue + + dist = distance.cosine(cur_v, end_v) + + if dist <= min_dist: + next_article = link + min_dist = dist + + if next_article == "": + raise PathNotFoundException(f"GreedySearch: could not find path, current: {ret}") + + ret.append(next_article) + cur = next_article + + raise MaxIterationsException(f"GreedySearch: Max iterations {self.max_iterations} reached, current path: {ret}") + + +class BeamSearch: + def __init__(self, embedding_provider, graph_provider, max_iterations=20, width=10): + self.embeddings = embedding_provider + self.graph = graph_provider + self.max_iterations = max_iterations + self.width = width + + def _get_path(self, end, parent): + ret = [] + cur = end + while(parent[cur] != cur): + ret.append(cur) + cur = parent[cur] + + ret.append(cur) + return list(reversed(ret)) + + + def search(self, start: str, end: str): + # Define distance metric + # TODO customizable + end_v = self.embeddings.get_embedding(end) + def get_dist(article): + try: + cur_v = self.embeddings.get_embedding(link) + except KeyError: + return 100 + return distance.cosine(cur_v, end_v) + + # Greedily searches the wikipedia graph + cur_set = [start] + # Keeps track of parent articles, also serves as visitor set + parent = {start: start} + + for i in range(self.max_iterations): + next_set = [] + for article in cur_set: + outgoing = self.graph.get_links(article) + for link in outgoing: + if link in parent: + continue + parent[link] = article + next_set.append((get_dist(link), link)) + + if link == end: + return self._get_path(link, parent) + + cur_set = [article for (_, article) in sorted(next_set)] + cur_set = cur_set[:self.width] + print(f"Articles in iteration {i}: ", cur_set) + + raise MaxIterationsException(f"BeamSearch: Max iterations {self.max_iterations} reached") + +# TODO probabilistic search (for random results) +# TODO other heuristics