Skip to content

Commit

Permalink
safer way to push data to streams
Browse files Browse the repository at this point in the history
  • Loading branch information
xtaci committed Apr 24, 2019
1 parent 4c827cf commit 401b0da
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 26 deletions.
31 changes: 13 additions & 18 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package smux
import (
"encoding/binary"
"io"
"io/ioutil"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -259,24 +258,20 @@ func (s *Session) recvLoop() {
}
s.streamLock.Unlock()
case cmdPSH:
var written int
var err error
s.streamLock.Lock()
if stream, ok := s.streams[sid]; ok {
if hdr.Length() > 0 {
written, err = stream.receiveBytes(s.conn, int(hdr.Length()))
atomic.AddInt32(&s.bucket, -int32(written))
stream.notifyReadEvent()
if hdr.Length() > 0 {
newbuf := make([]byte, hdr.Length())
if written, err := io.ReadFull(s.conn, newbuf); err == nil {
s.streamLock.Lock()
if stream, ok := s.streams[sid]; ok {
stream.pushBytes(newbuf)
atomic.AddInt32(&s.bucket, -int32(written))
stream.notifyReadEvent()
}
s.streamLock.Unlock()
} else {
s.Close()
return
}
} else { // discard
_, err = io.CopyN(ioutil.Discard, s.conn, int64(hdr.Length()))
}
s.streamLock.Unlock()

// read data error
if err != nil {
s.Close()
return
}
default:
s.Close()
Expand Down
13 changes: 5 additions & 8 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,26 +209,23 @@ func (s *Stream) RemoteAddr() net.Addr {
return nil
}

// receiveBytes receive from the reader and write into the buffer
func (s *Stream) receiveBytes(r io.Reader, sz int) (written int, err error) {
newbuf := make([]byte, sz)
written, err = io.ReadFull(r, newbuf)
// pushBytes append buf to buffers
func (s *Stream) pushBytes(buf []byte) (written int, err error) {
s.bufferLock.Lock()
s.buffers = append(s.buffers, newbuf)
s.buffers = append(s.buffers, buf)
s.bufferLock.Unlock()
return
}

// recycleTokens transform remaining bytes to tokens(will truncate buffer)
func (s *Stream) recycleTokens() (n int) {
total := 0
s.bufferLock.Lock()
for k := range s.buffers {
total += len(s.buffers[k])
n += len(s.buffers[k])
}
s.buffers = nil
s.bufferLock.Unlock()
return total
return
}

// notify read event
Expand Down

0 comments on commit 401b0da

Please sign in to comment.