diff --git a/examples/seznam.py b/examples/seznam.py index d6dcb2e..76eaf81 100644 --- a/examples/seznam.py +++ b/examples/seznam.py @@ -31,7 +31,9 @@ async def auth_init(): async def auth_callback(request: Request): """Verify login""" with sso: - user = await sso.verify_and_process(request, params={"client_secret": CLIENT_SECRET}) # <- "client_secret" parameter is needed! + user = await sso.verify_and_process( + request, params={"client_secret": CLIENT_SECRET} + ) # <- "client_secret" parameter is needed! return user diff --git a/fastapi_sso/sso/base.py b/fastapi_sso/sso/base.py index c9e6c85..1f88fe4 100644 --- a/fastapi_sso/sso/base.py +++ b/fastapi_sso/sso/base.py @@ -7,7 +7,7 @@ import sys import warnings from types import TracebackType -from typing import Any, ClassVar, Dict, List, Literal, Optional, Type, TypedDict, TypeVar, Union, overload +from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, TypedDict, TypeVar, Union, overload import httpx import pydantic @@ -110,12 +110,14 @@ def __init__( allow_insecure_http: bool = False, use_state: bool = False, scope: Optional[List[str]] = None, + get_async_client: Optional[Callable[[], httpx.AsyncClient]] = None, ): """Base class (mixin) for all SSO providers.""" self.client_id: str = client_id self.client_secret: str = client_secret self.redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = redirect_uri self.allow_insecure_http: bool = allow_insecure_http + self.get_async_client: Callable[[], httpx.AsyncClient] = get_async_client or httpx.AsyncClient self._login_lock = asyncio.Lock() self._in_stack = False self._oauth_client: Optional[WebApplicationClient] = None @@ -330,10 +332,10 @@ async def verify_and_process( self, request: Request, *, - params: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, Any]] = None, - redirect_uri: Optional[str] = None, - convert_response: Literal[True] = True, + params: Optional[Dict[str, Any]], + headers: Optional[Dict[str, Any]], + redirect_uri: Optional[str], + convert_response: Literal[True], ) -> Optional[OpenID]: ... @overload @@ -458,11 +460,11 @@ async def process_login( code: str, request: Request, *, - params: Optional[Dict[str, Any]] = None, - additional_headers: Optional[Dict[str, Any]] = None, - redirect_uri: Optional[str] = None, - pkce_code_verifier: Optional[str] = None, - convert_response: Literal[True] = True, + params: Optional[Dict[str, Any]], + additional_headers: Optional[Dict[str, Any]], + redirect_uri: Optional[str], + pkce_code_verifier: Optional[str], + convert_response: Literal[True], ) -> Optional[OpenID]: ... @overload @@ -471,10 +473,10 @@ async def process_login( code: str, request: Request, *, - params: Optional[Dict[str, Any]] = None, - additional_headers: Optional[Dict[str, Any]] = None, - redirect_uri: Optional[str] = None, - pkce_code_verifier: Optional[str] = None, + params: Optional[Dict[str, Any]], + additional_headers: Optional[Dict[str, Any]], + redirect_uri: Optional[str], + pkce_code_verifier: Optional[str], convert_response: Literal[False], ) -> Optional[Dict[str, Any]]: ... @@ -552,7 +554,7 @@ async def process_login( auth = httpx.BasicAuth(self.client_id, self.client_secret) - async with httpx.AsyncClient() as session: + async with self.get_async_client() as session: response = await session.post(token_url, headers=headers, content=body, auth=auth) content = response.json() self._refresh_token = content.get("refresh_token") diff --git a/tests/test_providers.py b/tests/test_providers.py index fb06956..ac0830c 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -151,8 +151,8 @@ async def test_process_login(self, Provider: Type[SSOBase], monkeypatch: pytest. async def fake_openid_from_response(_, __): return OpenID(id="test", email="email@example.com", display_name="Test") - async with sso: - monkeypatch.setattr("httpx.AsyncClient", FakeAsyncClient) + with sso: + monkeypatch.setattr(sso, "get_async_client", FakeAsyncClient) monkeypatch.setattr(sso, "openid_from_response", fake_openid_from_response) request = Request(url="https://localhost?code=code&state=unique") await sso.process_login("code", request) diff --git a/tests/test_race_condition.py b/tests/test_race_condition.py index dda8333..4569e6b 100644 --- a/tests/test_race_condition.py +++ b/tests/test_race_condition.py @@ -54,29 +54,26 @@ async def get(self, *args, **kwargs): await asyncio.sleep(0) return Response(token="") - with patch("fastapi_sso.sso.base.httpx") as httpx: - httpx.AsyncClient = AsyncClient - - first_response = Response(token="first_token") # noqa: S106 - second_response = Response(token="second_token") # noqa: S106 - - AsyncClient.post_responses = [second_response, first_response] # reversed order because of `pop` - - async def process_login(): - # this coro will be executed concurrently. - # completely not caring about the params - request = Mock() - request.url = URL("https://url.com?state=state&code=code") - async with provider: - await provider.process_login( - code="code", request=request, params=dict(state="state"), convert_response=False - ) - return provider.access_token - - # process login concurrently twice - tasks = [process_login(), process_login()] - results = await asyncio.gather(*tasks) - - # we would want to get the first and second tokens, - # but we see that the first request actually obtained the second token as well - assert results == [first_response.token, second_response.token] + first_response = Response(token="first_token") # noqa: S106 + second_response = Response(token="second_token") # noqa: S106 + AsyncClient.post_responses = [second_response, first_response] # reversed order because of `pop` + provider.get_async_client = AsyncClient + + async def process_login(): + # this coro will be executed concurrently. + # completely not caring about the params + request = Mock() + request.url = URL("https://url.com?state=state&code=code") + async with provider: + await provider.process_login( + code="code", request=request, params=dict(state="state"), convert_response=False + ) + return provider.access_token + + # process login concurrently twice + tasks = [process_login(), process_login()] + results = await asyncio.gather(*tasks) + + # we would want to get the first and second tokens, + # but we see that the first request actually obtained the second token as well + assert results == [first_response.token, second_response.token]