Skip to content

Commit

Permalink
Merge pull request #49 from alpacahq/feature/ws-disconnect-fix
Browse files Browse the repository at this point in the history
Implement re-subscription on websocket reconnection
  • Loading branch information
ttt733 authored Aug 14, 2019
2 parents 2b653ad + d430452 commit 1d6480a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 15 deletions.
18 changes: 17 additions & 1 deletion alpaca/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func (s *Stream) Subscribe(channel string, handler func(msg interface{})) (err e
s.handlers.Store(channel, handler)

if err = s.sub(channel); err != nil {
s.handlers.Delete(channel)
return
}
default:
Expand Down Expand Up @@ -83,6 +84,19 @@ func (s *Stream) Close() error {
return s.conn.Close()
}

func (s *Stream) reconnect() {
s.authenticated.Store(false)
s.conn = openSocket()
if err := s.auth(); err != nil {
return
}
s.handlers.Range(func(key, value interface{}) bool {
// there should be no errors if we've previously successfully connected
s.sub(key.(string))
return true
})
}

func (s *Stream) start() {
for {
msg := ServerMsg{}
Expand Down Expand Up @@ -111,7 +125,7 @@ func (s *Stream) start() {
log.Printf("alpaca stream read error (%v)", err)
}

s.conn = openSocket()
s.reconnect()
}
}
}
Expand Down Expand Up @@ -176,6 +190,8 @@ func (s *Stream) auth() (err error) {
return fmt.Errorf("failed to authorize alpaca stream")
}

s.authenticated.Store(true)

return
}

Expand Down
50 changes: 36 additions & 14 deletions polygon/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,6 @@ func (s *Stream) Subscribe(channel string, handler func(msg interface{})) (err e
s.conn = openSocket()
}

// read connection message
msg := []PolgyonServerMsg{}
if err = s.conn.ReadJSON(&msg); err != nil {
return
}

if err = s.auth(); err != nil {
return
}
Expand All @@ -54,10 +48,10 @@ func (s *Stream) Subscribe(channel string, handler func(msg interface{})) (err e
go s.start()
})

topic := channel[:strings.IndexByte(channel, '.')]
s.handlers.Store(topic, handler)
s.handlers.Store(channel, handler)

if err = s.sub(channel); err != nil {
s.handlers.Delete(channel)
return
}

Expand Down Expand Up @@ -86,6 +80,19 @@ func (s *Stream) Close() error {
return s.conn.Close()
}

func (s *Stream) reconnect() {
s.authenticated.Store(false)
s.conn = openSocket()
if err := s.auth(); err != nil {
return
}
s.handlers.Range(func(key, value interface{}) bool {
// there should be no errors if we've previously successfully connected
s.sub(key.(string))
return true
})
}

func (s *Stream) handleError(err error) {
if websocket.IsCloseError(err) {
// if this was a graceful closure, don't reconnect
Expand All @@ -96,7 +103,7 @@ func (s *Stream) handleError(err error) {
log.Printf("polygon stream read error (%v)", err)
}

s.conn = openSocket()
s.reconnect()
}

func (s *Stream) start() {
Expand All @@ -106,36 +113,44 @@ func (s *Stream) start() {
if err := json.Unmarshal(arrayBytes, &msgArray); err == nil {
for _, msg := range msgArray {
msgMap := msg.(map[string]interface{})
if v, ok := s.handlers.Load(msgMap["ev"]); ok {
channel := fmt.Sprintf("%s.%s", msgMap["ev"], msgMap["sym"])
handler, ok := s.handlers.Load(channel)
if !ok {
// see if an "all symbols" handler was registered
handler, ok = s.handlers.Load(fmt.Sprintf("%s.*", msgMap["ev"]))
}
if ok {
msgBytes, _ := json.Marshal(msg)
switch msgMap["ev"] {
case SecondAggs:
fallthrough
case MinuteAggs:
var minuteAgg StreamAggregate
if err := json.Unmarshal(msgBytes, &minuteAgg); err == nil {
h := v.(func(msg interface{}))
h := handler.(func(msg interface{}))
h(minuteAgg)
} else {
s.handleError(err)
}
case Quotes:
var quoteUpdate StreamQuote
if err := json.Unmarshal(msgBytes, &quoteUpdate); err == nil {
h := v.(func(msg interface{}))
h := handler.(func(msg interface{}))
h(quoteUpdate)
} else {
s.handleError(err)
}
case Trades:
var tradeUpdate StreamTrade
if err := json.Unmarshal(msgBytes, &tradeUpdate); err == nil {
h := v.(func(msg interface{}))
h := handler.(func(msg interface{}))
h(tradeUpdate)
} else {
s.handleError(err)
}
}
} else {

}
}
} else {
Expand Down Expand Up @@ -195,9 +210,11 @@ func (s *Stream) auth() (err error) {
}

if !strings.EqualFold(msg[0].Status, "auth_success") {
return fmt.Errorf("failed to authorize alpaca stream")
return fmt.Errorf("failed to authorize Polygon stream")
}

s.authenticated.Store(true)

return
}

Expand Down Expand Up @@ -225,5 +242,10 @@ func openSocket() *websocket.Conn {
if err != nil {
panic(err)
}
// read connection message
msg := []PolgyonServerMsg{}
if err = c.ReadJSON(&msg); err != nil {
panic(err)
}
return c
}

0 comments on commit 1d6480a

Please sign in to comment.