diff --git a/ssh/channel.go b/ssh/channel.go index c0834c00df..cc0bb7ab64 100644 --- a/ssh/channel.go +++ b/ssh/channel.go @@ -187,9 +187,11 @@ type channel struct { pending *buffer extPending *buffer - // windowMu protects myWindow, the flow-control window. - windowMu sync.Mutex - myWindow uint32 + // windowMu protects myWindow, the flow-control window, and myConsumed, + // the number of bytes consumed since we last increased myWindow + windowMu sync.Mutex + myWindow uint32 + myConsumed uint32 // writeMu serializes calls to mux.conn.writePacket() and // protects sentClose and packetPool. This mutex must be @@ -332,14 +334,24 @@ func (ch *channel) handleData(packet []byte) error { return nil } -func (c *channel) adjustWindow(n uint32) error { +func (c *channel) adjustWindow(adj uint32) error { c.windowMu.Lock() - // Since myWindow is managed on our side, and can never exceed - // the initial window setting, we don't worry about overflow. - c.myWindow += uint32(n) + // Since myConsumed and myWindow are managed on our side, and can never + // exceed the initial window setting, we don't worry about overflow. + c.myConsumed += adj + var sendAdj uint32 + if (channelWindowSize-c.myWindow > 3*c.maxIncomingPayload) || + (c.myWindow < channelWindowSize/2) { + sendAdj = c.myConsumed + c.myConsumed = 0 + c.myWindow += sendAdj + } c.windowMu.Unlock() + if sendAdj == 0 { + return nil + } return c.sendMessage(windowAdjustMsg{ - AdditionalBytes: uint32(n), + AdditionalBytes: sendAdj, }) } diff --git a/ssh/mempipe_test.go b/ssh/mempipe_test.go index 8697cd6140..f27339c51a 100644 --- a/ssh/mempipe_test.go +++ b/ssh/mempipe_test.go @@ -13,9 +13,10 @@ import ( // An in-memory packetConn. It is safe to call Close and writePacket // from different goroutines. type memTransport struct { - eof bool - pending [][]byte - write *memTransport + eof bool + pending [][]byte + write *memTransport + writeCount uint64 sync.Mutex *sync.Cond } @@ -63,9 +64,16 @@ func (t *memTransport) writePacket(p []byte) error { copy(c, p) t.write.pending = append(t.write.pending, c) t.write.Cond.Signal() + t.writeCount++ return nil } +func (t *memTransport) getWriteCount() uint64 { + t.write.Lock() + defer t.write.Unlock() + return t.writeCount +} + func memPipe() (a, b packetConn) { t1 := memTransport{} t2 := memTransport{} @@ -81,6 +89,9 @@ func TestMemPipe(t *testing.T) { if err := a.writePacket([]byte{42}); err != nil { t.Fatalf("writePacket: %v", err) } + if wc := a.(*memTransport).getWriteCount(); wc != 1 { + t.Fatalf("got %v, want 1", wc) + } if err := a.Close(); err != nil { t.Fatal("Close: ", err) } @@ -95,6 +106,9 @@ func TestMemPipe(t *testing.T) { if err != io.EOF { t.Fatalf("got %v, %v, want EOF", p, err) } + if wc := b.(*memTransport).getWriteCount(); wc != 0 { + t.Fatalf("got %v, want 0", wc) + } } func TestDoubleClose(t *testing.T) { diff --git a/ssh/mux_test.go b/ssh/mux_test.go index eae637d5e2..21f0ac3e32 100644 --- a/ssh/mux_test.go +++ b/ssh/mux_test.go @@ -182,6 +182,40 @@ func TestMuxChannelOverflow(t *testing.T) { } } +func TestMuxChannelReadUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != nil { + t.Errorf("Write: %v", err) + } + writer.Close() + }() + + writer.remoteWin.waitWriterBlocked() + + buf := make([]byte, 32768) + for { + _, err := reader.Read(buf) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Read: %v", err) + } + } +} + func TestMuxChannelCloseWriteUnblock(t *testing.T) { reader, writer, mux := channelPair(t) defer reader.Close() @@ -754,6 +788,43 @@ func TestMuxMaxPacketSize(t *testing.T) { } } +func TestMuxChannelWindowDeferredUpdates(t *testing.T) { + s, c, mux := channelPair(t) + cTransport := mux.conn.(*memTransport) + defer s.Close() + defer c.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + + data := make([]byte, 1024) + + wg.Add(1) + go func() { + defer wg.Done() + _, err := s.Write(data) + if err != nil { + t.Errorf("Write: %v", err) + return + } + }() + cWritesInit := cTransport.getWriteCount() + buf := make([]byte, 1) + for i := 0; i < len(data); i++ { + n, err := c.Read(buf) + if n != len(buf) || err != nil { + t.Fatalf("Read: %v, %v", n, err) + } + } + cWrites := cTransport.getWriteCount() - cWritesInit + // reading 1 KiB should not cause any window updates to be sent, but allow + // for some unexpected writes + if cWrites > 30 { + t.Fatalf("reading 1 KiB from channel caused %v writes", cWrites) + } +} + // Don't ship code with debug=true. func TestDebug(t *testing.T) { if debugMux {