diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py index 3ca5f9b59..bc609e42e 100644 --- a/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte_cdk/sources/declarative/auth/oauth.py @@ -3,7 +3,7 @@ # from dataclasses import InitVar, dataclass, field -from datetime import timedelta +from datetime import datetime, timedelta from typing import Any, List, Mapping, MutableMapping, Optional, Union from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator @@ -232,8 +232,13 @@ def get_refresh_request_headers(self) -> Mapping[str, Any]: return self._refresh_request_headers.eval(self.config) def get_token_expiry_date(self) -> AirbyteDateTime: + if not self._has_access_token_been_initialized(): + return AirbyteDateTime.from_datetime(datetime.min) return self._token_expiry_date # type: ignore # _token_expiry_date is an AirbyteDateTime. It is never None despite what mypy thinks + def _has_access_token_been_initialized(self) -> bool: + return self._access_token is not None + def set_token_expiry_date(self, value: Union[str, int]) -> None: self._token_expiry_date = self._parse_token_expiration_date(value) diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 0a9b15bc0..df936f8b6 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -261,6 +261,9 @@ def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTim :return: expiration datetime """ + if not value and not self.token_has_expired(): + # No expiry token was provided but the previous one is not expired so it's fine + return self.get_token_expiry_date() if self.token_expiry_is_time_of_expiration: if not self.token_expiry_date_format: diff --git a/unit_tests/sources/declarative/auth/test_oauth.py b/unit_tests/sources/declarative/auth/test_oauth.py index c1ce9326e..c54b9982f 100644 --- a/unit_tests/sources/declarative/auth/test_oauth.py +++ b/unit_tests/sources/declarative/auth/test_oauth.py @@ -301,6 +301,7 @@ def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp( client_id="{{ config['client_id'] }}", client_secret="{{ config['client_secret'] }}", token_expiry_date=timestamp, + access_token_value="some_access_token", refresh_token="some_refresh_token", config={ "refresh_endpoint": "refresh_end", @@ -313,6 +314,34 @@ def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp( assert isinstance(oauth._token_expiry_date, AirbyteDateTime) assert oauth.get_token_expiry_date() == ab_datetime_parse(expected_date) + def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_fetch_access_token( + self, + ) -> None: + expiry_date = ab_datetime_now().add(timedelta(days=1)) + oauth = DeclarativeOauth2Authenticator( + token_refresh_endpoint="https://refresh_endpoint.com/", + client_id="some_client_id", + client_secret="some_client_secret", + token_expiry_date=expiry_date.isoformat(), + refresh_token="some_refresh_token", + config={}, + parameters={}, + grant_type="client", + ) + + with HttpMocker() as http_mocker: + http_mocker.post( + HttpRequest( + url="https://refresh_endpoint.com/", + body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token", + ), + HttpResponse(body=json.dumps({"access_token": "new_access_token"})), + ) + oauth.get_access_token() + + assert oauth.access_token == "new_access_token" + assert oauth._token_expiry_date == expiry_date + @pytest.mark.parametrize( "expires_in_response, token_expiry_date_format", [