diff --git a/liteclient/pool.go b/liteclient/pool.go index 09168372..761e8475 100644 --- a/liteclient/pool.go +++ b/liteclient/pool.go @@ -81,7 +81,7 @@ func NewConnectionPoolWithAuth(key ed25519.PrivateKey) *ConnectionPool { // // In case if sticky node goes down, default balancer will be used as fallback func (c *ConnectionPool) StickyContext(ctx context.Context) context.Context { - if ctx.Value(_StickyCtxKey) != nil { + if c.StickyNodeID(ctx) != 0 { return ctx } @@ -94,11 +94,11 @@ func (c *ConnectionPool) StickyContext(ctx context.Context) context.Context { } c.nodesMx.RUnlock() - return context.WithValue(ctx, _StickyCtxKey, id) + return stickyContextWithNodeID(ctx, id) } func (c *ConnectionPool) StickyContextNextNode(ctx context.Context) (context.Context, error) { - nodeID, _ := ctx.Value(_StickyCtxKey).(uint32) + nodeID := c.StickyNodeID(ctx) usedNodes, _ := ctx.Value(_StickyCtxUsedNodesKey).([]uint32) if nodeID > 0 { usedNodes = append(usedNodes, nodeID) @@ -115,13 +115,51 @@ iter: } } - return context.WithValue(context.WithValue(ctx, _StickyCtxKey, node.id), _StickyCtxUsedNodesKey, usedNodes), nil + return context.WithValue(stickyContextWithNodeID(ctx, node.id), _StickyCtxUsedNodesKey, usedNodes), nil + } + + return ctx, fmt.Errorf("no more active nodes left") +} + +func (c *ConnectionPool) StickyContextExcludeNode(ctx context.Context) (context.Context, error) { + nodeID := c.StickyNodeID(ctx) + if nodeID == 0 { + return ctx, fmt.Errorf("no node to exclude") + } + + usedNodes, _ := ctx.Value(_StickyCtxUsedNodesKey).([]uint32) + usedNodes = append(usedNodes, nodeID) + + c.nodesMx.RLock() + defer c.nodesMx.RUnlock() + + if len(c.activeNodes) < len(usedNodes) { + return context.WithValue(stickyContextWithNodeID(ctx, 0), _StickyCtxUsedNodesKey, usedNodes), nil } return ctx, fmt.Errorf("no more active nodes left") } func (c *ConnectionPool) StickyContextWithNodeID(ctx context.Context, nodeId uint32) context.Context { + usedNodes, _ := ctx.Value(_StickyCtxUsedNodesKey).([]uint32) + if len(usedNodes) == 0 { + return context.WithValue(ctx, _StickyCtxKey, nodeId) + } + + nodes := make([]uint32, 0, len(usedNodes)) + for _, node := range usedNodes { + if node != nodeId { + nodes = append(nodes, node) + } + } + if len(nodes) == len(usedNodes) { + return stickyContextWithNodeID(ctx, nodeId) + } + + return context.WithValue(stickyContextWithNodeID(ctx, nodeId), _StickyCtxUsedNodesKey, usedNodes) +} + +func stickyContextWithNodeID(ctx context.Context, nodeId uint32) context.Context { return context.WithValue(ctx, _StickyCtxKey, nodeId) } @@ -185,13 +223,14 @@ func (c *ConnectionPool) QueryADNL(ctx context.Context, request tl.Serializable, tm := time.Now() var node *connection - if nodeID, ok := ctx.Value(_StickyCtxKey).(uint32); ok && nodeID > 0 { + excludeNodes, _ := ctx.Value(_StickyCtxUsedNodesKey).([]uint32) + if nodeID := c.StickyNodeID(ctx); nodeID > 0 { node, err = c.querySticky(nodeID, req) if err != nil { return err } } else { - node, err = c.queryWithSmartBalancer(req) + node, err = c.queryWithSmartBalancer(req, excludeNodes...) if err != nil { return err } @@ -238,11 +277,23 @@ func (c *ConnectionPool) querySticky(id uint32, req *ADNLRequest) (*connection, return c.queryWithSmartBalancer(req) } -func (c *ConnectionPool) queryWithSmartBalancer(req *ADNLRequest) (*connection, error) { +func (c *ConnectionPool) queryWithSmartBalancer(req *ADNLRequest, excludeNodes ...uint32) (*connection, error) { var reqNode *connection c.nodesMx.RLock() + + if len(c.activeNodes) == 0 { + c.nodesMx.RUnlock() + return nil, ErrNoActiveConnections + } + +iter: for _, node := range c.activeNodes { + for _, excludeNode := range excludeNodes { + if node.id == excludeNode { + continue iter + } + } if reqNode == nil { reqNode = node continue @@ -256,6 +307,9 @@ func (c *ConnectionPool) queryWithSmartBalancer(req *ADNLRequest) (*connection, c.nodesMx.RUnlock() if reqNode == nil { + if len(excludeNodes) > 0 { + return c.queryWithSmartBalancer(req) + } return nil, ErrNoActiveConnections } diff --git a/ton/api.go b/ton/api.go index e86c74af..d592f9ff 100644 --- a/ton/api.go +++ b/ton/api.go @@ -34,6 +34,7 @@ type LiteClient interface { QueryLiteserver(ctx context.Context, payload tl.Serializable, result tl.Serializable) error StickyContext(ctx context.Context) context.Context StickyContextNextNode(ctx context.Context) (context.Context, error) + StickyContextExcludeNode(ctx context.Context) (context.Context, error) StickyNodeID(ctx context.Context) uint32 } diff --git a/ton/retrier.go b/ton/retrier.go index 94edba32..e91d5666 100644 --- a/ton/retrier.go +++ b/ton/retrier.go @@ -25,7 +25,7 @@ func (w *retryClient) QueryLiteserver(ctx context.Context, payload tl.Serializab tries++ if err != nil { - if !errors.Is(err, liteclient.ErrADNLReqTimeout) && !errors.Is(err, context.DeadlineExceeded){ + if !errors.Is(err, liteclient.ErrADNLReqTimeout) && !errors.Is(err, context.DeadlineExceeded) { return err } @@ -69,6 +69,10 @@ func (w *retryClient) StickyNodeID(ctx context.Context) uint32 { return w.original.StickyNodeID(ctx) } +func (w *retryClient) StickyContextExcludeNode(ctx context.Context) (context.Context, error) { + return w.original.StickyContextExcludeNode(ctx) +} + func (w *retryClient) StickyContextNextNode(ctx context.Context) (context.Context, error) { return w.original.StickyContextNextNode(ctx) } diff --git a/ton/timeouter.go b/ton/timeouter.go index cd617437..2c08e795 100644 --- a/ton/timeouter.go +++ b/ton/timeouter.go @@ -27,6 +27,10 @@ func (c *timeoutClient) StickyNodeID(ctx context.Context) uint32 { return c.original.StickyNodeID(ctx) } +func (w *timeoutClient) StickyContextExcludeNode(ctx context.Context) (context.Context, error) { + return w.original.StickyContextExcludeNode(ctx) +} + func (c *timeoutClient) StickyContextNextNode(ctx context.Context) (context.Context, error) { return c.original.StickyContextNextNode(ctx) } diff --git a/ton/waiter.go b/ton/waiter.go index ae36d584..c008747f 100644 --- a/ton/waiter.go +++ b/ton/waiter.go @@ -49,3 +49,7 @@ func (w *waiterClient) StickyNodeID(ctx context.Context) uint32 { func (w *waiterClient) StickyContextNextNode(ctx context.Context) (context.Context, error) { return w.original.StickyContextNextNode(ctx) } + +func (w *waiterClient) StickyContextExcludeNode(ctx context.Context) (context.Context, error) { + return w.original.StickyContextExcludeNode(ctx) +}