diff --git a/pyproject.toml b/pyproject.toml index b51657a..48ec79d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ exclude_lines = [ [tool.poetry] name = "cryptojwt" -version = "1.8.0" +version = "1.8.1" description = "Python implementation of JWT, JWE, JWS and JWK" authors = ["Roland Hedberg "] license = "Apache-2.0" diff --git a/src/cryptojwt/jwe/fernet.py b/src/cryptojwt/jwe/fernet.py index 3cca512..90b02d3 100644 --- a/src/cryptojwt/jwe/fernet.py +++ b/src/cryptojwt/jwe/fernet.py @@ -16,26 +16,34 @@ class FernetEncrypter(Encrypter): def __init__( self, - password: str, + password: Optional[str] = None, salt: Optional[bytes] = "", + key: Optional[bytes] = None, hash_alg: Optional[str] = "SHA256", digest_size: Optional[int] = 0, iterations: Optional[int] = DEFAULT_ITERATIONS, ): Encrypter.__init__(self) - if not salt: - salt = os.urandom(16) - else: - salt = as_bytes(salt) - _alg = getattr(hashes, hash_alg) - # A bit special for SHAKE* and BLAKE* hashes - if hash_alg.startswith("SHAKE") or hash_alg.startswith("BLAKE"): - _algorithm = _alg(digest_size) + if password is not None: + _alg = getattr(hashes, hash_alg) + # A bit special for SHAKE* and BLAKE* hashes + if hash_alg.startswith("SHAKE") or hash_alg.startswith("BLAKE"): + _algorithm = _alg(digest_size) + else: + _algorithm = _alg() + salt = as_bytes(salt) if salt else os.urandom(16) + kdf = PBKDF2HMAC(algorithm=_algorithm, length=32, salt=salt, iterations=iterations) + self.key = base64.urlsafe_b64encode(kdf.derive(as_bytes(password))) + elif key is not None: + if not isinstance(key, bytes): + raise TypeError("Raw key must be bytes") + if len(key) != 32: + raise ValueError("Raw key must be 32 bytes") + self.key = base64.urlsafe_b64encode(key) else: - _algorithm = _alg() - kdf = PBKDF2HMAC(algorithm=_algorithm, length=32, salt=salt, iterations=iterations) - self.key = base64.urlsafe_b64encode(kdf.derive(as_bytes(password))) + self.key = Fernet.generate_key() + self.core = Fernet(self.key) def encrypt(self, msg: Union[str, bytes], **kwargs) -> bytes: diff --git a/tests/test_07_jwe.py b/tests/test_07_jwe.py index ed6197b..82a3160 100644 --- a/tests/test_07_jwe.py +++ b/tests/test_07_jwe.py @@ -648,10 +648,46 @@ def test_invalid(): decrypter.decrypt("a.b.c.d.e", keys=[encryption_key]) -def test_fernet(): +def test_fernet_password(): + encrypter = FernetEncrypter(password="DukeofHazardpass") + _token = encrypter.encrypt(plain) + + decrypter = encrypter + resp = decrypter.decrypt(_token) + assert resp == plain + + +def test_fernet_symkey(): encryption_key = SYMKey(use="enc", key="DukeofHazardpass", kid="some-key-id") - encrypter = FernetEncrypter(encryption_key.key) + encrypter = FernetEncrypter(password=encryption_key.key) + _token = encrypter.encrypt(plain) + + decrypter = encrypter + resp = decrypter.decrypt(_token) + assert resp == plain + + +def test_fernet_bad(): + with pytest.raises(TypeError): + encrypter = FernetEncrypter(key="xyzzy") + with pytest.raises(ValueError): + encrypter = FernetEncrypter(key=os.urandom(16)) + + +def test_fernet_bytes(): + key = os.urandom(32) + + encrypter = FernetEncrypter(key=key) + _token = encrypter.encrypt(plain) + + decrypter = encrypter + resp = decrypter.decrypt(_token) + assert resp == plain + + +def test_fernet_default_key(): + encrypter = FernetEncrypter() _token = encrypter.encrypt(plain) decrypter = encrypter @@ -662,7 +698,7 @@ def test_fernet(): def test_fernet_sha512(): encryption_key = SYMKey(use="enc", key="DukeofHazardpass", kid="some-key-id") - encrypter = FernetEncrypter(encryption_key.key, hash_alg="SHA512") + encrypter = FernetEncrypter(password=encryption_key.key, hash_alg="SHA512") _token = encrypter.encrypt(plain) decrypter = encrypter @@ -674,7 +710,7 @@ def test_fernet_blake2s(): encryption_key = SYMKey(use="enc", key="DukeofHazardpass", kid="some-key-id") encrypter = FernetEncrypter( - encryption_key.key, hash_alg="BLAKE2s", digest_size=32, iterations=1000 + password=encryption_key.key, hash_alg="BLAKE2s", digest_size=32, iterations=1000 ) _token = encrypter.encrypt(plain)