Skip to content

Commit

Permalink
Removed error prone key looping
Browse files Browse the repository at this point in the history
  • Loading branch information
kuranium committed Feb 12, 2025
1 parent ec2ca6f commit a07b9e1
Showing 1 changed file with 31 additions and 48 deletions.
79 changes: 31 additions & 48 deletions redash/authentication/jwt_auth.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,65 @@
import json
import logging

import jwt
import requests
from jwt import PyJWKClient


logger = logging.getLogger("jwt_auth")

FILE_SCHEME_PREFIX = "file://"


def get_public_key_from_file(url):
def get_signing_key_from_file(url):
file_path = url[len(FILE_SCHEME_PREFIX) :]
with open(file_path) as key_file:
key_str = key_file.read()

get_public_keys.key_cache[url] = [key_str]
get_signing_key.key_cache[url] = key_str
return key_str


def get_public_key_from_net(url):
r = requests.get(url)
r.raise_for_status()
data = r.json()
if "keys" in data:
public_keys = []
for key_dict in data["keys"]:
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key_dict))
public_keys.append(public_key)

get_public_keys.key_cache[url] = public_keys
return public_keys
else:
get_public_keys.key_cache[url] = data
return data
def get_signing_key_from_net(url, jwt_token):
optional_custom_headers = {"User-agent": "redash"}
client = PyJWKClient(url, headers=optional_custom_headers)
# Gets the matching signing key from the JWKS endpoint
signing_key = client.get_signing_key_from_jwt(jwt_token)
get_signing_key.key_cache[url] = signing_key
return signing_key


def get_public_keys(url):
def get_signing_key(url, jwt_token):
"""
Returns:
List of RSA public keys usable by PyJWT.
Signing key for given jwt_token.
"""
key_cache = get_public_keys.key_cache
keys = {}
key_cache = get_signing_key.key_cache
key = {}
if url in key_cache:
keys = key_cache[url]
key = key_cache[url]
else:
if url.startswith(FILE_SCHEME_PREFIX):
keys = [get_public_key_from_file(url)]
key = [get_signing_key_from_file(url)]
else:
keys = get_public_key_from_net(url)
return keys


get_public_keys.key_cache = {}
key = get_signing_key_from_net(url, jwt_token)
return key

#This cache shoud have a lifespan
get_signing_key.key_cache = {}

def verify_jwt_token(jwt_token, expected_issuer, expected_audience, algorithms, public_certs_url):
# https://developers.cloudflare.com/access/setting-up-access/validate-jwt-tokens/
# https://cloud.google.com/iap/docs/signed-headers-howto
# Loop through the keys since we can't pass the key set to the decoder
keys = get_public_keys(public_certs_url)

key_id = jwt.get_unverified_header(jwt_token).get("kid", "")
if key_id and isinstance(keys, dict):
keys = [keys.get(key_id)]

key = get_signing_key(public_certs_url, jwt_token)
valid_token = False
payload = None
for key in keys:
try:
# decode returns the claims which has the email if you need it
payload = jwt.decode(jwt_token, key=key, audience=expected_audience, algorithms=algorithms)
issuer = payload["iss"]
if issuer != expected_issuer:
raise Exception("Wrong issuer: {}".format(issuer))
valid_token = True
break
except Exception as e:
logging.exception(e)
try:
# decode returns the claims which has the email if you need it
payload = jwt.decode(jwt_token, key=key, audience=expected_audience, algorithms=algorithms)
issuer = payload["iss"]
if issuer != expected_issuer:
raise Exception("Wrong issuer: {}".format(issuer))
valid_token = True
except Exception as e:
logging.exception(e)

return payload, valid_token

0 comments on commit a07b9e1

Please sign in to comment.