Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close the WebSocket connection immediately when the stream is stopped #685

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 19 additions & 52 deletions alpaca_trade_api/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import msgpack
import re
import websockets
import queue

from .common import get_base_url, get_data_stream_url, get_credentials, URL
from .entity import Entity
Expand Down Expand Up @@ -59,7 +58,6 @@ def __init__(self,
self._running = False
self._loop = None
self._raw_data = raw_data
self._stop_stream_queue = queue.Queue()
self._handlers = {
'trades': {},
'quotes': {},
Expand Down Expand Up @@ -113,26 +111,14 @@ async def close(self):

async def stop_ws(self):
self._should_run = False
if self._stop_stream_queue.empty():
self._stop_stream_queue.put_nowait({"should_stop": True})
await self.close()

async def _consume(self):
while True:
if not self._stop_stream_queue.empty():
self._stop_stream_queue.get(timeout=1)
await self.close()
break
else:
try:
r = await asyncio.wait_for(self._ws.recv(), 5)
msgs = msgpack.unpackb(r)
for msg in msgs:
await self._dispatch(msg)
except asyncio.TimeoutError:
# ws.recv is hanging when no data is received. by using
# wait_for we break when no data is received, allowing us
# to break the loop when needed
pass
r = await self._ws.recv()
msgs = msgpack.unpackb(r)
for msg in msgs:
await self._dispatch(msg)

def _cast(self, msg_type, msg):
result = msg
Expand Down Expand Up @@ -230,14 +216,10 @@ async def _run_forever(self):
v for k, v in self._handlers.items()
if k not in ("cancelErrors", "corrections")
):
if not self._stop_stream_queue.empty():
# the ws was signaled to stop before starting the loop so
# we break
self._stop_stream_queue.get(timeout=1)
if not self._should_run:
return
await asyncio.sleep(0.1)
log.info(f'started {self._name} stream')
self._should_run = True
self._running = False
while True:
try:
Expand All @@ -253,10 +235,10 @@ async def _run_forever(self):
self._running = True
await self._consume()
except websockets.WebSocketException as wse:
await self.close()
self._running = False
log.warn('data websocket error, restarting connection: ' +
str(wse))
if self._should_run:
await self.close()
log.warn('data websocket error, restarting connection: ' +
str(wse))
except Exception as e:
log.exception('error during websocket '
'communication: {}'.format(str(e)))
Expand Down Expand Up @@ -621,7 +603,6 @@ def __init__(self,
self._running = False
self._loop = None
self._raw_data = raw_data
self._stop_stream_queue = queue.Queue()
self._should_run = True
self._websocket_params = websocket_params

Expand Down Expand Up @@ -686,31 +667,18 @@ async def _start_ws(self):

async def _consume(self):
while True:
if not self._stop_stream_queue.empty():
self._stop_stream_queue.get(timeout=1)
await self.close()
break
else:
try:
r = await asyncio.wait_for(self._ws.recv(), 5)
msg = json.loads(r)
await self._dispatch(msg)
except asyncio.TimeoutError:
# ws.recv is hanging when no data is received. by using
# wait_for we break when no data is received, allowing us
# to break the loop when needed
pass
r = await self._ws.recv()
msg = json.loads(r)
await self._dispatch(msg)

async def _run_forever(self):
self._loop = asyncio.get_running_loop()
# do not start the websocket connection until we subscribe to something
while not self._trade_updates_handler:
if not self._stop_stream_queue.empty():
self._stop_stream_queue.get(timeout=1)
if not self._should_run:
return
await asyncio.sleep(0.1)
log.info('started trading stream')
self._should_run = True
self._running = False
while True:
try:
Expand All @@ -723,10 +691,10 @@ async def _run_forever(self):
self._running = True
await self._consume()
except websockets.WebSocketException as wse:
await self.close()
self._running = False
log.warn('trading stream websocket error, restarting ' +
' connection: ' + str(wse))
if self._should_run:
await self.close()
log.warn('trading stream websocket error, restarting ' +
' connection: ' + str(wse))
except Exception as e:
log.exception('error during websocket '
'communication: {}'.format(str(e)))
Expand All @@ -741,8 +709,7 @@ async def close(self):

async def stop_ws(self):
self._should_run = False
if self._stop_stream_queue.empty():
self._stop_stream_queue.put_nowait({"should_stop": True})
await self.close()

def stop(self):
if self._loop.is_running():
Expand Down