From 16129375d15fdc3c3c93c7a39a58cc8436d02b85 Mon Sep 17 00:00:00 2001 From: Alex Golec Date: Sat, 14 Sep 2024 08:38:02 -0400 Subject: [PATCH] Pass oauth state properly --- schwab/auth.py | 8 +++++--- tests/auth_test.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/schwab/auth.py b/schwab/auth.py index fb4c4df..fa32b46 100644 --- a/schwab/auth.py +++ b/schwab/auth.py @@ -578,10 +578,11 @@ async def oauth_client_update_token(t, *args, **kwargs): AuthContext = collections.namedtuple( 'AuthContext', ['callback_url', 'authorization_url', 'state']) -def get_auth_context(api_key, callback_url): +def get_auth_context(api_key, callback_url, state=None): oauth = OAuth2Client(api_key, redirect_uri=callback_url) authorization_url, state = oauth.create_authorization_url( - 'https://api.schwabapi.com/v1/oauth/authorize') + 'https://api.schwabapi.com/v1/oauth/authorize', + state=state) return AuthContext(callback_url, authorization_url, state) @@ -597,7 +598,8 @@ def client_from_received_url( token = oauth.fetch_token( TOKEN_ENDPOINT, authorization_response=received_url, - client_id=api_key, auth=(api_key, app_secret)) + client_id=api_key, auth=(api_key, app_secret), + state=auth_context.state) # Don't emit token details in debug logs register_redactions(token) diff --git a/tests/auth_test.py b/tests/auth_test.py index f1bb954..df19d1b 100644 --- a/tests/auth_test.py +++ b/tests/auth_test.py @@ -536,6 +536,54 @@ def token_write_func(token): API_KEY, _, token_metadata=_, enforce_enums=True) +# Note the client_from_received_url is called internally by the other client +# generation functions, so testing here is kept light +class ClientFromReceivedUrl(unittest.TestCase): + + def setUp(self): + self.tmp_dir = tempfile.TemporaryDirectory() + self.token_path = os.path.join(self.tmp_dir.name, 'token.json') + self.raw_token = {'token': 'yes'} + + @no_duplicates + @patch('schwab.auth.Client') + @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) + @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) + def test_success( + self, async_session, sync_session, client): + AUTH_URL = 'https://auth.url.com' + + sync_session.return_value = sync_session + sync_session.create_authorization_url.return_value = \ + AUTH_URL, 'oauth state' + sync_session.fetch_token.return_value = self.raw_token + + auth_context = auth.get_auth_context(API_KEY, CALLBACK_URL) + self.assertEqual(AUTH_URL, auth_context.authorization_url) + self.assertEqual('oauth state', auth_context.state) + + client.return_value = 'returned client' + token_capture = [] + client = auth.client_from_received_url( + API_KEY, APP_SECRET, auth_context, + 'http://redirect.url.com/?data', + lambda token: token_capture.append(token)) + + # Verify that the oauth state is correctly passed along + sync_session.fetch_token.assert_called_once_with( + _, + authorization_response=_, + client_id=_, + auth=_, + state='oauth state') + + self.assertEqual([{ + 'creation_timestamp': MOCK_NOW, + 'token': self.raw_token + }], token_capture) + + class ClientFromManualFlow(unittest.TestCase): def setUp(self):