Skip to content

Commit

Permalink
feat(OAuth): allow for access_token without expiration from the API (#…
Browse files Browse the repository at this point in the history
…324)

Co-authored-by: octavia-squidington-iii <[email protected]>
  • Loading branch information
maxi297 and octavia-squidington-iii authored Feb 7, 2025
1 parent a32fea4 commit 6260248
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
7 changes: 6 additions & 1 deletion airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

from dataclasses import InitVar, dataclass, field
from datetime import timedelta
from datetime import datetime, timedelta
from typing import Any, List, Mapping, MutableMapping, Optional, Union

from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
Expand Down Expand Up @@ -232,8 +232,13 @@ def get_refresh_request_headers(self) -> Mapping[str, Any]:
return self._refresh_request_headers.eval(self.config)

def get_token_expiry_date(self) -> AirbyteDateTime:
if not self._has_access_token_been_initialized():
return AirbyteDateTime.from_datetime(datetime.min)
return self._token_expiry_date # type: ignore # _token_expiry_date is an AirbyteDateTime. It is never None despite what mypy thinks

def _has_access_token_been_initialized(self) -> bool:
return self._access_token is not None

def set_token_expiry_date(self, value: Union[str, int]) -> None:
self._token_expiry_date = self._parse_token_expiration_date(value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTim
:return: expiration datetime
"""
if not value and not self.token_has_expired():
# No expiry token was provided but the previous one is not expired so it's fine
return self.get_token_expiry_date()

if self.token_expiry_is_time_of_expiration:
if not self.token_expiry_date_format:
Expand Down
29 changes: 29 additions & 0 deletions unit_tests/sources/declarative/auth/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp(
client_id="{{ config['client_id'] }}",
client_secret="{{ config['client_secret'] }}",
token_expiry_date=timestamp,
access_token_value="some_access_token",
refresh_token="some_refresh_token",
config={
"refresh_endpoint": "refresh_end",
Expand All @@ -313,6 +314,34 @@ def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp(
assert isinstance(oauth._token_expiry_date, AirbyteDateTime)
assert oauth.get_token_expiry_date() == ab_datetime_parse(expected_date)

def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_fetch_access_token(
self,
) -> None:
expiry_date = ab_datetime_now().add(timedelta(days=1))
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="https://refresh_endpoint.com/",
client_id="some_client_id",
client_secret="some_client_secret",
token_expiry_date=expiry_date.isoformat(),
refresh_token="some_refresh_token",
config={},
parameters={},
grant_type="client",
)

with HttpMocker() as http_mocker:
http_mocker.post(
HttpRequest(
url="https://refresh_endpoint.com/",
body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token",
),
HttpResponse(body=json.dumps({"access_token": "new_access_token"})),
)
oauth.get_access_token()

assert oauth.access_token == "new_access_token"
assert oauth._token_expiry_date == expiry_date

@pytest.mark.parametrize(
"expires_in_response, token_expiry_date_format",
[
Expand Down

0 comments on commit 6260248

Please sign in to comment.