Skip to content

Commit

Permalink
Merge pull request #102 from maykinmedia/feature/client-that-can-refr…
Browse files Browse the repository at this point in the history
…esh-tokens

Client that can refresh tokens
  • Loading branch information
SilviaAmAm authored Dec 3, 2024
2 parents 74c24ff + 0548042 commit d266e4a
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 3 deletions.
31 changes: 31 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import time

import jwt
import pytest
from freezegun import freeze_time

from zgw_consumers.client import ZGWAuth
from zgw_consumers.constants import AuthTypes
from zgw_consumers.test.factories import ServiceFactory


@pytest.mark.django_db
def test_zgw_auth_refresh_token():
service = ServiceFactory.create(
api_root="https://example.com/",
auth_type=AuthTypes.zgw,
client_id="my-client-id",
secret="my-secret",
)

with freeze_time("2024-11-27T10:00:00+02:00"):
auth = ZGWAuth(service)
token = jwt.decode(auth._token, service.secret, algorithms=["HS256"])

assert token["iat"] == int(time.time())

with freeze_time("2024-11-27T15:00:00+02:00"): # 5 hours later
auth.refresh_token()
token = jwt.decode(auth._token, service.secret, algorithms=["HS256"])

assert token["iat"] == int(time.time())
103 changes: 103 additions & 0 deletions tests/test_clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import time

import jwt
import pytest
import requests_mock
from freezegun import freeze_time

from zgw_consumers.client import build_client
from zgw_consumers.constants import AuthTypes
from zgw_consumers.test.factories import ServiceFactory


@pytest.mark.django_db
def test_retry_request_on_403_auth_zgw():
service = ServiceFactory.create(
api_root="https://example.com/",
auth_type=AuthTypes.zgw,
client_id="my-client-id",
secret="my-secret",
)

with requests_mock.Mocker() as m:
m.get(
"https://example.com/",
status_code=403,
)

with freeze_time("2024-11-27T10:00:00+02:00"):
initial_time = int(time.time())
client = build_client(service)

with freeze_time("2024-11-27T15:00:00+02:00"): # 5h later
later_time = int(time.time())
with client:
client.get("https://example.com/")

history = m.request_history

assert len(history) == 2

first_request = history[0]
first_token = first_request.headers["Authorization"].removeprefix("Bearer ")
time1 = jwt.decode(first_token, service.secret, algorithms=["HS256"])["iat"]

assert time1 == initial_time

second_request = history[1]
second_token = second_request.headers["Authorization"].removeprefix("Bearer ")
time2 = jwt.decode(second_token, service.secret, algorithms=["HS256"])["iat"]

assert time2 == later_time


@pytest.mark.django_db
def test_retry_request_on_403_auth_api_key():
service = ServiceFactory.create(
api_root="https://example.com/",
auth_type=AuthTypes.api_key,
header_key="Some-Auth-Header",
header_value="some-api-key",
)

with requests_mock.Mocker() as m:
m.get(
"https://example.com/",
status_code=403,
)

with freeze_time("2024-11-27T10:00:00+02:00"):
client = build_client(service)

with freeze_time("2024-11-27T15:00:00+02:00"): # 5h later
with client:
client.get("https://example.com/")

history = m.request_history

assert len(history) == 1


@pytest.mark.django_db
def test_retry_request_on_403_no_auth():
service = ServiceFactory.create(
api_root="https://example.com/",
auth_type=AuthTypes.no_auth,
)

with requests_mock.Mocker() as m:
m.get(
"https://example.com/",
status_code=403,
)

with freeze_time("2024-11-27T10:00:00+02:00"):
client = build_client(service)

with freeze_time("2024-11-27T15:00:00+02:00"): # 5h later
with client:
client.get("https://example.com/")

history = m.request_history

assert len(history) == 1
9 changes: 7 additions & 2 deletions zgw_consumers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ class ZGWAuth(AuthBase):
service: Service

def __post_init__(self):
# Generate the JWT Bearer token. Ported from gemma-zds-client ClientAuth.
self._token = self._generate_token()

def _generate_token(self) -> str:
payload = {
# standard claims
"iss": self.service.client_id,
Expand All @@ -107,8 +109,11 @@ def __post_init__(self):
"user_representation": self.service.user_representation,
}

self._token: str = jwt.encode(payload, self.service.secret, algorithm="HS256")
return jwt.encode(payload, self.service.secret, algorithm="HS256")

def __call__(self, request: PreparedRequest):
request.headers["Authorization"] = f"Bearer {self._token}"
return request

def refresh_token(self):
self._token = self._generate_token()
18 changes: 18 additions & 0 deletions zgw_consumers/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from requests import Response


class RefreshTokenMixin:
def request(
self, method: str | bytes, url: str | bytes, *args, **kwargs
) -> Response:
from .client import ZGWAuth # circular import

response = super().request(method, url, *args, **kwargs)

if response.status_code != 403 or not isinstance(self.auth, ZGWAuth):
return response

self.auth.refresh_token()

# Retry with the fresh credentials
return super().request(method, url, *args, **kwargs)
3 changes: 2 additions & 1 deletion zgw_consumers/nlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from requests.models import PreparedRequest, Request, Response
from requests.utils import guess_json_utf

from .mixins import RefreshTokenMixin
from .models import NLXConfig, Service

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -150,7 +151,7 @@ def prepare_request(self, request: Request) -> PreparedRequest:
return prepared_request


class NLXClient(NLXMixin, APIClient):
class NLXClient(NLXMixin, RefreshTokenMixin, APIClient):
"""A :class:`ape_pie.APIClient` implementation that will take care of rewriting
URLs with :external:ref:`an event hook <event-hooks>`.
"""
Expand Down

0 comments on commit d266e4a

Please sign in to comment.