Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor JWT Class into JWTDecoder and JWTSigner, Improve Type Safety and Test Coverage #4

Merged
merged 7 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .github/workflows/pythonpackage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,32 @@ on:
branches: [ master ]

jobs:
mypy:
runs-on: ubuntu-latest
strategy:
fail-fast: false

steps:
- uses: actions/checkout@v2

- name: Setup python3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"

- name: Install poetry
run: python -m pip install poetry

- name: Install dependencies
run: poetry install
env:
FORCE_COLOR: yes

- name: Run mypy
run: poetry run mypy jwt_rsa
env:
FORCE_COLOR: yes

tests:
runs-on: ubuntu-latest
strategy:
Expand Down
4 changes: 3 additions & 1 deletion jwt_rsa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
RSAJWKPrivateKey, RSAJWKPublicKey, generate_rsa, load_private_key,
load_public_key, rsa_to_jwk,
)
from .token import JWT
from .token import JWT, JWTDecoder, JWTSigner
from .types import RSAPrivateKey, RSAPublicKey


__all__ = (
"JWT",
"JWTDecoder",
"JWTSigner",
"RSAJWKPrivateKey",
"RSAJWKPublicKey",
"RSAPrivateKey",
Expand Down
5 changes: 2 additions & 3 deletions jwt_rsa/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from argparse import ArgumentParser
from pathlib import Path

from jwt_rsa.types import AlgorithmType

from . import convert, issue, key_tester, keygen, pubkey, verify
from .token import ALGORITHMS


parser = ArgumentParser()
Expand All @@ -20,7 +19,7 @@
"--kid", dest="kid", type=str, default="", help="Key ID, will be generated if missing",
)
keygen_parser.add_argument(
"-a", "--algorithm", choices=AlgorithmType.__args__,
"-a", "--algorithm", choices=ALGORITHMS,
help="Key ID, will be generated if missing", default="RS512",
)
keygen_parser.add_argument("-u", "--use", dest="use", type=str, default="sig", choices=["sig", "enc"])
Expand Down
2 changes: 1 addition & 1 deletion jwt_rsa/issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@


def main(arguments: SimpleNamespace) -> None:
jwt = JWT(private_key=load_private_key(arguments.private_key))
jwt = JWT(load_private_key(arguments.private_key))

whoami = pwd.getpwuid(os.getuid())

Expand Down
35 changes: 22 additions & 13 deletions jwt_rsa/rsa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import json
from pathlib import Path
from typing import NamedTuple, Optional, TypedDict, Union, overload
from typing import NamedTuple, Optional, TypedDict, overload

from cryptography.hazmat.backends import default_backend

Expand All @@ -11,6 +11,11 @@


class KeyPair(NamedTuple):
private: RSAPrivateKey
public: RSAPublicKey


class JWKKeyPair(NamedTuple):
private: Optional[RSAPrivateKey]
public: RSAPublicKey

Expand Down Expand Up @@ -80,8 +85,8 @@ def load_jwk_private_key(jwk: RSAJWKPrivateKey) -> RSAPrivateKey:
return private_numbers.private_key(default_backend())


def load_jwk(jwk: Union[RSAJWKPublicKey, RSAJWKPrivateKey, str]) -> KeyPair:
jwk_dict: Union[RSAJWKPublicKey, RSAJWKPrivateKey]
def load_jwk(jwk: RSAJWKPublicKey | RSAJWKPrivateKey | str) -> JWKKeyPair:
jwk_dict: RSAJWKPublicKey | RSAJWKPrivateKey

if isinstance(jwk, str):
jwk_dict = json.loads(jwk)
Expand All @@ -92,10 +97,10 @@ def load_jwk(jwk: Union[RSAJWKPublicKey, RSAJWKPrivateKey, str]) -> KeyPair:
private_key = load_jwk_private_key(jwk_dict) # type: ignore
public_key = private_key.public_key()
else: # Public key
public_key = load_jwk_public_key(jwk_dict) # type: ignore
public_key = load_jwk_public_key(jwk_dict)
private_key = None

return KeyPair(private=private_key, public=public_key)
return JWKKeyPair(private=private_key, public=public_key)


def int_to_base64url(value: int) -> str:
Expand All @@ -106,24 +111,24 @@ def int_to_base64url(value: int) -> str:

@overload
def rsa_to_jwk(
key: RSAPublicKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig"
key: RSAPublicKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig",
) -> RSAJWKPublicKey: ...


@overload
def rsa_to_jwk( # type: ignore[overload-cannot-match]
def rsa_to_jwk(
key: RSAPrivateKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig",
) -> RSAJWKPrivateKey: ...


def rsa_to_jwk(
key: Union[RSAPrivateKey, RSAPublicKey],
key: RSAPrivateKey | RSAPublicKey,
*,
kid: str = "",
alg: AlgorithmType = "RS256",
use: str = "sig",
kty: str = "RSA",
) -> Union[RSAJWKPublicKey, RSAJWKPrivateKey]:
) -> RSAJWKPublicKey | RSAJWKPrivateKey:
if isinstance(key, RSAPublicKey):
public_numbers = key.public_numbers()
private_numbers = None
Expand Down Expand Up @@ -161,12 +166,14 @@ def rsa_to_jwk(
)


def load_private_key(data: Union[str, RSAJWKPrivateKey, Path]) -> RSAPrivateKey:
def load_private_key(data: str | RSAJWKPrivateKey | Path) -> RSAPrivateKey:
if isinstance(data, Path):
data = data.read_text()
if isinstance(data, str):
if data.startswith("-----BEGIN "):
return serialization.load_pem_private_key(data.encode(), None, default_backend())
result = serialization.load_pem_private_key(data.encode(), None, default_backend())
assert isinstance(result, RSAPrivateKey)
return result
if data.strip().startswith("{"):
return load_jwk_private_key(json.loads(data))
if isinstance(data, dict):
Expand All @@ -177,12 +184,14 @@ def load_private_key(data: Union[str, RSAJWKPrivateKey, Path]) -> RSAPrivateKey:
return key


def load_public_key(data: Union[str, RSAJWKPublicKey, Path]) -> RSAPublicKey:
def load_public_key(data: str | RSAJWKPublicKey | Path) -> RSAPublicKey:
if isinstance(data, Path):
data = data.read_text()
if isinstance(data, str):
if data.startswith("-----BEGIN "):
return serialization.load_pem_public_key(data.encode(), default_backend())
result = serialization.load_pem_public_key(data.encode(), default_backend())
assert isinstance(result, RSAPublicKey)
return result
if data.strip().startswith("{"):
return load_jwk_public_key(json.loads(data))
if isinstance(data, dict):
Expand Down
Loading
Loading