From d34228acaf64943ded7c8fd97700e3edc9b65383 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Mon, 6 Jan 2025 09:11:50 +0100 Subject: [PATCH 1/7] Add new API --- jwt_rsa/token.py | 201 +++++++++++++++++++++++++---------------------- 1 file changed, 106 insertions(+), 95 deletions(-) diff --git a/jwt_rsa/token.py b/jwt_rsa/token.py index 0340939..4db49a0 100644 --- a/jwt_rsa/token.py +++ b/jwt_rsa/token.py @@ -1,8 +1,9 @@ import time +from dataclasses import dataclass, field from datetime import datetime, timedelta from operator import add, sub from typing import ( - TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, TypeVar, Union, + TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, TypeVar, Union, overload, ) from jwt import PyJWT @@ -19,103 +20,113 @@ R = TypeVar("R") DAY = 86400 - - -class JWT: - __slots__ = ( - "__private_key", "__public_key", "__jwt", - "__expires", "__nbf_delta", "__algorithm", - "__algorithms", - ) - - DEFAULT_EXPIRATION = timedelta(days=31).total_seconds() - NBF_DELTA = 20 - ALGORITHMS = tuple(AlgorithmType.__args__) +DEFAULT_EXPIRATION = timedelta(days=31).total_seconds() +NBF_DELTA = 20 +ALGORITHMS = tuple(AlgorithmType.__args__) + + +def date_to_timestamp( + value: DateType, + default: Callable[[], R], + timedelta_func: Callable[[float, float], int] = add, +) -> Union[int, float, R]: + if isinstance(value, timedelta): + return timedelta_func(time.time(), value.total_seconds()) + elif isinstance(value, datetime): + return value.timestamp() + elif isinstance(value, (int, float)): + return value + elif value is Ellipsis: + return default() + + raise ValueError(type(value)) + + +@dataclass(frozen=True, init=False) +class JWTDecoder: + jwt: PyJWT = field(repr=False, compare=False) + public_key: RSAPublicKey = field(repr=False, compare=False) + expires: Union[int, float] + nbf_delta: Union[int, float] + algorithm: AlgorithmType + algorithms: Sequence[AlgorithmType] def __init__( self, - key: Optional[Union[RSAPrivateKey, RSAPublicKey]], - *, expires: Optional[int] = None, - nbf_delta: Optional[int] = None, + key: RSAPublicKey, + *, options: dict[str, Any] | None = None, + expires: int | float = DEFAULT_EXPIRATION, + nbf_delta: int | float = NBF_DELTA, algorithm: AlgorithmType = "RS512", algorithms: Sequence[AlgorithmType] = ALGORITHMS, - options: Optional[Dict[str, Any]] = None, ): - self.__public_key: RSAPublicKey - self.__private_key: Optional[RSAPrivateKey] - - if isinstance(key, RSAPrivateKey): - self.__public_key = key.public_key() - self.__private_key = key - elif isinstance(key, RSAPublicKey): - self.__public_key = key - self.__private_key = None - else: - raise ValueError("You must provide either a public or private key") - - self.__jwt = PyJWT(options) - self.__expires = expires or self.DEFAULT_EXPIRATION - self.__nbf_delta = nbf_delta or self.NBF_DELTA - self.__algorithm = algorithm - self.__algorithms = list(algorithms) - - @staticmethod - def _date_to_timestamp( - value: DateType, - default: Callable[[], R], - timedelta_func: Callable[[float, float], int] = add, - ) -> Union[int, float, R]: - if isinstance(value, timedelta): - return timedelta_func(time.time(), value.total_seconds()) - elif isinstance(value, datetime): - return value.timestamp() - elif isinstance(value, (int, float)): - return value - elif value is Ellipsis: - return default() - - raise ValueError(type(value)) - - def encode( - self, - expired: DateType = ..., - nbf: DateType = ..., - **claims: Any, - ) -> str: - if not self.__private_key: - raise RuntimeError("Can't encode without private key") - - claims.update( - dict( - exp=int( - self._date_to_timestamp( - expired, - lambda: time.time() + self.__expires, - ), - ), - nbf=int( - self._date_to_timestamp( - nbf, - lambda: time.time() - self.__nbf_delta, - timedelta_func=sub, - ), - ), - ), - ) - - return self.__jwt.encode( - claims, - self.__private_key, - algorithm=self.__algorithm, - ) - - def decode( - self, token: str, verify: bool = True, **kwargs: Any, - ) -> Dict[str, Any]: - return self.__jwt.decode( - token, - key=self.__public_key, - verify=verify, - algorithms=self.__algorithms, - **kwargs, - ) + super().__setattr__('public_key', key) + super().__setattr__('jwt', PyJWT(options)) + super().__setattr__('expires', expires) + super().__setattr__('nbf_delta', nbf_delta) + super().__setattr__('algorithm', algorithm) + super().__setattr__('algorithms', algorithms) + + def decode(self, token: str, verify: bool = True, **kwargs: Any) -> Dict[str, Any]: + return self.jwt.decode(token, key=self.public_key, verify=verify, algorithms=self.algorithms, **kwargs) + + +@dataclass(frozen=True, init=False) +class JWTSigner(JWTDecoder): + private_key: RSAPrivateKey = field(repr=False, compare=False) + + def __init__(self, key: RSAPrivateKey, *, options: Optional[Dict[str, Any]] = None, **kwargs: Any): + super(JWTDecoder, self).__setattr__('private_key', key) + super().__init__(key.public_key(), options=options, **kwargs) + + def encode(self, expired: DateType = ..., nbf: DateType = ..., **claims: Any) -> str: + claims.setdefault('exp', int(date_to_timestamp(expired, lambda: time.time() + self.expires))) + claims.setdefault('nbf', int(date_to_timestamp(nbf, lambda: time.time() - self.nbf_delta, timedelta_func=sub))) + return self.jwt.encode(claims, self.private_key, algorithm=self.algorithm) + + +@overload +def JWT( + key: RSAPrivateKey, *, + options: dict[str, Any] | None = None, + expires: int | float = DEFAULT_EXPIRATION, + nbf_delta: int | float = NBF_DELTA, + algorithm: AlgorithmType = "RS512", + algorithms: Sequence[AlgorithmType] = ALGORITHMS, +) -> JWTSigner: ... + + +@overload +def JWT( # type: ignore[overload-cannot-match] + key: RSAPublicKey, *, + options: dict[str, Any] | None = None, + expires: int | float = DEFAULT_EXPIRATION, + nbf_delta: int | float = NBF_DELTA, + algorithm: AlgorithmType = "RS512", + algorithms: Sequence[AlgorithmType] = ALGORITHMS, +) -> JWTDecoder: ... + + +def JWT( + key: Union[RSAPrivateKey, RSAPublicKey], + *, + options: dict[str, Any] | None = None, + expires: int | float = DEFAULT_EXPIRATION, + nbf_delta: int | float = NBF_DELTA, + algorithm: AlgorithmType = "RS512", + algorithms: Sequence[AlgorithmType] = ALGORITHMS, +) -> Union[JWTSigner, JWTDecoder]: + kwargs = dict( + expires=expires, + nbf_delta=nbf_delta, + algorithm=algorithm, + algorithms=algorithms, + options=options, + ) + + if isinstance(key, RSAPrivateKey): + return JWTSigner(key, **kwargs) + elif isinstance(key, RSAPublicKey): + return JWTDecoder(key, **kwargs) + else: + raise TypeError(f"Invalid key type: {type(key)}") From cc0c634bedc2a4e9a4d34d4e6524b09854569ea6 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Mon, 6 Jan 2025 09:12:00 +0100 Subject: [PATCH 2/7] More tests --- tests/test_cli.py | 50 ++++++++++++++++++++--------------------------- tests/test_rsa.py | 4 ++-- 2 files changed, 23 insertions(+), 31 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 32a5a8e..20bec3f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,8 +1,6 @@ import io import json import os -from io import StringIO -from unittest import mock import pytest from cryptography.exceptions import InvalidSignature @@ -13,14 +11,11 @@ from jwt_rsa.cli import parser from jwt_rsa.key_tester import main as verify from jwt_rsa.keygen import main as keygen -from jwt_rsa.rsa import ( - generate_rsa, load_private_key, load_public_key, rsa_to_jwk, -) +from jwt_rsa.rsa import load_private_key, load_public_key def test_rsa_keygen(capsys): - with mock.patch("sys.argv", ["jwt-rsa", "keygen", "--raw", "-o", "jwk"]): - keygen(parser.parse_args()) + keygen(parser.parse_args(["keygen", "--raw", "-o", "jwk"])) stdout, stderr = capsys.readouterr() @@ -57,8 +52,7 @@ def test_rsa_keygen(capsys): def test_pem_format(capsys): - with mock.patch("sys.argv", ["jwt-rsa", "keygen", "-o", "pem"]): - keygen(parser.parse_args()) + keygen(parser.parse_args(["keygen", "-o", "pem"])) stdout, stderr = capsys.readouterr() @@ -160,29 +154,27 @@ def test_keygen_public_key_auto_naming(capsys, tmp_path): assert private_content != private_path.read_text() -@pytest.mark.skip(reason="TODO") -def test_rsa_verify(capsys): - with mock.patch("sys.argv", ["jwt-rsa", "keygen"]): - keygen(parser.parse_args()) +@pytest.mark.parametrize("fmt", ["jwk", "pem", "base64"]) +def test_rsa_verify(fmt, capsys, tmp_path): + private_path = tmp_path / "private" + public_path = tmp_path / "public" + keygen(parser.parse_args(["keygen", "-o", fmt, "-K", str(private_path), "-k", str(public_path)])) + verify(parser.parse_args(["testkey", "-K", str(private_path), "-k", str(public_path)])) stdout, stderr = capsys.readouterr() + assert "Signing OK" in stderr + assert "Verifying OK" in stderr - with mock.patch("sys.stdin", StringIO(stdout)): - verify(parser.parse_args()) - -@pytest.mark.skip(reason="TODO") -def test_rsa_verify_bad_key(): - private1, public1 = generate_rsa() - private2, public2 = generate_rsa() - - data = json.dumps( - { - "private_jwk": rsa_to_jwk(private1), - "public_jwk": rsa_to_jwk(public2), - }, indent=" ", sort_keys=True, +@pytest.mark.parametrize("fmt", ["jwk", "pem", "base64"]) +def test_rsa_verify_bad_key(fmt, capsys, tmp_path): + keys = ( + (tmp_path / "private1", tmp_path / "public1"), + (tmp_path / "private2", tmp_path / "public2"), ) - with mock.patch("sys.stdin", StringIO(data)): - with pytest.raises(InvalidSignature): - verify(parser.parse_args()) + for private_path, public_path in keys: + keygen(parser.parse_args(["keygen", "-o", fmt, "-K", str(private_path), "-k", str(public_path)])) + + with pytest.raises(InvalidSignature): + verify(parser.parse_args(["testkey", "-K", str(keys[0][0]), "-k", str(keys[1][1])])) diff --git a/tests/test_rsa.py b/tests/test_rsa.py index 2d6203f..fe8d49c 100644 --- a/tests/test_rsa.py +++ b/tests/test_rsa.py @@ -116,7 +116,7 @@ def test_decode_only_ability(): jwt = JWT(public) assert "foo" in jwt.decode(token) - with pytest.raises(RuntimeError): + with pytest.raises(AttributeError): jwt.encode(foo=None) @@ -128,7 +128,7 @@ def test_jwt_init(): assert JWT(public) - with pytest.raises(ValueError): + with pytest.raises(TypeError): JWT(None) From 207dd0f1d34d9bb5599274f9f628f2802d732641 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Mon, 6 Jan 2025 09:14:51 +0100 Subject: [PATCH 3/7] Add to init --- jwt_rsa/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jwt_rsa/__init__.py b/jwt_rsa/__init__.py index d6af352..1172aa9 100644 --- a/jwt_rsa/__init__.py +++ b/jwt_rsa/__init__.py @@ -2,12 +2,14 @@ RSAJWKPrivateKey, RSAJWKPublicKey, generate_rsa, load_private_key, load_public_key, rsa_to_jwk, ) -from .token import JWT +from .token import JWT, JWTDecoder, JWTSigner from .types import RSAPrivateKey, RSAPublicKey __all__ = ( "JWT", + "JWTDecoder", + "JWTSigner", "RSAJWKPrivateKey", "RSAJWKPublicKey", "RSAPrivateKey", From 49b1847d5e3a1ec81bd733b72cfbcbb4e15d82d3 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Mon, 6 Jan 2025 09:23:41 +0100 Subject: [PATCH 4/7] Comparable keys --- jwt_rsa/token.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jwt_rsa/token.py b/jwt_rsa/token.py index 4db49a0..a0ae2d1 100644 --- a/jwt_rsa/token.py +++ b/jwt_rsa/token.py @@ -45,7 +45,7 @@ def date_to_timestamp( @dataclass(frozen=True, init=False) class JWTDecoder: jwt: PyJWT = field(repr=False, compare=False) - public_key: RSAPublicKey = field(repr=False, compare=False) + public_key: RSAPublicKey = field(repr=False, compare=True) expires: Union[int, float] nbf_delta: Union[int, float] algorithm: AlgorithmType @@ -73,7 +73,7 @@ def decode(self, token: str, verify: bool = True, **kwargs: Any) -> Dict[str, An @dataclass(frozen=True, init=False) class JWTSigner(JWTDecoder): - private_key: RSAPrivateKey = field(repr=False, compare=False) + private_key: RSAPrivateKey = field(repr=False, compare=True) def __init__(self, key: RSAPrivateKey, *, options: Optional[Dict[str, Any]] = None, **kwargs: Any): super(JWTDecoder, self).__setattr__('private_key', key) From b2293934da13f3f1496d85ee86b2675e489ce19c Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Mon, 6 Jan 2025 09:43:17 +0100 Subject: [PATCH 5/7] Separate mypy linter rule --- .github/workflows/pythonpackage.yml | 26 +++++++++++++++++++++ jwt_rsa/cli.py | 5 ++--- jwt_rsa/issue.py | 2 +- jwt_rsa/rsa.py | 33 +++++++++++++++++---------- jwt_rsa/token.py | 35 +++++++++++------------------ jwt_rsa/types.py | 4 +++- jwt_rsa/verify.py | 10 +++++---- pyproject.toml | 2 +- 8 files changed, 73 insertions(+), 44 deletions(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index e0fc894..7f63c32 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -7,6 +7,32 @@ on: branches: [ master ] jobs: + mypy: + runs-on: ubuntu-latest + strategy: + fail-fast: false + + steps: + - uses: actions/checkout@v2 + + - name: Setup python3.10 + uses: actions/setup-python@v2 + with: + python-version: "3.10" + + - name: Install poetry + run: python -m pip install poetry + + - name: Install dependencies + run: poetry install + env: + FORCE_COLOR: yes + + - name: Run mypy + run: poetry run mypy jwt_rsa + env: + FORCE_COLOR: yes + tests: runs-on: ubuntu-latest strategy: diff --git a/jwt_rsa/cli.py b/jwt_rsa/cli.py index 4d02d5a..188007f 100644 --- a/jwt_rsa/cli.py +++ b/jwt_rsa/cli.py @@ -2,9 +2,8 @@ from argparse import ArgumentParser from pathlib import Path -from jwt_rsa.types import AlgorithmType - from . import convert, issue, key_tester, keygen, pubkey, verify +from .token import ALGORITHMS parser = ArgumentParser() @@ -20,7 +19,7 @@ "--kid", dest="kid", type=str, default="", help="Key ID, will be generated if missing", ) keygen_parser.add_argument( - "-a", "--algorithm", choices=AlgorithmType.__args__, + "-a", "--algorithm", choices=ALGORITHMS, help="Key ID, will be generated if missing", default="RS512", ) keygen_parser.add_argument("-u", "--use", dest="use", type=str, default="sig", choices=["sig", "enc"]) diff --git a/jwt_rsa/issue.py b/jwt_rsa/issue.py index 1f610cc..d96a64a 100644 --- a/jwt_rsa/issue.py +++ b/jwt_rsa/issue.py @@ -69,7 +69,7 @@ def main(arguments: SimpleNamespace) -> None: - jwt = JWT(private_key=load_private_key(arguments.private_key)) + jwt = JWT(load_private_key(arguments.private_key)) whoami = pwd.getpwuid(os.getuid()) diff --git a/jwt_rsa/rsa.py b/jwt_rsa/rsa.py index 08f1bd7..c965b27 100644 --- a/jwt_rsa/rsa.py +++ b/jwt_rsa/rsa.py @@ -1,7 +1,7 @@ import base64 import json from pathlib import Path -from typing import NamedTuple, Optional, TypedDict, Union, overload +from typing import NamedTuple, Optional, TypedDict, overload from cryptography.hazmat.backends import default_backend @@ -11,6 +11,11 @@ class KeyPair(NamedTuple): + private: RSAPrivateKey + public: RSAPublicKey + + +class JWKKeyPair(NamedTuple): private: Optional[RSAPrivateKey] public: RSAPublicKey @@ -80,8 +85,8 @@ def load_jwk_private_key(jwk: RSAJWKPrivateKey) -> RSAPrivateKey: return private_numbers.private_key(default_backend()) -def load_jwk(jwk: Union[RSAJWKPublicKey, RSAJWKPrivateKey, str]) -> KeyPair: - jwk_dict: Union[RSAJWKPublicKey, RSAJWKPrivateKey] +def load_jwk(jwk: RSAJWKPublicKey | RSAJWKPrivateKey | str) -> JWKKeyPair: + jwk_dict: RSAJWKPublicKey | RSAJWKPrivateKey if isinstance(jwk, str): jwk_dict = json.loads(jwk) @@ -92,10 +97,10 @@ def load_jwk(jwk: Union[RSAJWKPublicKey, RSAJWKPrivateKey, str]) -> KeyPair: private_key = load_jwk_private_key(jwk_dict) # type: ignore public_key = private_key.public_key() else: # Public key - public_key = load_jwk_public_key(jwk_dict) # type: ignore + public_key = load_jwk_public_key(jwk_dict) private_key = None - return KeyPair(private=private_key, public=public_key) + return JWKKeyPair(private=private_key, public=public_key) def int_to_base64url(value: int) -> str: @@ -111,19 +116,19 @@ def rsa_to_jwk( @overload -def rsa_to_jwk( # type: ignore[overload-cannot-match] +def rsa_to_jwk( key: RSAPrivateKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig", ) -> RSAJWKPrivateKey: ... def rsa_to_jwk( - key: Union[RSAPrivateKey, RSAPublicKey], + key: RSAPrivateKey | RSAPublicKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig", kty: str = "RSA", -) -> Union[RSAJWKPublicKey, RSAJWKPrivateKey]: +) -> RSAJWKPublicKey | RSAJWKPrivateKey: if isinstance(key, RSAPublicKey): public_numbers = key.public_numbers() private_numbers = None @@ -161,12 +166,14 @@ def rsa_to_jwk( ) -def load_private_key(data: Union[str, RSAJWKPrivateKey, Path]) -> RSAPrivateKey: +def load_private_key(data: str | RSAJWKPrivateKey | Path) -> RSAPrivateKey: if isinstance(data, Path): data = data.read_text() if isinstance(data, str): if data.startswith("-----BEGIN "): - return serialization.load_pem_private_key(data.encode(), None, default_backend()) + result = serialization.load_pem_private_key(data.encode(), None, default_backend()) + assert isinstance(result, RSAPrivateKey) + return result if data.strip().startswith("{"): return load_jwk_private_key(json.loads(data)) if isinstance(data, dict): @@ -177,12 +184,14 @@ def load_private_key(data: Union[str, RSAJWKPrivateKey, Path]) -> RSAPrivateKey: return key -def load_public_key(data: Union[str, RSAJWKPublicKey, Path]) -> RSAPublicKey: +def load_public_key(data: str | RSAJWKPublicKey | Path) -> RSAPublicKey: if isinstance(data, Path): data = data.read_text() if isinstance(data, str): if data.startswith("-----BEGIN "): - return serialization.load_pem_public_key(data.encode(), default_backend()) + result = serialization.load_pem_public_key(data.encode(), default_backend()) + assert isinstance(result, RSAPublicKey) + return result if data.strip().startswith("{"): return load_jwk_public_key(json.loads(data)) if isinstance(data, dict): diff --git a/jwt_rsa/token.py b/jwt_rsa/token.py index a0ae2d1..92f487a 100644 --- a/jwt_rsa/token.py +++ b/jwt_rsa/token.py @@ -2,34 +2,25 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta from operator import add, sub -from typing import ( - TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, TypeVar, Union, overload, -) +from types import EllipsisType +from typing import Any, Callable, Dict, Optional, Sequence, TypeVar, overload from jwt import PyJWT -from .types import AlgorithmType, RSAPrivateKey, RSAPublicKey - - -if TYPE_CHECKING: - # pylama:ignore=E0602 - DateType = Union[timedelta, datetime, float, int, ellipsis] -else: - DateType = Union[timedelta, datetime, float, int, type(Ellipsis)] - +from .types import AlgorithmType, RSAPrivateKey, RSAPublicKey, DateType R = TypeVar("R") DAY = 86400 DEFAULT_EXPIRATION = timedelta(days=31).total_seconds() NBF_DELTA = 20 -ALGORITHMS = tuple(AlgorithmType.__args__) +ALGORITHMS: Sequence[AlgorithmType] = ("RS256", "RS384", "RS512") def date_to_timestamp( - value: DateType, + value: DateType | EllipsisType, default: Callable[[], R], timedelta_func: Callable[[float, float], int] = add, -) -> Union[int, float, R]: +) -> int | float | R: if isinstance(value, timedelta): return timedelta_func(time.time(), value.total_seconds()) elif isinstance(value, datetime): @@ -46,8 +37,8 @@ def date_to_timestamp( class JWTDecoder: jwt: PyJWT = field(repr=False, compare=False) public_key: RSAPublicKey = field(repr=False, compare=True) - expires: Union[int, float] - nbf_delta: Union[int, float] + expires: int | float + nbf_delta: int | float algorithm: AlgorithmType algorithms: Sequence[AlgorithmType] @@ -79,7 +70,7 @@ def __init__(self, key: RSAPrivateKey, *, options: Optional[Dict[str, Any]] = No super(JWTDecoder, self).__setattr__('private_key', key) super().__init__(key.public_key(), options=options, **kwargs) - def encode(self, expired: DateType = ..., nbf: DateType = ..., **claims: Any) -> str: + def encode(self, expired: DateType | EllipsisType = ..., nbf: DateType | EllipsisType = ..., **claims: Any) -> str: claims.setdefault('exp', int(date_to_timestamp(expired, lambda: time.time() + self.expires))) claims.setdefault('nbf', int(date_to_timestamp(nbf, lambda: time.time() - self.nbf_delta, timedelta_func=sub))) return self.jwt.encode(claims, self.private_key, algorithm=self.algorithm) @@ -97,7 +88,7 @@ def JWT( @overload -def JWT( # type: ignore[overload-cannot-match] +def JWT( key: RSAPublicKey, *, options: dict[str, Any] | None = None, expires: int | float = DEFAULT_EXPIRATION, @@ -108,15 +99,15 @@ def JWT( # type: ignore[overload-cannot-match] def JWT( - key: Union[RSAPrivateKey, RSAPublicKey], + key: RSAPrivateKey | RSAPublicKey, *, options: dict[str, Any] | None = None, expires: int | float = DEFAULT_EXPIRATION, nbf_delta: int | float = NBF_DELTA, algorithm: AlgorithmType = "RS512", algorithms: Sequence[AlgorithmType] = ALGORITHMS, -) -> Union[JWTSigner, JWTDecoder]: - kwargs = dict( +) -> JWTSigner | JWTDecoder: + kwargs: dict[str, Any] = dict( expires=expires, nbf_delta=nbf_delta, algorithm=algorithm, diff --git a/jwt_rsa/types.py b/jwt_rsa/types.py index 4ab9a89..4dadefb 100644 --- a/jwt_rsa/types.py +++ b/jwt_rsa/types.py @@ -1,3 +1,4 @@ +from datetime import timedelta, datetime from typing import Literal from cryptography.hazmat.primitives import serialization @@ -11,10 +12,11 @@ AlgorithmType = Literal["RS256", "RS384", "RS512"] - +DateType = timedelta | datetime | float | int __all__ = ( "AlgorithmType", + "DateType", "RSAPrivateKey", "RSAPublicKey", "serialization", diff --git a/jwt_rsa/verify.py b/jwt_rsa/verify.py index 51bba03..150fdad 100644 --- a/jwt_rsa/verify.py +++ b/jwt_rsa/verify.py @@ -3,16 +3,18 @@ from types import SimpleNamespace from .rsa import generate_rsa, load_private_key, load_public_key -from .token import JWT +from .token import JWT, JWTSigner, JWTDecoder def main(arguments: SimpleNamespace) -> None: + jwt: JWTSigner | JWTDecoder if arguments.private_key: - jwt = JWT(private_key=load_private_key(arguments.private_key)) + jwt = JWT(load_private_key(arguments.private_key)) elif arguments.public_key: - jwt = JWT(public_key=load_public_key(arguments.public_key)) + jwt = JWT(load_public_key(arguments.public_key)) elif not arguments.verify: - jwt = JWT(*generate_rsa(1024)) + key_pair = generate_rsa(1024) + jwt = JWT(key_pair.private) else: print("Either private or public key must be provided", file=sys.stderr) exit(1) diff --git a/pyproject.toml b/pyproject.toml index 5c95337..1ad19ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.pylama] -linters = "pycodestyle,pyflakes,mccabe,mccabe,mypy" +linters = "pycodestyle,pyflakes,mccabe,mccabe" [tool.pylama.linter.pycodestyle] max_line_length = 119 From 2fc1ce520f0f18b3c17e52e80da304a4e30a1531 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Mon, 6 Jan 2025 09:43:28 +0100 Subject: [PATCH 6/7] reformat --- jwt_rsa/rsa.py | 2 +- jwt_rsa/token.py | 21 +++++++++++---------- jwt_rsa/types.py | 2 +- jwt_rsa/verify.py | 2 +- tests/test_cli.py | 6 +++--- tests/test_rsa.py | 24 ++++++++++++++---------- 6 files changed, 31 insertions(+), 26 deletions(-) diff --git a/jwt_rsa/rsa.py b/jwt_rsa/rsa.py index c965b27..570f154 100644 --- a/jwt_rsa/rsa.py +++ b/jwt_rsa/rsa.py @@ -111,7 +111,7 @@ def int_to_base64url(value: int) -> str: @overload def rsa_to_jwk( - key: RSAPublicKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig" + key: RSAPublicKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig", ) -> RSAJWKPublicKey: ... diff --git a/jwt_rsa/token.py b/jwt_rsa/token.py index 92f487a..bb232d0 100644 --- a/jwt_rsa/token.py +++ b/jwt_rsa/token.py @@ -7,7 +7,8 @@ from jwt import PyJWT -from .types import AlgorithmType, RSAPrivateKey, RSAPublicKey, DateType +from .types import AlgorithmType, DateType, RSAPrivateKey, RSAPublicKey + R = TypeVar("R") DAY = 86400 @@ -51,12 +52,12 @@ def __init__( algorithm: AlgorithmType = "RS512", algorithms: Sequence[AlgorithmType] = ALGORITHMS, ): - super().__setattr__('public_key', key) - super().__setattr__('jwt', PyJWT(options)) - super().__setattr__('expires', expires) - super().__setattr__('nbf_delta', nbf_delta) - super().__setattr__('algorithm', algorithm) - super().__setattr__('algorithms', algorithms) + super().__setattr__("public_key", key) + super().__setattr__("jwt", PyJWT(options)) + super().__setattr__("expires", expires) + super().__setattr__("nbf_delta", nbf_delta) + super().__setattr__("algorithm", algorithm) + super().__setattr__("algorithms", algorithms) def decode(self, token: str, verify: bool = True, **kwargs: Any) -> Dict[str, Any]: return self.jwt.decode(token, key=self.public_key, verify=verify, algorithms=self.algorithms, **kwargs) @@ -67,12 +68,12 @@ class JWTSigner(JWTDecoder): private_key: RSAPrivateKey = field(repr=False, compare=True) def __init__(self, key: RSAPrivateKey, *, options: Optional[Dict[str, Any]] = None, **kwargs: Any): - super(JWTDecoder, self).__setattr__('private_key', key) + super(JWTDecoder, self).__setattr__("private_key", key) super().__init__(key.public_key(), options=options, **kwargs) def encode(self, expired: DateType | EllipsisType = ..., nbf: DateType | EllipsisType = ..., **claims: Any) -> str: - claims.setdefault('exp', int(date_to_timestamp(expired, lambda: time.time() + self.expires))) - claims.setdefault('nbf', int(date_to_timestamp(nbf, lambda: time.time() - self.nbf_delta, timedelta_func=sub))) + claims.setdefault("exp", int(date_to_timestamp(expired, lambda: time.time() + self.expires))) + claims.setdefault("nbf", int(date_to_timestamp(nbf, lambda: time.time() - self.nbf_delta, timedelta_func=sub))) return self.jwt.encode(claims, self.private_key, algorithm=self.algorithm) diff --git a/jwt_rsa/types.py b/jwt_rsa/types.py index 4dadefb..9cd1131 100644 --- a/jwt_rsa/types.py +++ b/jwt_rsa/types.py @@ -1,4 +1,4 @@ -from datetime import timedelta, datetime +from datetime import datetime, timedelta from typing import Literal from cryptography.hazmat.primitives import serialization diff --git a/jwt_rsa/verify.py b/jwt_rsa/verify.py index 150fdad..a8ce266 100644 --- a/jwt_rsa/verify.py +++ b/jwt_rsa/verify.py @@ -3,7 +3,7 @@ from types import SimpleNamespace from .rsa import generate_rsa, load_private_key, load_public_key -from .token import JWT, JWTSigner, JWTDecoder +from .token import JWT, JWTDecoder, JWTSigner def main(arguments: SimpleNamespace) -> None: diff --git a/tests/test_cli.py b/tests/test_cli.py index 20bec3f..df905b5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -95,7 +95,7 @@ def test_keygen_no_force(capsys, tmp_path): parser.parse_args([ "keygen", "-o", "pem", "-K", str(private_path), "-k", str(public_path), - ]) + ]), ) assert private_path.exists() @@ -112,7 +112,7 @@ def test_keygen_no_force(capsys, tmp_path): parser.parse_args([ "keygen", "-o", "pem", "-K", str(private_path), "-k", str(public_path), - ]) + ]), ) assert public_content == public_path.read_text() @@ -122,7 +122,7 @@ def test_keygen_no_force(capsys, tmp_path): parser.parse_args([ "keygen", "-o", "pem", "-f", "-K", str(private_path), "-k", str(public_path), - ]) + ]), ) assert public_content != public_path.read_text() diff --git a/tests/test_rsa.py b/tests/test_rsa.py index fe8d49c..24aa2d9 100644 --- a/tests/test_rsa.py +++ b/tests/test_rsa.py @@ -11,8 +11,8 @@ from jwt.exceptions import InvalidAlgorithmError, InvalidSignatureError from jwt_rsa import rsa -from jwt_rsa.types import serialization from jwt_rsa.token import JWT +from jwt_rsa.types import serialization def test_rsa_sign(): @@ -173,17 +173,21 @@ def test_load_public_key(tmp_path): private_path = tmp_path / "private.pem" with open(public_path, "wb") as fp: - fp.write(public.public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - )) + fp.write( + public.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ), + ) with open(private_path, "wb") as fp: - fp.write(key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - )) + fp.write( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ), + ) assert rsa.load_public_key(public_path) assert rsa.load_public_key(public_path.read_text()) From a213617b718d3b8f5f52a5aee47d81b2b307f2b4 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Mon, 6 Jan 2025 09:44:27 +0100 Subject: [PATCH 7/7] Bump to 1.1.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1ad19ca..643675a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pyjwt-rsa" -version = "1.0.1" +version = "1.1.0" description = "RSA helpers for PyJWT" authors = ["Dmitry Orlov "] license = "MIT"