diff --git a/internal/net/message_manager.go b/internal/net/message_manager.go index 7908be2bd..94fa16678 100644 --- a/internal/net/message_manager.go +++ b/internal/net/message_manager.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "sync" + "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/host" @@ -22,12 +23,11 @@ import ( "go.opencensus.io/stats" "go.opencensus.io/tag" - "github.com/libp2p/go-libp2p-kad-dht/internal" "github.com/libp2p/go-libp2p-kad-dht/metrics" pb "github.com/libp2p/go-libp2p-kad-dht/pb" ) -var dhtReadMessageTimeout = 10 * time.Second +const dhtMessageTimeout = 10 * time.Second // ErrReadTimeout is an error that occurs when no message is read within the timeout period. var ErrReadTimeout = fmt.Errorf("timed out reading response") @@ -60,13 +60,13 @@ func (m *messageSenderImpl) OnDisconnect(ctx context.Context, p peer.ID) { } delete(m.strmap, p) - // Do this asynchronously as ms.lk can block for a while. + // Do this asynchronously as ms.writeLock can block for a while. go func() { - if err := ms.lk.Lock(ctx); err != nil { - return + ms.writeLock <- struct{}{} + if p := ms.pipeline; p != nil { + p.kill(false) } - defer ms.lk.Unlock() - ms.invalidate() + <-ms.writeLock }() } @@ -143,73 +143,74 @@ func (m *messageSenderImpl) messageSenderForPeer(ctx context.Context, p peer.ID) m.smlk.Unlock() return ms, nil } - ms = &peerMessageSender{p: p, m: m, lk: internal.NewCtxMutex()} + ms = newPeerMessageSender(p, m) m.strmap[p] = ms m.smlk.Unlock() - if err := ms.prepOrInvalidate(ctx); err != nil { - m.smlk.Lock() - defer m.smlk.Unlock() - - if msCur, ok := m.strmap[p]; ok { - // Changed. Use the new one, old one is invalid and - // not in the map so we can just throw it away. - if ms != msCur { - return msCur, nil - } - // Not changed, remove the now invalid stream from the - // map. - delete(m.strmap, p) - } - // Invalid but not in map. Must have been removed by a disconnect. - return nil, err - } - // All ready to go. return ms, nil } // peerMessageSender is responsible for sending requests and messages to a particular peer +// it apply backpressure by pipelining messages. type peerMessageSender struct { - s network.Stream - r msgio.ReadCloser - lk internal.CtxMutex - p peer.ID - m *messageSenderImpl - - invalid bool - singleMes int + pipeline *pipeline + writeLock chan struct{} // use a chan so we can select againt ctx.Done + + p peer.ID + m *messageSenderImpl + + waiting uint + invalid bool } -// invalidate is called before this peerMessageSender is removed from the strmap. -// It prevents the peerMessageSender from being reused/reinitialized and then -// forgotten (leaving the stream open). -func (ms *peerMessageSender) invalidate() { - ms.invalid = true - if ms.s != nil { - _ = ms.s.Reset() - ms.s = nil +func newPeerMessageSender(p peer.ID, m *messageSenderImpl) *peerMessageSender { + return &peerMessageSender{ + writeLock: make(chan struct{}, 1), + p: p, + m: m, } } -func (ms *peerMessageSender) prepOrInvalidate(ctx context.Context) error { - if err := ms.lk.Lock(ctx); err != nil { - return err - } - defer ms.lk.Unlock() +type pipeline struct { + hasError atomic.Bool + hasBeenClosed bool + nextReader chan struct{} + s network.Stream + r msgio.ReadCloser + bufW bufio.Writer + w protoio.Writer + readError error + closeLk sync.Mutex + waiting uint +} - if err := ms.prep(ctx); err != nil { - ms.invalidate() - return err +func (p *pipeline) kill(wasWaiting bool) { + p.closeLk.Lock() + defer p.closeLk.Unlock() + + p.hasError.Store(true) + w := p.waiting + if wasWaiting { + w-- + } + if w == 0 && !p.hasBeenClosed { + p.hasBeenClosed = true + p.s.Reset() + p.r.Close() + } else { + p.waiting = w } - return nil } -func (ms *peerMessageSender) prep(ctx context.Context) error { +func (ms *peerMessageSender) getStream(ctx context.Context) error { if ms.invalid { return fmt.Errorf("message sender has been invalidated") } - if ms.s != nil { - return nil + if ms.pipeline != nil { + if !ms.pipeline.hasError.Load() { + return nil + } + ms.pipeline = nil } // We only want to speak to peers using our primary protocols. We do not want to query any peer that only speaks @@ -220,133 +221,161 @@ func (ms *peerMessageSender) prep(ctx context.Context) error { return err } - ms.r = msgio.NewVarintReaderSize(nstr, network.MessageSizeMax) - ms.s = nstr + nr := make(chan struct{}) + close(nr) + ms.pipeline = &pipeline{ + nextReader: nr, + s: nstr, + r: msgio.NewVarintReaderSize(nstr, network.MessageSizeMax), + } + ms.pipeline.bufW.Reset(nstr) + ms.pipeline.w = protoio.NewDelimitedWriter(&ms.pipeline.bufW) return nil } -// streamReuseTries is the number of times we will try to reuse a stream to a -// given peer before giving up and reverting to the old one-message-per-stream -// behaviour. -const streamReuseTries = 3 - func (ms *peerMessageSender) SendMessage(ctx context.Context, pmes *pb.Message) error { - if err := ms.lk.Lock(ctx); err != nil { - return err + select { + case ms.writeLock <- struct{}{}: + defer func() { <-ms.writeLock }() + case <-ctx.Done(): + return ctx.Err() } - defer ms.lk.Unlock() - retry := false - for { - if err := ms.prep(ctx); err != nil { - return err - } + _, _, cancel, err := ms.sendMessage(ctx, pmes) + cancel() + return err +} - if err := ms.writeMsg(pmes); err != nil { - _ = ms.s.Reset() - ms.s = nil - - if retry { - logger.Debugw("error writing message", "error", err) - return err - } - logger.Debugw("error writing message", "error", err, "retrying", true) - retry = true - continue - } +// sendMessage let the caller handle [ms.writeLock]. +func (ms *peerMessageSender) sendMessage(ictx context.Context, pmes *pb.Message) (deadline time.Time, ctx context.Context, cancel context.CancelFunc, err error) { + ctx, cancel = context.WithTimeout(ictx, dhtMessageTimeout) - var err error - if ms.singleMes > streamReuseTries { - err = ms.s.Close() - ms.s = nil - } else if retry { - ms.singleMes++ + err = ms.getStream(ctx) + if err != nil { + return + } + p := ms.pipeline + var good bool + defer func() { + if !good { + p.kill(false) + ms.pipeline = nil } + }() - return err + var ok bool + deadline, ok = ctx.Deadline() + if !ok { + panic("wtf we just added a timeout, can't not have a deadline.") + } + err = p.s.SetWriteDeadline(deadline) + if err != nil { + return + } + + err = p.w.WriteMsg(pmes) + if err != nil { + return + } + err = p.bufW.Flush() + if err != nil { + return + } + + good = true + return +} + +func (p *pipeline) decrement() { + p.closeLk.Lock() + defer p.closeLk.Unlock() + + w := p.waiting - 1 + if w == 0 && p.hasError.Load() && !p.hasBeenClosed { + p.hasBeenClosed = true + p.s.Reset() + p.r.Close() + } else { + p.waiting = w } } +func (p *pipeline) increment() { + p.closeLk.Lock() + defer p.closeLk.Unlock() + p.waiting++ +} + func (ms *peerMessageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) { - if err := ms.lk.Lock(ctx); err != nil { + select { + case ms.writeLock <- struct{}{}: + case <-ctx.Done(): + return nil, ctx.Err() + } + + deadline, ctx, cancel, err := ms.sendMessage(ctx, pmes) + defer cancel() + if err != nil { + <-ms.writeLock return nil, err } - defer ms.lk.Unlock() - retry := false - for { - if err := ms.prep(ctx); err != nil { - return nil, err + p := ms.pipeline + nextNext := make(chan struct{}) + next := p.nextReader + p.nextReader = nextNext + p.increment() + <-ms.writeLock + var good bool + defer func() { + if !good { + p.kill(true) } + }() - if err := ms.writeMsg(pmes); err != nil { - _ = ms.s.Reset() - ms.s = nil - - if retry { - logger.Debugw("error writing message", "error", err) - return nil, err - } - logger.Debugw("error writing message", "error", err, "retrying", true) - retry = true - continue - } + // FIXME: handle theses contex failures nicely, + // we could start a background goroutine to clean up the stream, or let the next goroutine do it. - mes := new(pb.Message) - if err := ms.ctxReadMsg(ctx, mes); err != nil { - _ = ms.s.Reset() - ms.s = nil - - if retry { - logger.Debugw("error reading message", "error", err) - return nil, err - } - logger.Debugw("error reading message", "error", err, "retrying", true) - retry = true - continue - } + select { + case <-next: + case <-ctx.Done(): + go func() { + // pass the chain + <-next + close(nextNext) + }() + return nil, ctx.Err() + } - var err error - if ms.singleMes > streamReuseTries { - err = ms.s.Close() - ms.s = nil - } else if retry { - ms.singleMes++ - } + defer close(nextNext) - return mes, err + if rerr := p.readError; rerr != nil { + return nil, rerr } -} -func (ms *peerMessageSender) writeMsg(pmes *pb.Message) error { - return WriteMsg(ms.s, pmes) -} + err = p.s.SetReadDeadline(deadline) + if err != nil { + return nil, err + } -func (ms *peerMessageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error { - errc := make(chan error, 1) - go func(r msgio.ReadCloser) { - defer close(errc) - bytes, err := r.ReadMsg() - defer r.ReleaseMsg(bytes) - if err != nil { - errc <- err - return - } - errc <- mes.Unmarshal(bytes) - }(ms.r) + bytes, err := p.r.ReadMsg() + if err != nil { + p.readError = err + return nil, err + } + mes := new(pb.Message) + err = mes.Unmarshal(bytes) + p.r.ReleaseMsg(bytes) + if err != nil { + p.readError = err + return nil, err + } - t := time.NewTimer(dhtReadMessageTimeout) - defer t.Stop() + p.decrement() + good = true - select { - case err := <-errc: - return err - case <-ctx.Done(): - return ctx.Err() - case <-t.C: - return ErrReadTimeout - } + return mes, nil } // The Protobuf writer performs multiple small writes when writing a message. diff --git a/internal/net/message_manager_test.go b/internal/net/message_manager_test.go deleted file mode 100644 index 5c61ec2de..000000000 --- a/internal/net/message_manager_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package net - -import ( - "context" - "testing" - - "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/protocol" - - bhost "github.com/libp2p/go-libp2p/p2p/host/basic" - swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" - - "github.com/stretchr/testify/require" -) - -func TestInvalidMessageSenderTracking(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - foo := peer.ID("asdasd") - - h, err := bhost.NewHost(swarmt.GenSwarm(t, swarmt.OptDisableReuseport), new(bhost.HostOpts)) - require.NoError(t, err) - h.Start() - defer h.Close() - - msgSender := NewMessageSenderImpl(h, []protocol.ID{"/test/kad/1.0.0"}).(*messageSenderImpl) - - _, err = msgSender.messageSenderForPeer(ctx, foo) - require.Error(t, err, "should have failed to find message sender") - - msgSender.smlk.Lock() - mscnt := len(msgSender.strmap) - msgSender.smlk.Unlock() - - if mscnt > 0 { - t.Fatal("should have no message senders in map") - } -}