diff --git a/session.go b/session.go index 99f4451..fba94ef 100644 --- a/session.go +++ b/session.go @@ -259,15 +259,17 @@ func (s *Session) recvLoop() { } s.streamLock.Unlock() case cmdPSH: - var written int64 + var written int var err error s.streamLock.Lock() if stream, ok := s.streams[sid]; ok { - written, err = stream.receiveBytes(s.conn, int64(hdr.Length())) - atomic.AddInt32(&s.bucket, -int32(written)) - stream.notifyReadEvent() + if hdr.Length() > 0 { + written, err = stream.receiveBytes(s.conn, int(hdr.Length())) + atomic.AddInt32(&s.bucket, -int32(written)) + stream.notifyReadEvent() + } } else { // discard - written, err = io.CopyN(ioutil.Discard, s.conn, int64(hdr.Length())) + _, err = io.CopyN(ioutil.Discard, s.conn, int64(hdr.Length())) } s.streamLock.Unlock() diff --git a/stream.go b/stream.go index a77e799..8323a06 100644 --- a/stream.go +++ b/stream.go @@ -1,7 +1,6 @@ package smux import ( - "bytes" "io" "net" "sync" @@ -16,7 +15,7 @@ type Stream struct { id uint32 rstflag int32 sess *Session - buffer *bytes.Buffer + buffers [][]byte bufferLock sync.Mutex frameSize int chReadEvent chan struct{} // notify a read event @@ -34,7 +33,6 @@ func newStream(id uint32, frameSize int, sess *Session) *Stream { s.frameSize = frameSize s.sess = sess s.die = make(chan struct{}) - s.buffer = new(bytes.Buffer) return s } @@ -63,7 +61,14 @@ func (s *Stream) Read(b []byte) (n int, err error) { READ: s.bufferLock.Lock() - n, _ = s.buffer.Read(b) + if len(s.buffers) > 0 { + n = copy(b, s.buffers[0]) + s.buffers[0] = s.buffers[0][n:] + if len(s.buffers[0]) == 0 { + s.buffers[0] = nil + s.buffers = s.buffers[1:] + } + } s.bufferLock.Unlock() if n > 0 { @@ -205,20 +210,24 @@ func (s *Stream) RemoteAddr() net.Addr { } // receiveBytes receive from the reader and write into the buffer -func (s *Stream) receiveBytes(r io.Reader, sz int64) (written int64, err error) { +func (s *Stream) receiveBytes(r io.Reader, sz int) (written int, err error) { + newbuf := make([]byte, sz) + written, err = io.ReadFull(r, newbuf) s.bufferLock.Lock() - written, err = io.CopyN(s.buffer, r, sz) + s.buffers = append(s.buffers, newbuf) s.bufferLock.Unlock() return } // recycleTokens transform remaining bytes to tokens(will truncate buffer) func (s *Stream) recycleTokens() (n int) { + total := 0 s.bufferLock.Lock() - n = s.buffer.Len() - s.buffer.Reset() + for k := range s.buffers { + total += len(s.buffers[k]) + } s.bufferLock.Unlock() - return + return total } // notify read event