From 0205bbe83c0455caf7fe0ec7712eee5db7b63c5c Mon Sep 17 00:00:00 2001 From: fumoboy007 <2100868+fumoboy007@users.noreply.github.com> Date: Fri, 24 Feb 2023 00:02:18 -0800 Subject: [PATCH] Close the WebSocket connection immediately when the stream is stopped. Currently, when the stream is stopped, we set the stream status accordingly and then wait for the `_consume` loop to check the stream status and close the WebSocket connection. The `_consume` loop calls `self._ws.recv()` with a timeout of 5 seconds, so it can take up to 5 seconds for the WebSocket connection to be closed after the stream is stopped. This is unnecessarily inefficient and complicated. Instead, we could close the WebSocket connection immediately when the stream is stopped. The `_consume` loop would still be broken out of properly because `self._ws.recv()` would raise a `ConnectionClosed` error. --- alpaca_trade_api/stream.py | 71 ++++++++++---------------------------- 1 file changed, 19 insertions(+), 52 deletions(-) diff --git a/alpaca_trade_api/stream.py b/alpaca_trade_api/stream.py index d23a2221..8bcc78f9 100644 --- a/alpaca_trade_api/stream.py +++ b/alpaca_trade_api/stream.py @@ -6,7 +6,6 @@ import msgpack import re import websockets -import queue from .common import get_base_url, get_data_stream_url, get_credentials, URL from .entity import Entity @@ -59,7 +58,6 @@ def __init__(self, self._running = False self._loop = None self._raw_data = raw_data - self._stop_stream_queue = queue.Queue() self._handlers = { 'trades': {}, 'quotes': {}, @@ -113,26 +111,14 @@ async def close(self): async def stop_ws(self): self._should_run = False - if self._stop_stream_queue.empty(): - self._stop_stream_queue.put_nowait({"should_stop": True}) + await self.close() async def _consume(self): while True: - if not self._stop_stream_queue.empty(): - self._stop_stream_queue.get(timeout=1) - await self.close() - break - else: - try: - r = await asyncio.wait_for(self._ws.recv(), 5) - msgs = msgpack.unpackb(r) - for msg in msgs: - await self._dispatch(msg) - except asyncio.TimeoutError: - # ws.recv is hanging when no data is received. by using - # wait_for we break when no data is received, allowing us - # to break the loop when needed - pass + r = await self._ws.recv() + msgs = msgpack.unpackb(r) + for msg in msgs: + await self._dispatch(msg) def _cast(self, msg_type, msg): result = msg @@ -230,14 +216,10 @@ async def _run_forever(self): v for k, v in self._handlers.items() if k not in ("cancelErrors", "corrections") ): - if not self._stop_stream_queue.empty(): - # the ws was signaled to stop before starting the loop so - # we break - self._stop_stream_queue.get(timeout=1) + if not self._should_run: return await asyncio.sleep(0.1) log.info(f'started {self._name} stream') - self._should_run = True self._running = False while True: try: @@ -253,10 +235,10 @@ async def _run_forever(self): self._running = True await self._consume() except websockets.WebSocketException as wse: - await self.close() - self._running = False - log.warn('data websocket error, restarting connection: ' + - str(wse)) + if self._should_run: + await self.close() + log.warn('data websocket error, restarting connection: ' + + str(wse)) except Exception as e: log.exception('error during websocket ' 'communication: {}'.format(str(e))) @@ -621,7 +603,6 @@ def __init__(self, self._running = False self._loop = None self._raw_data = raw_data - self._stop_stream_queue = queue.Queue() self._should_run = True self._websocket_params = websocket_params @@ -686,31 +667,18 @@ async def _start_ws(self): async def _consume(self): while True: - if not self._stop_stream_queue.empty(): - self._stop_stream_queue.get(timeout=1) - await self.close() - break - else: - try: - r = await asyncio.wait_for(self._ws.recv(), 5) - msg = json.loads(r) - await self._dispatch(msg) - except asyncio.TimeoutError: - # ws.recv is hanging when no data is received. by using - # wait_for we break when no data is received, allowing us - # to break the loop when needed - pass + r = await self._ws.recv() + msg = json.loads(r) + await self._dispatch(msg) async def _run_forever(self): self._loop = asyncio.get_running_loop() # do not start the websocket connection until we subscribe to something while not self._trade_updates_handler: - if not self._stop_stream_queue.empty(): - self._stop_stream_queue.get(timeout=1) + if not self._should_run: return await asyncio.sleep(0.1) log.info('started trading stream') - self._should_run = True self._running = False while True: try: @@ -723,10 +691,10 @@ async def _run_forever(self): self._running = True await self._consume() except websockets.WebSocketException as wse: - await self.close() - self._running = False - log.warn('trading stream websocket error, restarting ' + - ' connection: ' + str(wse)) + if self._should_run: + await self.close() + log.warn('trading stream websocket error, restarting ' + + ' connection: ' + str(wse)) except Exception as e: log.exception('error during websocket ' 'communication: {}'.format(str(e))) @@ -741,8 +709,7 @@ async def close(self): async def stop_ws(self): self._should_run = False - if self._stop_stream_queue.empty(): - self._stop_stream_queue.put_nowait({"should_stop": True}) + await self.close() def stop(self): if self._loop.is_running():