diff --git a/src/cryptojwt/__init__.py b/src/cryptojwt/__init__.py index 1ae167c..4444478 100644 --- a/src/cryptojwt/__init__.py +++ b/src/cryptojwt/__init__.py @@ -21,7 +21,7 @@ except ImportError: pass -__version__ = "1.2.0" +__version__ = "1.3.0" logger = logging.getLogger(__name__) diff --git a/src/cryptojwt/exception.py b/src/cryptojwt/exception.py index b56fd92..74a9d8c 100644 --- a/src/cryptojwt/exception.py +++ b/src/cryptojwt/exception.py @@ -63,10 +63,6 @@ class UpdateFailed(KeyIOError): pass -class UnknownKeytype(Invalid): - """An unknown key type""" - - class JWKException(JWKESTException): pass diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 9b2f200..4a30faf 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -4,6 +4,7 @@ import logging import os import time +from datetime import datetime from functools import cmp_to_key import requests @@ -156,6 +157,7 @@ def __init__( keys=None, source="", cache_time=300, + ignore_errors_period=0, fileformat="jwks", keytype="RSA", keyusage=None, @@ -188,6 +190,8 @@ def __init__( self.remote = False self.local = False self.cache_time = cache_time + self.ignore_errors_period = ignore_errors_period + self.ignore_errors_until = None # UNIX timestamp of last error self.time_out = 0 self.etag = "" self.source = None @@ -314,7 +318,11 @@ def do_local_jwk(self, filename): Load a JWKS from a local file :param filename: Name of the file from which the JWKS should be loaded + :return: True if load was successful or False if file hasn't been modified """ + if not self._local_update_required(): + return False + LOGGER.info("Reading local JWKS from %s", filename) with open(filename) as input_file: _info = json.load(input_file) @@ -324,6 +332,7 @@ def do_local_jwk(self, filename): self.do_keys([_info]) self.last_local = time.time() self.time_out = self.last_local + self.cache_time + return True def do_local_der(self, filename, keytype, keyusage=None, kid=""): """ @@ -332,7 +341,11 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""): :param filename: Name of the file :param keytype: Presently 'rsa' and 'ec' supported :param keyusage: encryption ('enc') or signing ('sig') or both + :return: True if load was successful or False if file hasn't been modified """ + if not self._local_update_required(): + return False + LOGGER.info("Reading local DER from %s", filename) key_args = {} _kty = keytype.lower() @@ -355,16 +368,25 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""): self.do_keys([key_args]) self.last_local = time.time() self.time_out = self.last_local + self.cache_time + return True def do_remote(self): """ Load a JWKS from a webpage. - :return: True or False if load was successful + :return: True if load was successful or False if remote hasn't been modified """ # if self.verify_ssl is not None: # self.httpc_params["verify"] = self.verify_ssl + if self.ignore_errors_until and time.time() < self.ignore_errors_until: + LOGGER.warning( + "Not reading remote JWKS from %s (in error holddown until %s)", + self.source, + datetime.fromtimestamp(self.ignore_errors_until), + ) + return False + LOGGER.info("Reading remote JWKS from %s", self.source) try: LOGGER.debug("KeyBundle fetch keys from: %s", self.source) @@ -378,7 +400,10 @@ def do_remote(self): LOGGER.error(err) raise UpdateFailed(REMOTE_FAILED.format(self.source, str(err))) - if _http_resp.status_code == 200: # New content + load_successful = _http_resp.status_code == 200 + not_modified = _http_resp.status_code == 304 + + if load_successful: self.time_out = time.time() + self.cache_time self.imp_jwks = self._parse_remote_response(_http_resp) @@ -390,25 +415,27 @@ def do_remote(self): self.do_keys(self.imp_jwks["keys"]) except KeyError: LOGGER.error("No 'keys' keyword in JWKS") + self.ignore_errors_until = time.time() + self.ignore_errors_period raise UpdateFailed(MALFORMED.format(self.source)) if hasattr(_http_resp, "headers"): headers = getattr(_http_resp, "headers") self.last_remote = headers.get("last-modified") or headers.get("date") - - elif _http_resp.status_code == 304: # Not modified + elif not_modified: LOGGER.debug("%s not modified since %s", self.source, self.last_remote) self.time_out = time.time() + self.cache_time - else: LOGGER.warning( "HTTP status %d reading remote JWKS from %s", _http_resp.status_code, self.source, ) + self.ignore_errors_until = time.time() + self.ignore_errors_period raise UpdateFailed(REMOTE_FAILED.format(self.source, _http_resp.status_code)) + self.last_updated = time.time() - return True + self.ignore_errors_until = None + return load_successful def _parse_remote_response(self, response): """ @@ -433,14 +460,10 @@ def _parse_remote_response(self, response): return None def _uptodate(self): - res = False if self.remote or self.local: if time.time() > self.time_out: - if self.local and not self._local_update_required(): - res = True - elif self.update(): - res = True - return res + return self.update() + return False def update(self): """ @@ -448,8 +471,9 @@ def update(self): This is a forced update, will happen even if cache time has not elapsed. Replaced keys will be marked as inactive and not removed. + + :return: True if update was ok or False if we encountered an error during update. """ - res = True # An update was successful if self.source: _old_keys = self._keys # just in case @@ -459,24 +483,27 @@ def update(self): try: if self.local: if self.fileformat in ["jwks", "jwk"]: - self.do_local_jwk(self.source) + updated = self.do_local_jwk(self.source) elif self.fileformat == "der": - self.do_local_der(self.source, self.keytype, self.keyusage) + updated = self.do_local_der(self.source, self.keytype, self.keyusage) elif self.remote: - res = self.do_remote() + updated = self.do_remote() except Exception as err: LOGGER.error("Key bundle update failed: %s", err) self._keys = _old_keys # restore return False - now = time.time() - for _key in _old_keys: - if _key not in self._keys: - if not _key.inactive_since: # If already marked don't mess - _key.inactive_since = now - self._keys.append(_key) + if updated: + now = time.time() + for _key in _old_keys: + if _key not in self._keys: + if not _key.inactive_since: # If already marked don't mess + _key.inactive_since = now + self._keys.append(_key) + else: + self._keys = _old_keys - return res + return True def get(self, typ="", only_active=True): """ diff --git a/tests/test_03_key_bundle.py b/tests/test_03_key_bundle.py index 7d25b39..7d12026 100755 --- a/tests/test_03_key_bundle.py +++ b/tests/test_03_key_bundle.py @@ -17,6 +17,7 @@ from cryptojwt.jwk.rsa import import_rsa_key_from_cert_file from cryptojwt.jwk.rsa import new_rsa_key from cryptojwt.key_bundle import KeyBundle +from cryptojwt.key_bundle import UpdateFailed from cryptojwt.key_bundle import build_key_bundle from cryptojwt.key_bundle import dump_jwks from cryptojwt.key_bundle import init_key @@ -566,6 +567,7 @@ def test_update_2(): ec_key = new_ec_key(crv="P-256", key_ops=["sign"]) _jwks = {"keys": [rsa_key.serialize(), ec_key.serialize()]} + time.sleep(0.5) with open(fname, "w") as fp: fp.write(json.dumps(_jwks)) @@ -1008,7 +1010,7 @@ def test_remote_not_modified(): with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, status=304, headers=headers) - assert kb.do_remote() + assert not kb.do_remote() assert kb.last_remote == headers.get("Last-Modified") timeout2 = kb.time_out @@ -1018,9 +1020,50 @@ def test_remote_not_modified(): kb2 = KeyBundle().load(exp) assert kb2.source == source assert len(kb2.keys()) == 3 + assert len(kb2.active_keys()) == 3 assert len(kb2.get("rsa")) == 1 assert len(kb2.get("oct")) == 1 assert len(kb2.get("ec")) == 1 assert kb2.httpc_params == {"timeout": (2, 2)} assert kb2.imp_jwks assert kb2.last_updated + + +def test_ignore_errors_period(): + source_good = "https://example.com/keys.json" + source_bad = "https://example.com/keys-bad.json" + ignore_errors_period = 1 + # Mock response + with responses.RequestsMock() as rsps: + rsps.add(method="GET", url=source_good, json=JWKS_DICT, status=200) + rsps.add(method="GET", url=source_bad, json=JWKS_DICT, status=500) + httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds + kb = KeyBundle( + source=source_good, + httpc=requests.request, + httpc_params=httpc_params, + ignore_errors_period=ignore_errors_period, + ) + res = kb.do_remote() + assert res == True + assert kb.ignore_errors_until is None + + # refetch, but fail by using a bad source + kb.source = source_bad + try: + res = kb.do_remote() + except UpdateFailed: + pass + + # retry should fail silently as we're in holddown + res = kb.do_remote() + assert kb.ignore_errors_until is not None + assert res == False + + # wait until holddown + time.sleep(ignore_errors_period + 1) + + # try again + kb.source = source_good + res = kb.do_remote() + assert res == True diff --git a/tests/test_04_key_jar.py b/tests/test_04_key_jar.py index b31e5ba..53cb5ef 100755 --- a/tests/test_04_key_jar.py +++ b/tests/test_04_key_jar.py @@ -746,6 +746,12 @@ def test_aud(self): keys = self.bob_keyjar.get_jwt_verify_keys(_jwt.jwt, no_kid_issuer=no_kid_issuer) assert len(keys) == 1 + def test_inactive_verify_key(self): + _jwt = factory(self.sjwt_b) + self.alice_keyjar.return_issuer("Bob")[0].mark_all_as_inactive() + keys = self.alice_keyjar.get_jwt_verify_keys(_jwt.jwt) + assert len(keys) == 0 + def test_copy(): kj = KeyJar() diff --git a/tox.ini b/tox.ini index def7f92..272756a 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ envlist = py{36,37,38},quality [testenv] passenv = CI TRAVIS TRAVIS_* commands = - py.test --cov=cryptojwt --isort --black {posargs} + pytest -vvv -ra --cov=cryptojwt --isort --black {posargs} codecov extras = testing deps =