diff --git a/alpaca/stream.go b/alpaca/stream.go index e043955..ac832d3 100644 --- a/alpaca/stream.go +++ b/alpaca/stream.go @@ -53,6 +53,7 @@ func (s *Stream) Subscribe(channel string, handler func(msg interface{})) (err e s.handlers.Store(channel, handler) if err = s.sub(channel); err != nil { + s.handlers.Delete(channel) return } default: @@ -83,6 +84,19 @@ func (s *Stream) Close() error { return s.conn.Close() } +func (s *Stream) reconnect() { + s.authenticated.Store(false) + s.conn = openSocket() + if err := s.auth(); err != nil { + return + } + s.handlers.Range(func(key, value interface{}) bool { + // there should be no errors if we've previously successfully connected + s.sub(key.(string)) + return true + }) +} + func (s *Stream) start() { for { msg := ServerMsg{} @@ -111,7 +125,7 @@ func (s *Stream) start() { log.Printf("alpaca stream read error (%v)", err) } - s.conn = openSocket() + s.reconnect() } } } @@ -176,6 +190,8 @@ func (s *Stream) auth() (err error) { return fmt.Errorf("failed to authorize alpaca stream") } + s.authenticated.Store(true) + return } diff --git a/polygon/stream.go b/polygon/stream.go index b707550..a34910c 100644 --- a/polygon/stream.go +++ b/polygon/stream.go @@ -40,12 +40,6 @@ func (s *Stream) Subscribe(channel string, handler func(msg interface{})) (err e s.conn = openSocket() } - // read connection message - msg := []PolgyonServerMsg{} - if err = s.conn.ReadJSON(&msg); err != nil { - return - } - if err = s.auth(); err != nil { return } @@ -54,10 +48,10 @@ func (s *Stream) Subscribe(channel string, handler func(msg interface{})) (err e go s.start() }) - topic := channel[:strings.IndexByte(channel, '.')] - s.handlers.Store(topic, handler) + s.handlers.Store(channel, handler) if err = s.sub(channel); err != nil { + s.handlers.Delete(channel) return } @@ -86,6 +80,19 @@ func (s *Stream) Close() error { return s.conn.Close() } +func (s *Stream) reconnect() { + s.authenticated.Store(false) + s.conn = openSocket() + if err := s.auth(); err != nil { + return + } + s.handlers.Range(func(key, value interface{}) bool { + // there should be no errors if we've previously successfully connected + s.sub(key.(string)) + return true + }) +} + func (s *Stream) handleError(err error) { if websocket.IsCloseError(err) { // if this was a graceful closure, don't reconnect @@ -96,7 +103,7 @@ func (s *Stream) handleError(err error) { log.Printf("polygon stream read error (%v)", err) } - s.conn = openSocket() + s.reconnect() } func (s *Stream) start() { @@ -106,7 +113,13 @@ func (s *Stream) start() { if err := json.Unmarshal(arrayBytes, &msgArray); err == nil { for _, msg := range msgArray { msgMap := msg.(map[string]interface{}) - if v, ok := s.handlers.Load(msgMap["ev"]); ok { + channel := fmt.Sprintf("%s.%s", msgMap["ev"], msgMap["sym"]) + handler, ok := s.handlers.Load(channel) + if !ok { + // see if an "all symbols" handler was registered + handler, ok = s.handlers.Load(fmt.Sprintf("%s.*", msgMap["ev"])) + } + if ok { msgBytes, _ := json.Marshal(msg) switch msgMap["ev"] { case SecondAggs: @@ -114,7 +127,7 @@ func (s *Stream) start() { case MinuteAggs: var minuteAgg StreamAggregate if err := json.Unmarshal(msgBytes, &minuteAgg); err == nil { - h := v.(func(msg interface{})) + h := handler.(func(msg interface{})) h(minuteAgg) } else { s.handleError(err) @@ -122,7 +135,7 @@ func (s *Stream) start() { case Quotes: var quoteUpdate StreamQuote if err := json.Unmarshal(msgBytes, "eUpdate); err == nil { - h := v.(func(msg interface{})) + h := handler.(func(msg interface{})) h(quoteUpdate) } else { s.handleError(err) @@ -130,12 +143,14 @@ func (s *Stream) start() { case Trades: var tradeUpdate StreamTrade if err := json.Unmarshal(msgBytes, &tradeUpdate); err == nil { - h := v.(func(msg interface{})) + h := handler.(func(msg interface{})) h(tradeUpdate) } else { s.handleError(err) } } + } else { + } } } else { @@ -195,9 +210,11 @@ func (s *Stream) auth() (err error) { } if !strings.EqualFold(msg[0].Status, "auth_success") { - return fmt.Errorf("failed to authorize alpaca stream") + return fmt.Errorf("failed to authorize Polygon stream") } + s.authenticated.Store(true) + return } @@ -225,5 +242,10 @@ func openSocket() *websocket.Conn { if err != nil { panic(err) } + // read connection message + msg := []PolgyonServerMsg{} + if err = c.ReadJSON(&msg); err != nil { + panic(err) + } return c }