Skip to content

Commit

Permalink
Pass oauth state properly
Browse files Browse the repository at this point in the history
  • Loading branch information
alexgolec committed Sep 14, 2024
1 parent 3ae305c commit 1612937
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
8 changes: 5 additions & 3 deletions schwab/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,11 @@ async def oauth_client_update_token(t, *args, **kwargs):
AuthContext = collections.namedtuple(
'AuthContext', ['callback_url', 'authorization_url', 'state'])

def get_auth_context(api_key, callback_url):
def get_auth_context(api_key, callback_url, state=None):
oauth = OAuth2Client(api_key, redirect_uri=callback_url)
authorization_url, state = oauth.create_authorization_url(
'https://api.schwabapi.com/v1/oauth/authorize')
'https://api.schwabapi.com/v1/oauth/authorize',
state=state)

return AuthContext(callback_url, authorization_url, state)

Expand All @@ -597,7 +598,8 @@ def client_from_received_url(
token = oauth.fetch_token(
TOKEN_ENDPOINT,
authorization_response=received_url,
client_id=api_key, auth=(api_key, app_secret))
client_id=api_key, auth=(api_key, app_secret),
state=auth_context.state)

# Don't emit token details in debug logs
register_redactions(token)
Expand Down
48 changes: 48 additions & 0 deletions tests/auth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,54 @@ def token_write_func(token):
API_KEY, _, token_metadata=_, enforce_enums=True)


# Note the client_from_received_url is called internally by the other client
# generation functions, so testing here is kept light
class ClientFromReceivedUrl(unittest.TestCase):

def setUp(self):
self.tmp_dir = tempfile.TemporaryDirectory()
self.token_path = os.path.join(self.tmp_dir.name, 'token.json')
self.raw_token = {'token': 'yes'}

@no_duplicates
@patch('schwab.auth.Client')
@patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient)
@patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient)
@patch('time.time', MagicMock(return_value=MOCK_NOW))
def test_success(
self, async_session, sync_session, client):
AUTH_URL = 'https://auth.url.com'

sync_session.return_value = sync_session
sync_session.create_authorization_url.return_value = \
AUTH_URL, 'oauth state'
sync_session.fetch_token.return_value = self.raw_token

auth_context = auth.get_auth_context(API_KEY, CALLBACK_URL)
self.assertEqual(AUTH_URL, auth_context.authorization_url)
self.assertEqual('oauth state', auth_context.state)

client.return_value = 'returned client'
token_capture = []
client = auth.client_from_received_url(
API_KEY, APP_SECRET, auth_context,
'http://redirect.url.com/?data',
lambda token: token_capture.append(token))

# Verify that the oauth state is correctly passed along
sync_session.fetch_token.assert_called_once_with(
_,
authorization_response=_,
client_id=_,
auth=_,
state='oauth state')

self.assertEqual([{
'creation_timestamp': MOCK_NOW,
'token': self.raw_token
}], token_capture)


class ClientFromManualFlow(unittest.TestCase):

def setUp(self):
Expand Down

0 comments on commit 1612937

Please sign in to comment.