diff --git a/src/app.py b/src/app.py index c07e5f8d..b3314577 100644 --- a/src/app.py +++ b/src/app.py @@ -1,6 +1,8 @@ from contextlib import asynccontextmanager import os +import requests # type: ignore + import uvicorn # type: ignore from pathlib import Path from fastapi import ( @@ -44,6 +46,8 @@ BackendStatusDatasetsSchema, AgentSchema, ServerSchema, + LoginSchema, + TokenSchema, ) @@ -70,6 +74,24 @@ def with_plugins() -> Iterable[PluginManager]: plugins.cleanup() +def get_new_tokens(refresh_token): + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": db.config.openid_client_id, + "client_secret": db.config.openid_secret, + } + url = "http://mquery-keycloak-1:8080/auth/realms/myrealm/protocol/openid-connect/token" + try: + response: requests.Response = requests.post(url=url, data=data) + token_data = response.json() + new_refresh_token = token_data["refresh_token"] + new_token = token_data["access_token"] + return new_token, new_refresh_token + except requests.exceptions.RequestException: + return None, None + + class User: def __init__(self, token: Optional[Dict]) -> None: self.__token = token @@ -124,8 +146,11 @@ async def current_user(authorization: Optional[str] = Header(None)) -> User: token_json = jwt.decode( token, public_key, algorithms=["RS256"], audience="account" # type: ignore ) + except jwt.ExpiredSignatureError: + # token expired so user is anonymous + return User(None) except jwt.InvalidTokenError: - # Invalid token means invalid signature, issuer, or just expired. + # Invalid token means invalid signature, issuer. raise unauthorized return User(token_json) @@ -584,6 +609,35 @@ def server() -> ServerSchema: ) +@app.post("/api/login", response_model=LoginSchema, tags=["stable"]) +async def login(request: Request, response: Response) -> LoginSchema: + token = await request.json() + if token["refresh_token"]: + response.set_cookie( + key="refresh_token", + value=token["refresh_token"], + httponly=True, + max_age=1800, + ) + return LoginSchema(status="OK") + return LoginSchema(status="Bad Token") + + +@app.post("/api/token/refresh", response_model=TokenSchema) +def refresh_token(request: Request, response: Response) -> TokenSchema: + refresh_token_value = request.cookies.get("refresh_token") + if refresh_token_value: + new_token, new_refresh_token = get_new_tokens(refresh_token_value) + response.set_cookie( + key="refresh_token", + value=new_refresh_token, + httponly=True, + max_age=1800, + ) + return TokenSchema(token=new_token) + return TokenSchema(token=None) + + @app.get("/query/{path}", include_in_schema=False) def serve_index(path: str) -> FileResponse: return FileResponse(Path(__file__).parent / "mqueryfront/dist/index.html") diff --git a/src/mqueryfront/src/App.js b/src/mqueryfront/src/App.js index a5e93dd8..4b41e12f 100644 --- a/src/mqueryfront/src/App.js +++ b/src/mqueryfront/src/App.js @@ -1,4 +1,4 @@ -import React, { useState, useEffect } from "react"; +import React, { useState, useRef, useEffect } from "react"; import { Routes, Route } from "react-router-dom"; import Navigation from "./Navigation"; import QueryPage from "./query/QueryPage"; @@ -9,6 +9,7 @@ import AboutPage from "./about/AboutPage"; import AuthPage from "./auth/AuthPage"; import api, { parseJWT } from "./api"; import "./App.css"; +import { refreshAccesToken, storeTokenData, clearTokenData } from "./utils"; function getCurrentTokenOrNull() { // This function handles missing and corrupted token in the same way. @@ -21,20 +22,32 @@ function getCurrentTokenOrNull() { function App() { const [config, setConfig] = useState(null); + const tokenIntervalRef = useRef(null); useEffect(() => { api.get("/server").then((response) => { setConfig(response.data); }); + tokenIntervalRef.current = setInterval(() => { + refreshAccesToken(); + }, 900000); // refresh token every 15 minutes just in case user was idle. + return () => clearInterval(tokenIntervalRef.current); }, []); - - const login = (rawToken) => { - localStorage.setItem("rawToken", rawToken); - window.location.href = "/"; + const login = async (token_data) => { + token_data.not_before_policy = token_data["not-before-policy"]; + delete token_data["not-before-policy"]; + const response = await api.post("/login", token_data); + storeTokenData(token_data["access_token"]); + const location_href = localStorage.getItem("currentLocation"); + if (location_href) { + window.location.href = location_href; + } else { + window.location.href = "/"; + } }; const logout = () => { - localStorage.removeItem("rawToken"); + clearTokenData(tokenIntervalRef.current); if (config !== null) { const logout_url = new URL(config["openid_url"] + "/logout"); logout_url.searchParams.append( diff --git a/src/mqueryfront/src/api.js b/src/mqueryfront/src/api.js index 67cdd3c8..7fa719a8 100644 --- a/src/mqueryfront/src/api.js +++ b/src/mqueryfront/src/api.js @@ -1,4 +1,5 @@ import axios from "axios"; +import { refreshAccesToken, tokenExpired } from "./utils"; export const api_url = "/api"; @@ -8,7 +9,11 @@ export function parseJWT(token) { return JSON.parse(atob(base64)); } -function request(method, path, payload, params) { +async function request(method, path, payload, params) { + if (tokenExpired()) { + // If the token expired, try to refresh it with refresh_token + await refreshAccesToken(); + } const rawToken = localStorage.getItem("rawToken"); const headers = rawToken ? { Authorization: `Bearer ${rawToken}` } : {}; return axios @@ -17,6 +22,7 @@ function request(method, path, payload, params) { data: payload, params: params, headers: headers, + withCredentials: true, }) .catch((error) => { if (error.response.status === 401) { diff --git a/src/mqueryfront/src/auth/AuthPage.js b/src/mqueryfront/src/auth/AuthPage.js index 177135c9..c1ea26be 100644 --- a/src/mqueryfront/src/auth/AuthPage.js +++ b/src/mqueryfront/src/auth/AuthPage.js @@ -37,7 +37,7 @@ class AuthPage extends Component { axios .post(this.props.config["openid_url"] + "/token", params) .then((response) => { - this.props.login(response.data["access_token"]); + this.props.login(response.data); }) .catch((error) => { this.setState({ error: error }); diff --git a/src/mqueryfront/src/config/ConfigPage.js b/src/mqueryfront/src/config/ConfigPage.js index e0cc08dc..59591eb5 100644 --- a/src/mqueryfront/src/config/ConfigPage.js +++ b/src/mqueryfront/src/config/ConfigPage.js @@ -14,6 +14,7 @@ class ConfigPage extends Component { } componentDidMount() { + localStorage.setItem("currentLocation", window.location.href); api.get("/config") .then((response) => { this.setState({ config: response.data }); diff --git a/src/mqueryfront/src/query/QueryPage.js b/src/mqueryfront/src/query/QueryPage.js index af99a911..a266cca9 100644 --- a/src/mqueryfront/src/query/QueryPage.js +++ b/src/mqueryfront/src/query/QueryPage.js @@ -38,6 +38,7 @@ class QueryPageInner extends Component { } async componentDidMount() { + localStorage.setItem("currentLocation", window.location.href); if (this.queryHash) { this.fetchJob(); } diff --git a/src/mqueryfront/src/recent/RecentPage.js b/src/mqueryfront/src/recent/RecentPage.js index 9907b9b4..080622c9 100644 --- a/src/mqueryfront/src/recent/RecentPage.js +++ b/src/mqueryfront/src/recent/RecentPage.js @@ -23,6 +23,7 @@ class RecentPage extends Component { } componentDidMount() { + localStorage.setItem("currentLocation", window.location.href); api.get("/job") .then((response) => { const { jobs } = response.data; diff --git a/src/mqueryfront/src/status/StatusPage.js b/src/mqueryfront/src/status/StatusPage.js index 419a6b28..d55b25fb 100644 --- a/src/mqueryfront/src/status/StatusPage.js +++ b/src/mqueryfront/src/status/StatusPage.js @@ -20,6 +20,7 @@ class StatusPage extends Component { } componentDidMount() { + localStorage.setItem("currentLocation", window.location.href); api.get("/backend") .then((response) => { this.setState({ backend: response.data }); diff --git a/src/mqueryfront/src/utils.js b/src/mqueryfront/src/utils.js index 3d054721..750097c6 100644 --- a/src/mqueryfront/src/utils.js +++ b/src/mqueryfront/src/utils.js @@ -1,3 +1,5 @@ +import axios from "axios"; +import api, { parseJWT } from "./api"; export const isStatusFinished = (status) => ["done", "cancelled"].includes(status); @@ -25,3 +27,45 @@ export const openidLoginUrl = (config) => { ); return login_url; }; + +export const storeTokenData = (token) => { + localStorage.setItem("rawToken", token); + const decodedToken = parseJWT(token); + localStorage.setItem("expiresAt", decodedToken.exp * 1000); +}; + +export const refreshAccesToken = async () => { + const rawToken = localStorage.getItem("rawToken"); + const expiresAt = localStorage.getItem("expiresAt"); + if (rawToken) { + const headers = rawToken ? { Authorization: `Bearer ${rawToken}` } : {}; + const response = await axios.request("/api/token/refresh", { + method: "POST", + headers: headers, + withCredentials: true, + }); + if (response.data["token"]) { + storeTokenData(response.data["token"]); + } else { + return; + } + } +}; + +export const clearTokenData = (tokenInterval) => { + clearInterval(tokenInterval); + localStorage.removeItem("expiresAt"); + localStorage.removeItem("rawToken"); +}; + +export const tokenExpired = () => { + const rawToken = localStorage.getItem("rawToken"); + if (rawToken) { + const expiresAt = localStorage.getItem("expiresAt"); + if (Date.now() > expiresAt) { + return true; + } + return false; + } + return false; +}; diff --git a/src/schema.py b/src/schema.py index 5b2ab8ff..7e67fc7a 100644 --- a/src/schema.py +++ b/src/schema.py @@ -105,3 +105,11 @@ class ServerSchema(BaseModel): openid_url: Optional[str] openid_client_id: Optional[str] about: str + + +class LoginSchema(BaseModel): + status: str + + +class TokenSchema(BaseModel): + token: str | None