diff --git a/tests/auth_test.py b/tests/auth_test.py index df19d1b..8f0a190 100644 --- a/tests/auth_test.py +++ b/tests/auth_test.py @@ -547,11 +547,12 @@ def setUp(self): @no_duplicates @patch('schwab.auth.Client') + @patch('schwab.auth.AsyncClient') @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) @patch('time.time', MagicMock(return_value=MOCK_NOW)) - def test_success( - self, async_session, sync_session, client): + def test_success_sync( + self, async_session, sync_session, async_client, client): AUTH_URL = 'https://auth.url.com' sync_session.return_value = sync_session @@ -565,11 +566,56 @@ def test_success( client.return_value = 'returned client' token_capture = [] - client = auth.client_from_received_url( + auth.client_from_received_url( API_KEY, APP_SECRET, auth_context, 'http://redirect.url.com/?data', lambda token: token_capture.append(token)) + client.assert_called_once() + async_client.assert_not_called() + + # Verify that the oauth state is correctly passed along + sync_session.fetch_token.assert_called_once_with( + _, + authorization_response=_, + client_id=_, + auth=_, + state='oauth state') + + self.assertEqual([{ + 'creation_timestamp': MOCK_NOW, + 'token': self.raw_token + }], token_capture) + + + @no_duplicates + @patch('schwab.auth.Client') + @patch('schwab.auth.AsyncClient') + @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) + @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) + def test_success_async( + self, async_session, sync_session, async_client, client): + AUTH_URL = 'https://auth.url.com' + + sync_session.return_value = sync_session + sync_session.create_authorization_url.return_value = \ + AUTH_URL, 'oauth state' + sync_session.fetch_token.return_value = self.raw_token + + auth_context = auth.get_auth_context(API_KEY, CALLBACK_URL) + + client.return_value = 'returned client' + token_capture = [] + auth.client_from_received_url( + API_KEY, APP_SECRET, auth_context, + 'http://redirect.url.com/?data', + lambda token: token_capture.append(token), + asyncio=True) + + async_client.assert_called_once() + client.assert_not_called() + # Verify that the oauth state is correctly passed along sync_session.fetch_token.assert_called_once_with( _,