diff --git a/schwab/auth.py b/schwab/auth.py index 0895431..a488a20 100644 --- a/schwab/auth.py +++ b/schwab/auth.py @@ -1,9 +1,7 @@ - - from authlib.integrations.httpx_client import AsyncOAuth2Client, OAuth2Client from prompt_toolkit import prompt -import urllib +import contextlib import json import logging import multiprocessing @@ -13,6 +11,7 @@ import requests import sys import time +import urllib import urllib3 import warnings import webbrowser @@ -165,6 +164,10 @@ async def oauth_client_update_token(t, *args, **kwargs): token_metadata=metadata_manager, enforce_enums=enforce_enums) +################################################################################ +# client_from_login_flow + + # This runs in a separate process and is invisible to coverage def __run_client_from_login_flow_server( q, callback_port, callback_path): # pragma: no cover @@ -184,7 +187,16 @@ def handle_token(): def status(): return 'running' - app.run(port=callback_port, ssl_context='adhoc') + # Wrap this call in some hackery to suppress the flask startup messages + with open(os.devnull, 'w') as devnull: + import logging + log = logging.getLogger('werkzeug') + log.setLevel(logging.ERROR) + + old_stdout = sys.stdout + sys.stdout = devnull + app.run(port=callback_port, ssl_context='adhoc') + sys.stdout = old_stdout class RedirectTimeoutError(Exception): @@ -193,16 +205,21 @@ class RedirectTimeoutError(Exception): class RedirectServerExitedError(Exception): pass +# Capture the real time.time so that we can use it in server initialization +# while simultaneously mocking it in testing +__TIME_TIME = time.time def client_from_login_flow(api_key, app_secret, callback_url, token_path, asyncio=False, enforce_enums=False, - token_write_func=None, callback_timeout=300.0): + token_write_func=None, callback_timeout=300.0, + interactive=True): # TODO: documentation # Start the server parsed = urllib.parse.urlparse(callback_url) if parsed.hostname != '127.0.0.1': + # TODO: document this error raise ValueError( ('disallowed hostname {}. client_from_login_flow only allows '+ 'callback URLs with hostname 127.0.0.1').format( @@ -217,71 +234,110 @@ def client_from_login_flow(api_key, app_secret, callback_url, token_path, target=__run_client_from_login_flow_server, args=(output_queue, callback_port, callback_path)) - print('Running a server to intercept the callback. Please ignore the ' + - 'following debug messages:') - print() - server.start() - - # Wait until the server successfully starts - while True: - # Check if the server is still alive - if server.exitcode is not None: - raise RedirectServerExitedError( - 'Redirect server exited. Are you attempting to use a ' + - 'callback URL without a port number specified?') + # Context manager to kill the server upon completion + @contextlib.contextmanager + def callback_server(): + server.start() - import traceback - - # Attempt to send a request to the server try: - with warnings.catch_warnings(): - warnings.filterwarnings( - 'ignore', category=urllib3.exceptions.InsecureRequestWarning) - - resp = requests.get( - 'https://127.0.0.1:{}/schwab-py-internal/status'.format( - callback_port), verify=False) - break - except requests.exceptions.ConnectionError as e: - pass + yield + finally: + try: + psutil.Process(server.pid).kill() + except psutil.NoSuchProcess: + pass + + with callback_server(): + # Wait until the server successfully starts + while True: + # Check if the server is still alive + if server.exitcode is not None: + # TODO: document this error + raise RedirectServerExitedError( + 'Redirect server exited. Are you attempting to use a ' + + 'callback URL without a port number specified?') + + import traceback + + # Attempt to send a request to the server + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + 'ignore', + category=urllib3.exceptions.InsecureRequestWarning) + + resp = requests.get( + 'https://127.0.0.1:{}/schwab-py-internal/status'.format( + callback_port), verify=False) + break + except requests.exceptions.ConnectionError as e: + pass + + time.sleep(0.1) + + # Open the browser + oauth = OAuth2Client(api_key, redirect_uri=callback_url) + authorization_url, state = oauth.create_authorization_url( + 'https://api.schwabapi.com/v1/oauth/authorize') + + if interactive: + print() + print('**************************************************************') + print() + print('This is the browser-assisted login and token creation flow for') + print('schwab-py. This flow automatically opens the login page on your') + print('browser, captures the resulting OAuth callback, and creates a token') + print('using the result.') + print() + print('IMPORTANT: Your browser will give you a security warning about an') + print('invalid certificate prior to issuing the redirect. This is because') + print('schwab-py has started a server on your machine to receive the OAuth') + print('redirect using a self-signed SSL certificate. You can ignore that') + print('warning, but make sure to first check that the URL matches your') + print('callback URL. As a reminder, your callback URL is:') + print() + print('>>',callback_url) + print() + print('See here to learn more: TODO') + print() + print('If you encounter any issues, see here for troubleshooting:') + print('https://schwab-py.readthedocs.io/en/latest/auth.html#troubleshooting') + print('\n**************************************************************') + print() + prompt('Press ENTER to open the browser. Note you can run ' + + 'client_from_login_flow with interactive=False to skip this input') + + webbrowser.open(authorization_url) + + # Wait for a response + now = __TIME_TIME() + timeout_time = now + callback_timeout + received_url = None + while now < timeout_time: + # Attempt to fetch from the queue + try: + received_url = output_queue.get( + timeout=min(timeout_time - now, 0.1)) + break + except queue.Empty: + pass + + now = __TIME_TIME() + + if not received_url: + # TODO: document this error + raise RedirectTimeoutError( + 'Timed out waiting for a post-authorization callback. You '+ + 'can set a longer timeout by passing a value of ' + + 'callback_timeout to client_from_login_flow.') - time.sleep(0.1) - - # Open the browser - oauth = OAuth2Client(api_key, redirect_uri=callback_url) - authorization_url, state = oauth.create_authorization_url( - 'https://api.schwabapi.com/v1/oauth/authorize') - - webbrowser.open(authorization_url) - - # Wait for a response - now = time.time() - timeout_time = now + callback_timeout - callback_url = None - while now < timeout_time: - # Attempt to fetch from the queue - try: - callback_url = output_queue.get( - timeout=min(timeout_time - now, 0.1)) - break - except queue.Empty: - pass - - now = time.time() - - # Clean up and create the client - psutil.Process(server.pid).kill() - - if callback_url: return __fetch_and_register_token_from_redirect( - oauth, callback_url, api_key, app_secret, token_path, token_write_func, - asyncio, enforce_enums=enforce_enums) - else: - raise RedirectTimeoutError( - 'Timed out waiting for a post-authorization callback. You '+ - 'can set a longer timeout by passing a value of ' + - 'callback_timeout to client_from_login_flow.') + oauth, received_url, api_key, app_secret, token_path, + token_write_func, asyncio, enforce_enums=enforce_enums) + +################################################################################ +# client_from_token_path def client_from_token_file(token_path, api_key, app_secret, asyncio=False, @@ -313,6 +369,10 @@ def client_from_token_file(token_path, api_key, app_secret, asyncio=False, enforce_enums=enforce_enums) +################################################################################ +# client_from_manual_flow + + def client_from_manual_flow(api_key, app_secret, callback_url, token_path, asyncio=False, token_write_func=None, enforce_enums=True): @@ -385,6 +445,10 @@ def client_from_manual_flow(api_key, app_secret, callback_url, token_path, asyncio, enforce_enums=enforce_enums) +################################################################################ +# client_from_access_functions + + def client_from_access_functions(api_key, app_secret, token_read_func, token_write_func, asyncio=False, enforce_enums=True): diff --git a/tests/auth_test.py b/tests/auth_test.py index 63fdfe3..41395c5 100644 --- a/tests/auth_test.py +++ b/tests/auth_test.py @@ -14,8 +14,6 @@ import tempfile import unittest -import schwab - API_KEY = 'APIKEY' APP_SECRET = '0x5EC07' @@ -39,7 +37,8 @@ def setUp(self): @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) @patch('schwab.auth.webbrowser.open', new_callable=MagicMock) - @patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW)) + @patch('schwab.auth.prompt', MagicMock(return_value='')) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) def test_create_token_file( self, mock_webbrowser_open, async_session, sync_session, client): AUTH_URL = 'https://auth.url.com' @@ -70,7 +69,44 @@ def test_create_token_file( @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) @patch('schwab.auth.webbrowser.open', new_callable=MagicMock) - @patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW)) + @patch('schwab.auth.prompt') + @patch('time.time', MagicMock(return_value=MOCK_NOW)) + def test_create_token_file_not_interactive( + self, mock_prompt,mock_webbrowser_open, async_session, sync_session, + client): + AUTH_URL = 'https://auth.url.com' + + sync_session.return_value = sync_session + sync_session.create_authorization_url.return_value = AUTH_URL, None + sync_session.fetch_token.return_value = self.raw_token + + callback_url = 'https://127.0.0.1:6969/callback' + + mock_webbrowser_open.side_effect = \ + lambda auth_url: requests.get( + 'https://127.0.0.1:6969/callback', verify=False) + + client.return_value = 'returned client' + + auth.client_from_login_flow( + API_KEY, APP_SECRET, callback_url, self.token_path, + interactive=False) + + with open(self.token_path, 'r') as f: + self.assertEqual({ + 'creation_timestamp': MOCK_NOW, + 'token': self.raw_token + }, json.load(f)) + + mock_prompt.assert_not_called() + + + @patch('schwab.auth.Client') + @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) + @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) + @patch('schwab.auth.webbrowser.open', new_callable=MagicMock) + @patch('schwab.auth.prompt', MagicMock(return_value='')) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) def test_create_token_file_root_callback_url( self, mock_webbrowser_open, async_session, sync_session, client): AUTH_URL = 'https://auth.url.com' @@ -101,12 +137,13 @@ def test_create_token_file_root_callback_url( @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) @patch('schwab.auth.webbrowser.open', new_callable=MagicMock) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) def test_disallowed_hostname( self, mock_webbrowser_open, async_session, sync_session, client): callback_url = 'https://example.com/callback' with self.assertRaisesRegex( - ValueError,'disallowed hostname example.com'): + ValueError, 'disallowed hostname example.com'): auth.client_from_login_flow( API_KEY, APP_SECRET, callback_url, self.token_path) @@ -115,12 +152,13 @@ def test_disallowed_hostname( @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) @patch('schwab.auth.webbrowser.open', new_callable=MagicMock) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) def test_disallowed_hostname_with_port( self, mock_webbrowser_open, async_session, sync_session, client): callback_url = 'https://example.com:8080/callback' with self.assertRaisesRegex( - ValueError,'disallowed hostname example.com'): + ValueError, 'disallowed hostname example.com'): auth.client_from_login_flow( API_KEY, APP_SECRET, callback_url, self.token_path) @@ -129,11 +167,12 @@ def test_disallowed_hostname_with_port( @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) @patch('schwab.auth.webbrowser.open', new_callable=MagicMock) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) def test_unprivileged_start_on_port_80( self, mock_webbrowser_open, async_session, sync_session, client): callback_url = 'https://127.0.0.1/callback' - with self.assertRaisesRegex(schwab.auth.RedirectServerExitedError, + with self.assertRaisesRegex(auth.RedirectServerExitedError, 'callback URL without a port number'): auth.client_from_login_flow( API_KEY, APP_SECRET, callback_url, self.token_path) @@ -143,6 +182,8 @@ def test_unprivileged_start_on_port_80( @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) @patch('schwab.auth.webbrowser.open', new_callable=MagicMock) + @patch('schwab.auth.prompt', MagicMock(return_value='')) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) def test_time_out_waiting_for_request( self, mock_webbrowser_open, async_session, sync_session, client): AUTH_URL = 'https://auth.url.com' @@ -153,7 +194,7 @@ def test_time_out_waiting_for_request( callback_url = 'https://127.0.0.1:6969/callback' - with self.assertRaisesRegex(schwab.auth.RedirectTimeoutError, + with self.assertRaisesRegex(auth.RedirectTimeoutError, 'Timed out waiting'): auth.client_from_login_flow( API_KEY, APP_SECRET, callback_url, self.token_path, @@ -412,7 +453,7 @@ def setUp(self): @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) @patch('schwab.auth.prompt') - @patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW)) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) def test_no_token_file( self, prompt_func, async_session, sync_session, client): AUTH_URL = 'https://auth.url.com' @@ -439,7 +480,7 @@ def test_no_token_file( @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) @patch('schwab.auth.prompt') - @patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW)) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) def test_custom_token_write_func( self, prompt_func, async_session, sync_session, client): AUTH_URL = 'https://auth.url.com' @@ -476,7 +517,7 @@ def dummy_token_write_func(token): @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) @patch('schwab.auth.prompt') @patch('builtins.print') - @patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW)) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) def test_print_warning_on_http_redirect_uri( self, print_func, prompt_func, async_session, sync_session, client): auth_url = 'https://auth.url.com' @@ -507,7 +548,7 @@ def test_print_warning_on_http_redirect_uri( @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) @patch('schwab.auth.prompt') - @patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW)) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) def test_enforce_enums_disabled( self, prompt_func, async_session, sync_session, client): auth_url = 'https://auth.url.com' @@ -532,7 +573,7 @@ def test_enforce_enums_disabled( @patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient) @patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) @patch('schwab.auth.prompt') - @patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW)) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) def test_enforce_enums_enabled( self, prompt_func, async_session, sync_session, client): auth_url = 'https://auth.url.com' @@ -589,7 +630,7 @@ def test_reject_tokens_without_creation_timestamp(self): @no_duplicates - @patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW)) + @patch('time.time', MagicMock(return_value=MOCK_NOW)) def test_token_age(self): token = {'token': 'yes', 'creation_timestamp': TOKEN_CREATION_TIMESTAMP}